Skip to content

Commit b708dff

Browse files
committed
set_arg: Try cl_mem/svm based on what was used last
1 parent ea395d9 commit b708dff

1 file changed

Lines changed: 39 additions & 11 deletions

File tree

src/wrap_cl.hpp

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4641,16 +4641,18 @@ namespace pyopencl
46414641
{
46424642
private:
46434643
cl_kernel m_kernel;
4644+
bool m_set_arg_prefer_svm;
46444645

46454646
public:
46464647
kernel(cl_kernel knl, bool retain)
4647-
: m_kernel(knl)
4648+
: m_kernel(knl), m_set_arg_prefer_svm(false)
46484649
{
46494650
if (retain)
46504651
PYOPENCL_CALL_GUARDED(clRetainKernel, (knl));
46514652
}
46524653

46534654
kernel(program const &prg, std::string const &kernel_name)
4655+
: m_set_arg_prefer_svm(false)
46544656
{
46554657
cl_int status_code;
46564658

@@ -4806,21 +4808,47 @@ namespace pyopencl
48064808
return;
48074809
}
48084810

4809-
try
4811+
// It turns out that a taken 'catch' has a relatively high cost, so
4812+
// in deciding which of "mem object" and "svm" to try first, we use
4813+
// whatever we were given last time around.
4814+
if (m_set_arg_prefer_svm)
48104815
{
4811-
set_arg_mem(arg_index, arg.cast<memory_object_holder &>());
4812-
return;
4816+
#if PYOPENCL_CL_VERSION >= 0x2000
4817+
try
4818+
{
4819+
set_arg_svm(arg_index, arg.cast<svm_pointer const &>());
4820+
return;
4821+
}
4822+
catch (py::cast_error &) { }
4823+
#endif
4824+
4825+
try
4826+
{
4827+
set_arg_mem(arg_index, arg.cast<memory_object_holder &>());
4828+
m_set_arg_prefer_svm = false;
4829+
return;
4830+
}
4831+
catch (py::cast_error &) { }
48134832
}
4814-
catch (py::cast_error &) { }
4833+
else
4834+
{
4835+
try
4836+
{
4837+
set_arg_mem(arg_index, arg.cast<memory_object_holder &>());
4838+
return;
4839+
}
4840+
catch (py::cast_error &) { }
48154841

48164842
#if PYOPENCL_CL_VERSION >= 0x2000
4817-
try
4818-
{
4819-
set_arg_svm(arg_index, arg.cast<svm_pointer const &>());
4820-
return;
4821-
}
4822-
catch (py::cast_error &) { }
4843+
try
4844+
{
4845+
set_arg_svm(arg_index, arg.cast<svm_pointer const &>());
4846+
m_set_arg_prefer_svm = true;
4847+
return;
4848+
}
4849+
catch (py::cast_error &) { }
48234850
#endif
4851+
}
48244852

48254853
try
48264854
{

0 commit comments

Comments
 (0)