Skip to content

Commit 3e9ebd6

Browse files
committed
squash-merge PR #5887 swolchok→enum-perf
1 parent 181d5a7 commit 3e9ebd6

7 files changed

Lines changed: 238 additions & 58 deletions

File tree

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ set(PYBIND11_HEADERS
188188
include/pybind11/detail/dynamic_raw_ptr_cast_if_possible.h
189189
include/pybind11/detail/exception_translation.h
190190
include/pybind11/detail/function_record_pyobject.h
191+
include/pybind11/detail/function_ref.h
191192
include/pybind11/detail/holder_caster_foreign_helpers.h
192193
include/pybind11/detail/init.h
193194
include/pybind11/detail/internals.h

include/pybind11/detail/common.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,14 @@
167167
# define PYBIND11_NOINLINE __attribute__((noinline)) inline
168168
#endif
169169

170+
#if defined(_MSC_VER)
171+
# define PYBIND11_ALWAYS_INLINE __forceinline
172+
#elif defined(__GNUC__)
173+
# define PYBIND11_ALWAYS_INLINE __attribute__((__always_inline__)) inline
174+
#else
175+
# define PYBIND11_ALWAYS_INLINE inline
176+
#endif
177+
170178
#if defined(__MINGW32__)
171179
// For unknown reasons all PYBIND11_DEPRECATED member trigger a warning when declared
172180
// whether it is used or not
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
//===- llvm/ADT/STLFunctionalExtras.h - Extras for <functional> -*- C++ -*-===//
10+
//
11+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
12+
// See https://llvm.org/LICENSE.txt for license information.
13+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14+
//
15+
//===----------------------------------------------------------------------===//
16+
//
17+
// This file contains a header-only class template that provides functionality
18+
// similar to std::function but with non-owning semantics. It is a template-only
19+
// implementation that requires no additional library linking.
20+
//
21+
//===----------------------------------------------------------------------===//
22+
23+
/// An efficient, type-erasing, non-owning reference to a callable. This is
24+
/// intended for use as the type of a function parameter that is not used
25+
/// after the function in question returns.
26+
///
27+
/// This class does not own the callable, so it is not in general safe to store
28+
/// a FunctionRef.
29+
30+
// pybind11: modified again from executorch::runtime::FunctionRef
31+
// - renamed back to function_ref
32+
// - use pybind11 enable_if_t, remove_cvref_t, and remove_reference_t
33+
// - lint suppressions
34+
35+
// torch::executor: modified from llvm::function_ref
36+
// - renamed to FunctionRef
37+
// - removed LLVM_GSL_POINTER and LLVM_LIFETIME_BOUND macro uses
38+
// - use namespaced internal::remove_cvref_t
39+
40+
#pragma once
41+
42+
#include <pybind11/detail/common.h>
43+
44+
#include <cstdint>
45+
#include <type_traits>
46+
#include <utility>
47+
48+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
49+
PYBIND11_NAMESPACE_BEGIN(detail)
50+
51+
//===----------------------------------------------------------------------===//
52+
// Features from C++20
53+
//===----------------------------------------------------------------------===//
54+
55+
template <typename Fn>
56+
class function_ref;
57+
58+
template <typename Ret, typename... Params>
59+
class function_ref<Ret(Params...)> {
60+
Ret (*callback)(intptr_t callable, Params... params) = nullptr;
61+
intptr_t callable;
62+
63+
template <typename Callable>
64+
// NOLINTNEXTLINE(performance-unnecessary-value-param)
65+
static Ret callback_fn(intptr_t callable, Params... params) {
66+
// NOLINTNEXTLINE(performance-no-int-to-ptr)
67+
return (*reinterpret_cast<Callable *>(callable))(std::forward<Params>(params)...);
68+
}
69+
70+
public:
71+
function_ref() = default;
72+
// NOLINTNEXTLINE(google-explicit-constructor)
73+
function_ref(std::nullptr_t) {}
74+
75+
template <typename Callable>
76+
// NOLINTNEXTLINE(google-explicit-constructor)
77+
function_ref(
78+
Callable &&callable,
79+
// This is not the copy-constructor.
80+
enable_if_t<!std::is_same<remove_cvref_t<Callable>, function_ref>::value> * = nullptr,
81+
// Functor must be callable and return a suitable type.
82+
enable_if_t<
83+
std::is_void<Ret>::value
84+
|| std::is_convertible<decltype(std::declval<Callable>()(std::declval<Params>()...)),
85+
Ret>::value> * = nullptr)
86+
: callback(callback_fn<remove_reference_t<Callable>>),
87+
callable(reinterpret_cast<intptr_t>(&callable)) {}
88+
89+
// NOLINTNEXTLINE(performance-unnecessary-value-param)
90+
Ret operator()(Params... params) const {
91+
return callback(callable, std::forward<Params>(params)...);
92+
}
93+
94+
explicit operator bool() const { return callback; }
95+
96+
bool operator==(const function_ref<Ret(Params...)> &Other) const {
97+
return callable == Other.callable;
98+
}
99+
};
100+
PYBIND11_NAMESPACE_END(detail)
101+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)

