Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion python/tvm/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,4 @@
# Make the disco module optional.
disco = None # type: ignore[assignment]

from .support import _regex_match
from tvm_ffi import Shape as ShapeTuple
51 changes: 0 additions & 51 deletions python/tvm/runtime/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,59 +17,8 @@

"""Runtime support infra of TVM."""

import re
from typing import TypeVar

import tvm_ffi


@tvm_ffi.register_global_func("tvm.runtime.regex_match")
def _regex_match(regex_pattern: str, match_against: str) -> bool:
"""Check if a pattern matches a regular expression

This function should be used instead of `std::regex` within C++
call sites, to avoid ABI incompatibilities with pytorch.

Currently, the pytorch wheels available through pip install use
the pre-C++11 ABI by setting `-DUSE_CXX11_ABI=0` [0]. If TVM were to
user the pre-C++11 ABI, this would cause breakages with
dynamically-linked LLVM environments.

Use of the `<regex>` header in TVM should be avoided, as its
implementation is not supported by gcc's dual ABI. This ABI
incompatibility results in runtime errors either when `std::regex`
is called from TVM, or when `std::regex` is called from pytorch,
depending on which library was loaded first. This restriction can
be removed when a version of pytorch compiled using
`-DUSE_CXX11_ABI=1` is available from PyPI.

This is exposed as part of `libtvm_runtime.so` as it is used by
the DNNL runtime.

[0] https://github.com/pytorch/pytorch/issues/51039

Parameters
----------
regex_pattern: str

The regular expression

match_against: str

The string against which to match the regular expression

Returns
-------
match_result: bool

True if `match_against` matches the pattern defined by
`regex_pattern`, and False otherwise.

"""
match = re.match(regex_pattern, match_against)
return match is not None


T = TypeVar("T")


Expand Down
1 change: 0 additions & 1 deletion src/relax/transform/update_param_struct_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
#include <unordered_map>
#include <vector>

#include "../../runtime/regex.h"
#include "utils.h"

namespace tvm {
Expand Down
60 changes: 25 additions & 35 deletions src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include <string>
#include <vector>

#include "../../../runtime/regex.h"
#include "../json/json_node.h"
#include "../json/json_runtime.h"

Expand All @@ -49,6 +48,16 @@ namespace contrib {
using namespace tvm::runtime;
using namespace tvm::runtime::json;

namespace {
inline bool contains(const std::string& s, const std::string& sub) {
return s.find(sub) != std::string::npos;
}
template <typename... Args>
inline bool contains_any(const std::string& s, const Args&... args) {
return (contains(s, args) || ...);
}
} // namespace

class DNNLJSONRuntime : public JSONRuntimeBase {
public:
DNNLJSONRuntime(const std::string& symbol_name, const std::string& graph_json,
Expand Down Expand Up @@ -189,55 +198,43 @@ class DNNLJSONRuntime : public JSONRuntimeBase {

if (o_scl_tr || activation[0] != "none" || sum_scl_tr || dst_zp_tr) return attr;

// Define RegExp.
std::string bias_add_pat(".*_bias.*");
std::string relu_pat(".*_relu.*");
std::string tanh_pat(".*_tanh.*");
std::string sigmoid_pat(".*_sigmoid.*");
std::string clip_pat(".*_clip.*");
std::string gelu_pat(".*_gelu.*");
std::string swish_pat(".*_swish.*");
std::string sum_pat(".*_sum.*");
std::string mish_pat(".*_mish.*");

// parsing of name to extract attributes
auto op_name = nodes_[nid].GetOpName();

// Parsing post-ops.
dnnl::post_ops ops;
if (tvm::runtime::regex_match(op_name, sum_pat)) {
if (contains(op_name, "_sum")) {
ops.append_sum(1.f);
}
if (tvm::runtime::regex_match(op_name, relu_pat)) {
if (contains(op_name, "_relu")) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_relu, 0.f, 0.f);
}
if (tvm::runtime::regex_match(op_name, tanh_pat)) {
if (contains(op_name, "_tanh")) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_tanh, 0.f, 0.f);
}
if (tvm::runtime::regex_match(op_name, clip_pat)) {
if (contains(op_name, "_clip")) {
float a_min = GetNodeAttr<float>(nodes_[nid], "a_min");
float a_max = GetNodeAttr<float>(nodes_[nid], "a_max");
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_clip, a_min, a_max);
}
if (tvm::runtime::regex_match(op_name, sigmoid_pat)) {
if (contains(op_name, "_sigmoid")) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f);
}
if (tvm::runtime::regex_match(op_name, swish_pat)) {
if (contains(op_name, "_swish")) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_swish, 1.f, 1.f);
}
if (tvm::runtime::regex_match(op_name, gelu_pat)) {
if (contains(op_name, "_gelu")) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
}
if (tvm::runtime::regex_match(op_name, mish_pat)) {
if (contains(op_name, "_mish")) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_mish, 1.f, 0.f);
}
if (ops.len() != 0) {
attr.set_post_ops(ops);
}

// Parsing bias_add.
*bias_tr =
tvm::runtime::regex_match(op_name, bias_add_pat) ? GetInput(nid, 2) : TensorRequisite{};
*bias_tr = contains(op_name, "_bias") ? GetInput(nid, 2) : TensorRequisite{};

return attr;
}
Expand All @@ -250,31 +247,24 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
std::set<uint32_t> io_eid_set(run_arg_eid_.begin(), run_arg_eid_.end());
tensor_registry_ = TensorRegistry(engine_, io_eid_set);

std::string conv_pat(".*conv[1-3]d.*");
std::string deconv_pat(".*deconv[1-3]d.*");
std::string conv_transpose_pat(".*conv[1-3]d_transpose.*");
std::string dense_pat(".*dense.*");
std::string max_pool_pat(".*max_pool[1-3]d");
std::string avg_pool_pat(".*avg_pool[1-3]d");

// Build subgraph engine.
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
const auto& node = nodes_[nid];
if (node.GetOpType() == "kernel") {
TVM_FFI_ICHECK_EQ(node.GetOpType(), "kernel");
auto op_name = node.GetOpName();
if (tvm::runtime::regex_match(op_name, deconv_pat) ||
tvm::runtime::regex_match(op_name, conv_transpose_pat)) {
if (contains_any(op_name, "deconv1d", "deconv2d", "deconv3d", "conv1d_transpose",
"conv2d_transpose", "conv3d_transpose")) {
Deconvolution(nid);
} else if (tvm::runtime::regex_match(op_name, conv_pat)) {
} else if (contains_any(op_name, "conv1d", "conv2d", "conv3d")) {
Convolution(nid);
} else if (tvm::runtime::regex_match(op_name, dense_pat)) {
} else if (contains(op_name, "dense")) {
Dense(nid);
} else if ("nn.batch_norm" == op_name) {
BatchNorm(nid);
} else if (tvm::runtime::regex_match(op_name, max_pool_pat)) {
} else if (contains_any(op_name, "max_pool1d", "max_pool2d", "max_pool3d")) {
Pooling(nid, dnnl::algorithm::pooling_max);
} else if (tvm::runtime::regex_match(op_name, avg_pool_pat)) {
} else if (contains_any(op_name, "avg_pool1d", "avg_pool2d", "avg_pool3d")) {
Pooling(nid, dnnl::algorithm::pooling_avg);
} else if (elt_name2algo.count(op_name)) {
Eltwise(nid);
Expand Down
43 changes: 0 additions & 43 deletions src/runtime/regex.cc

This file was deleted.

67 changes: 0 additions & 67 deletions src/runtime/regex.h

This file was deleted.

Loading