include/pybind11/pybind11.h

Lines changed: 110 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "detail/dynamic_raw_ptr_cast_if_possible.h"
1414
#include "detail/exception_translation.h"
1515
#include "detail/function_record_pyobject.h"
16+
#include "detail/function_ref.h"
1617
#include "detail/init.h"
1718
#include "detail/native_enum_data.h"
1819
#include "detail/using_smart_holder.h"
@@ -386,6 +387,46 @@ class cpp_function : public function {
386387
return unique_function_record(new detail::function_record());
387388
}
388389

390+
private:
391+
// This is outlined from the dispatch lambda in initialize to save
392+
// on code size. Crucially, we use function_ref to type-erase the
393+
// actual function lambda so that we can get code reuse for
394+
// functions with the same Return, Args, and Guard.
395+
template <typename Return, typename Guard, typename ArgsConverter, typename... Args>
396+
static handle call_impl(detail::function_call &call, detail::function_ref<Return(Args...)> f) {
397+
using namespace detail;
398+
// Static assertion: function_ref must be trivially copyable to ensure safe pass-by-value.
399+
// Lifetime safety: The function_ref is created from cap->f which lives in the capture
400+
// object stored in the function record, and is only used synchronously within this
401+
// function call. It is never stored beyond the scope of call_impl.
402+
static_assert(std::is_trivially_copyable<detail::function_ref<Return(Args...)>>::value,
403+
"function_ref must be trivially copyable for safe pass-by-value usage");
404+
using cast_out
405+
= make_caster<conditional_t<std::is_void<Return>::value, void_type, Return>>;
406+
407+
ArgsConverter args_converter;
408+
if (!args_converter.load_args(call)) {
409+
return PYBIND11_TRY_NEXT_OVERLOAD;
410+
}
411+
412+
/* Override policy for rvalues -- usually to enforce rvp::move on an rvalue */
413+
return_value_policy policy
414+
= return_value_policy_override<Return>::policy(call.func.policy);
415+
416+
/* Perform the function call */
417+
handle result;
418+
if (call.func.is_setter) {
419+
(void) std::move(args_converter).template call<Return, Guard>(f);
420+
result = none().release();
421+
} else {
422+
result = cast_out::cast(
423+
std::move(args_converter).template call<Return, Guard>(f), policy, call.parent);
424+
}
425+
426+
return result;
427+
}
428+
429+
protected:
389430
/// Special internal constructor for functors, lambda functions, etc.
390431
template <typename Func, typename Return, typename... Args, typename... Extra>
391432
void initialize(Func &&f, Return (*)(Args...), const Extra &...extra) {
@@ -448,13 +489,6 @@ class cpp_function : public function {
448489

449490
/* Dispatch code which converts function arguments and performs the actual function call */
450491
rec->impl = [](function_call &call) -> handle {
451-
cast_in args_converter;
452-
453-
/* Try to cast the function arguments into the C++ domain */
454-
if (!args_converter.load_args(call)) {
455-
return PYBIND11_TRY_NEXT_OVERLOAD;
456-
}
457-
458492
/* Invoke call policy pre-call hook */
459493
process_attributes<Extra...>::precall(call);
460494

@@ -463,24 +497,11 @@ class cpp_function : public function {
463497
: call.func.data[0]);
464498
auto *cap = const_cast<capture *>(reinterpret_cast<const capture *>(data));
465499

466-
/* Override policy for rvalues -- usually to enforce rvp::move on an rvalue */
467-
return_value_policy policy
468-
= return_value_policy_override<Return>::policy(call.func.policy);
469-
470-
/* Function scope guard -- defaults to the compile-to-nothing `void_type` */
471-
using Guard = extract_guard_t<Extra...>;
472-
473-
/* Perform the function call */
474-
handle result;
475-
if (call.func.is_setter) {
476-
(void) std::move(args_converter).template call<Return, Guard>(cap->f);
477-
result = none().release();
478-
} else {
479-
result = cast_out::cast(
480-
std::move(args_converter).template call<Return, Guard>(cap->f),
481-
policy,
482-
call.parent);
483-
}
500+
auto result = call_impl<Return,
501+
/* Function scope guard -- defaults to the compile-to-nothing
502+
`void_type` */
503+
extract_guard_t<Extra...>,
504+
cast_in>(call, detail::function_ref<Return(Args...)>(cap->f));
484505

485506
/* Invoke call policy post-call hook */
486507
process_attributes<Extra...>::postcall(call, result);
@@ -2245,7 +2266,7 @@ class class_ : public detail::generic_type {
22452266
static void add_base(detail::type_record &) {}
22462267

22472268
template <typename Func, typename... Extra>
2248-
class_ &def(const char *name_, Func &&f, const Extra &...extra) {
2269+
PYBIND11_ALWAYS_INLINE class_ &def(const char *name_, Func &&f, const Extra &...extra) {
22492270
cpp_function cf(method_adaptor<type>(std::forward<Func>(f)),
22502271
name(name_),
22512272
is_method(*this),
@@ -2830,38 +2851,13 @@ struct enum_base {
28302851
pos_only())
28312852

28322853
if (is_convertible) {
2833-
PYBIND11_ENUM_OP_CONV_LHS("__eq__", !b.is_none() && a.equal(b));
2834-
PYBIND11_ENUM_OP_CONV_LHS("__ne__", b.is_none() || !a.equal(b));
2835-
28362854
if (is_arithmetic) {
2837-
PYBIND11_ENUM_OP_CONV("__lt__", a < b);
2838-
PYBIND11_ENUM_OP_CONV("__gt__", a > b);
2839-
PYBIND11_ENUM_OP_CONV("__le__", a <= b);
2840-
PYBIND11_ENUM_OP_CONV("__ge__", a >= b);
2841-
PYBIND11_ENUM_OP_CONV("__and__", a & b);
2842-
PYBIND11_ENUM_OP_CONV("__rand__", a & b);
2843-
PYBIND11_ENUM_OP_CONV("__or__", a | b);
2844-
PYBIND11_ENUM_OP_CONV("__ror__", a | b);
2845-
PYBIND11_ENUM_OP_CONV("__xor__", a ^ b);
2846-
PYBIND11_ENUM_OP_CONV("__rxor__", a ^ b);
28472855
m_base.attr("__invert__")
28482856
= cpp_function([](const object &arg) { return ~(int_(arg)); },
28492857
name("__invert__"),
28502858
is_method(m_base),
28512859
pos_only());
28522860
}
2853-
} else {
2854-
PYBIND11_ENUM_OP_STRICT("__eq__", int_(a).equal(int_(b)), return false);
2855-
PYBIND11_ENUM_OP_STRICT("__ne__", !int_(a).equal(int_(b)), return true);
2856-
2857-
if (is_arithmetic) {
2858-
#define PYBIND11_THROW throw type_error("Expected an enumeration of matching type!");
2859-
PYBIND11_ENUM_OP_STRICT("__lt__", int_(a) < int_(b), PYBIND11_THROW);
2860-
PYBIND11_ENUM_OP_STRICT("__gt__", int_(a) > int_(b), PYBIND11_THROW);
2861-
PYBIND11_ENUM_OP_STRICT("__le__", int_(a) <= int_(b), PYBIND11_THROW);
2862-
PYBIND11_ENUM_OP_STRICT("__ge__", int_(a) >= int_(b), PYBIND11_THROW);
2863-
#undef PYBIND11_THROW
2864-
}
28652861
}
28662862

28672863
#undef PYBIND11_ENUM_OP_CONV_LHS
@@ -2977,6 +2973,69 @@ class enum_ : public class_<Type> {
29772973

29782974
def(init([](Scalar i) { return static_cast<Type>(i); }), arg("value"));
29792975
def_property_readonly("value", [](Type value) { return (Scalar) value; }, pos_only());
2976+
#define PYBIND11_ENUM_OP_SAME_TYPE(op, expr) \
2977+
def(op, [](Type a, Type b) { return expr; }, pybind11::name(op), arg("other"), pos_only())
2978+
#define PYBIND11_ENUM_OP_SAME_TYPE_RHS_MAY_BE_NONE(op, expr) \
2979+
def(op, [](Type a, Type *b_ptr) { return expr; }, pybind11::name(op), arg("other"), pos_only())
2980+
#define PYBIND11_ENUM_OP_SCALAR(op, op_expr) \
2981+
def( \
2982+
op, \
2983+
[](Type a, Scalar b) { return static_cast<Scalar>(a) op_expr b; }, \
2984+
pybind11::name(op), \
2985+
arg("other"), \
2986+
pos_only())
2987+
#define PYBIND11_ENUM_OP_CONV_ARITHMETIC(op, op_expr) \
2988+
/* NOLINTNEXTLINE(bugprone-macro-parentheses) */ \
2989+
PYBIND11_ENUM_OP_SAME_TYPE(op, static_cast<Scalar>(a) op_expr static_cast<Scalar>(b)); \
2990+
PYBIND11_ENUM_OP_SCALAR(op, op_expr)
2991+
#define PYBIND11_ENUM_OP_REJECT_UNRELATED_TYPE(op, strict_behavior) \
2992+
def( \
2993+
op, \
2994+
[](Type, const object &) { strict_behavior; }, \
2995+
pybind11::name(op), \
2996+
arg("other"), \
2997+
pos_only())
2998+
#define PYBIND11_ENUM_OP_STRICT_ARITHMETIC(op, op_expr, strict_behavior) \
2999+
/* NOLINTNEXTLINE(bugprone-macro-parentheses) */ \
3000+
PYBIND11_ENUM_OP_SAME_TYPE(op, static_cast<Scalar>(a) op_expr static_cast<Scalar>(b)); \
3001+
PYBIND11_ENUM_OP_REJECT_UNRELATED_TYPE(op, strict_behavior);
3002+
3003+
PYBIND11_ENUM_OP_SAME_TYPE_RHS_MAY_BE_NONE("__eq__", b_ptr && a == *b_ptr);
3004+
PYBIND11_ENUM_OP_SAME_TYPE_RHS_MAY_BE_NONE("__ne__", !b_ptr || a != *b_ptr);
3005+
if (std::is_convertible<Type, Scalar>::value) {
3006+
PYBIND11_ENUM_OP_SCALAR("__eq__", ==);
3007+
PYBIND11_ENUM_OP_SCALAR("__ne__", !=);
3008+
if (is_arithmetic) {
3009+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__lt__", <);
3010+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__gt__", >);
3011+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__le__", <=);
3012+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__ge__", >=);
3013+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__and__", &);
3014+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__rand__", &);
3015+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__or__", |);
3016+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__ror__", |);
3017+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__xor__", ^);
3018+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__rxor__", ^);
3019+
}
3020+
} else if (is_arithmetic) {
3021+
#define PYBIND11_ENUM_OP_THROW_TYPE_ERROR \
3022+
throw type_error("Expected an enumeration of matching type!");
3023+
PYBIND11_ENUM_OP_STRICT_ARITHMETIC("__lt__", <, PYBIND11_ENUM_OP_THROW_TYPE_ERROR);
3024+
PYBIND11_ENUM_OP_STRICT_ARITHMETIC("__gt__", >, PYBIND11_ENUM_OP_THROW_TYPE_ERROR);
3025+
PYBIND11_ENUM_OP_STRICT_ARITHMETIC("__le__", <=, PYBIND11_ENUM_OP_THROW_TYPE_ERROR);
3026+
PYBIND11_ENUM_OP_STRICT_ARITHMETIC("__ge__", >=, PYBIND11_ENUM_OP_THROW_TYPE_ERROR);
3027+
#undef PYBIND11_ENUM_OP_THROW_TYPE_ERROR
3028+
}
3029+
PYBIND11_ENUM_OP_REJECT_UNRELATED_TYPE("__eq__", return false);
3030+
PYBIND11_ENUM_OP_REJECT_UNRELATED_TYPE("__ne__", return true);
3031+
3032+
#undef PYBIND11_ENUM_OP_SAME_TYPE
3033+
#undef PYBIND11_ENUM_OP_SAME_TYPE_RHS_MAY_BE_NONE
3034+
#undef PYBIND11_ENUM_OP_SCALAR
3035+
#undef PYBIND11_ENUM_OP_CONV_ARITHMETIC
3036+
#undef PYBIND11_ENUM_OP_REJECT_UNRELATED_TYPE
3037+
#undef PYBIND11_ENUM_OP_STRICT_ARITHMETIC
3038+
29803039
def("__int__", [](Type value) { return (Scalar) value; }, pos_only());
29813040
def("__index__", [](Type value) { return (Scalar) value; }, pos_only());
29823041
attr("__setstate__") = cpp_function(

tests/extra_python_package/test_files.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
"include/pybind11/detail/descr.h",
8484
"include/pybind11/detail/dynamic_raw_ptr_cast_if_possible.h",
8585
"include/pybind11/detail/function_record_pyobject.h",
86+
"include/pybind11/detail/function_ref.h",
8687
"include/pybind11/detail/holder_caster_foreign_helpers.h",
8788
"include/pybind11/detail/init.h",
8889
"include/pybind11/detail/internals.h",

tests/test_copy_move.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@ def test_move_and_copy_loads():
7070

7171
assert c_m.copy_assignments + c_m.copy_constructions == 0
7272
assert c_m.move_assignments == 6
73-
assert c_m.move_constructions == 9
73+
assert c_m.move_constructions == 21
7474
assert c_mc.copy_assignments + c_mc.copy_constructions == 0
7575
assert c_mc.move_assignments == 5
76-
assert c_mc.move_constructions == 8
76+
assert c_mc.move_constructions == 18
7777
assert c_c.copy_assignments == 4
78-
assert c_c.copy_constructions == 6
78+
assert c_c.copy_constructions == 14
7979
assert c_m.alive() + c_mc.alive() + c_c.alive() == 0
8080

8181

@@ -103,12 +103,12 @@ def test_move_and_copy_load_optional():
103103

104104
assert c_m.copy_assignments + c_m.copy_constructions == 0
105105
assert c_m.move_assignments == 2
106-
assert c_m.move_constructions == 5
106+
assert c_m.move_constructions == 9
107107
assert c_mc.copy_assignments + c_mc.copy_constructions == 0
108108
assert c_mc.move_assignments == 2
109-
assert c_mc.move_constructions == 5
109+
assert c_mc.move_constructions == 9
110110
assert c_c.copy_assignments == 2
111-
assert c_c.copy_constructions == 5
111+
assert c_c.copy_constructions == 9
112112
assert c_m.alive() + c_mc.alive() + c_c.alive() == 0
113113

114114

0 commit comments

Comments
 (0)