From 286d16bec0e7c8693d1776b23acf648671a12db4 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 23 Jun 2026 15:43:02 +0000 Subject: [PATCH 1/8] [REFACTOR][IR] Use PrimType for compiler dtypes Use PrimType as the compiler-facing dtype/type carrier so primitive expression dtype information is unified with Expr.ty instead of flowing through a dedicated dtype path. This keeps compiler IR type information in the type system and leaves room for future expression type annotations. Use raw DLDataType at runtime, ABI, storage-helper, and dtype-valued attr boundaries where a plain DLPack dtype value is the real interface. Keep the PrimType API minimal and hot-path friendly with value equality, matching helpers, documented factories, and cached common constructors. Update TIRX, TE, TOPI, Relax, codegen, Python bindings, and tests to follow the compiler PrimType versus runtime DLDataType boundary. --- include/tvm/ir/base_expr.h | 311 +++++++ include/tvm/ir/expr.h | 135 +-- include/tvm/ir/type.h | 65 +- include/tvm/relax/attrs/create.h | 2 +- include/tvm/relax/attrs/datatype.h | 4 +- include/tvm/relax/attrs/image.h | 4 +- include/tvm/relax/attrs/linear_algebra.h | 2 +- include/tvm/relax/attrs/nn.h | 12 +- include/tvm/relax/attrs/qdq.h | 2 +- include/tvm/relax/attrs/sampling.h | 4 +- include/tvm/relax/attrs/sorting.h | 8 +- include/tvm/relax/attrs/statistical.h | 2 +- include/tvm/relax/dataflow_pattern.h | 8 +- include/tvm/relax/distributed/global_info.h | 1 + include/tvm/relax/expr.h | 4 +- include/tvm/relax/transform.h | 5 +- include/tvm/relax/type.h | 14 +- include/tvm/runtime/data_type.h | 522 ------------ include/tvm/runtime/disco/builtin.h | 4 +- include/tvm/runtime/tensor.h | 4 +- include/tvm/runtime/vm/bytecode.h | 2 +- include/tvm/runtime/vm/tensor_cache_support.h | 2 +- include/tvm/s_tir/data_layout.h | 4 +- include/tvm/s_tir/meta_schedule/arg_info.h | 6 +- include/tvm/script/printer/config.h | 9 +- include/tvm/script/printer/doc.h | 10 +- include/tvm/script/printer/ir_docsifier.h | 2 +- include/tvm/te/operation.h | 32 +- include/tvm/te/tensor.h | 10 +- include/tvm/tirx/buffer.h | 73 +- include/tvm/tirx/expr.h | 8 +- include/tvm/tirx/op.h | 166 ++-- include/tvm/tirx/script/builder/ir.h | 147 ++-- include/tvm/tirx/stmt.h | 2 +- include/tvm/tirx/var.h | 10 +- include/tvm/topi/broadcast.h | 12 +- include/tvm/topi/contrib/cublas.h | 4 +- include/tvm/topi/detail/broadcast.h | 18 +- include/tvm/topi/detail/extern.h | 13 +- include/tvm/topi/detail/strided_slice.h | 6 +- include/tvm/topi/detail/tensor_utils.h | 8 +- include/tvm/topi/elemwise.h | 111 ++- include/tvm/topi/nn.h | 40 +- include/tvm/topi/nn/bnn.h | 8 +- include/tvm/topi/nn/dense.h | 2 +- include/tvm/topi/nn/dilate.h | 2 +- include/tvm/topi/nn/group_norm.h | 14 +- include/tvm/topi/nn/instance_norm.h | 22 +- include/tvm/topi/nn/layer_norm.h | 21 +- include/tvm/topi/nn/local_response_norm.h | 9 +- include/tvm/topi/nn/pooling.h | 31 +- include/tvm/topi/nn/rms_norm.h | 10 +- include/tvm/topi/reduction.h | 15 +- include/tvm/topi/transform.h | 112 ++- python/tvm/ir/expr.py | 5 +- python/tvm/ir/type.py | 29 + python/tvm/relax/frontend/nn/extern.py | 2 +- .../torch/base_fx_graph_translator.py | 5 +- python/tvm/relax/op/create.py | 8 +- python/tvm/relax/op/manipulate.py | 11 +- .../relax/transform/legalize_ops/common.py | 13 +- .../transform/legalize_ops/manipulate.py | 4 +- .../tvm/relax/transform/legalize_ops/qdq.py | 7 +- python/tvm/relax/type.py | 8 +- python/tvm/runtime/object_generic.py | 6 +- python/tvm/s_tir/schedule/schedule.py | 14 +- python/tvm/script/parser/core/evaluator.py | 2 +- python/tvm/te/tensor.py | 14 + python/tvm/tirx/buffer.py | 2 +- python/tvm/tirx/expr.py | 31 +- python/tvm/tirx/script/parser/operation.py | 69 +- python/tvm/topi/math.py | 54 +- python/tvm/topi/scatter.py | 4 +- python/tvm/topi/sort.py | 2 +- src/arith/analyzer.cc | 11 +- src/arith/bound_deducer.cc | 3 +- src/arith/canonical_simplify.cc | 106 +-- src/arith/const_fold.h | 134 +-- src/arith/const_int_bound.cc | 41 +- src/arith/detect_linear_equation.cc | 24 +- src/arith/int_constraints.cc | 10 +- src/arith/int_set.cc | 46 +- src/arith/ir_mutator_with_analyzer.cc | 10 +- src/arith/ir_visitor_with_analyzer.cc | 2 +- src/arith/iter_affine_map.cc | 72 +- src/arith/pattern_match.h | 38 +- src/arith/product_normal_form.h | 5 +- src/arith/rewrite_simplify.cc | 114 +-- src/arith/solve_linear_equation.cc | 28 +- src/arith/solve_linear_inequality.cc | 25 +- src/arith/transitive_comparison_analyzer.cc | 3 +- src/arith/unwrap_vector_expr.cc | 6 +- src/arith/z3_prover.cc | 44 +- src/backend/cuda/codegen/codegen_cuda.cc | 623 +++++++------- src/backend/cuda/codegen/codegen_cuda.h | 19 +- src/backend/cuda/codegen/intrin_rule_cuda.cc | 26 +- .../cuda/codegen/llvm/codegen_nvptx.cc | 12 +- .../cuda/codegen/llvm/intrin_rule_nvptx.cc | 9 +- src/backend/cuda/runtime/cuda_device_api.cc | 16 +- .../hexagon/codegen/llvm/codegen_hexagon.cc | 42 +- .../codegen/llvm/intrin_rule_hexagon.cc | 40 +- .../hexagon/runtime/ops/conv2d_fp16_hvx.cc | 4 +- src/backend/metal/codegen/codegen_metal.cc | 73 +- src/backend/metal/codegen/codegen_metal.h | 7 +- .../metal/codegen/intrin_rule_metal.cc | 6 +- src/backend/opencl/codegen/codegen_opencl.cc | 125 +-- src/backend/opencl/codegen/codegen_opencl.h | 20 +- .../opencl/codegen/intrin_rule_opencl.cc | 4 +- src/backend/opencl/runtime/opencl_common.h | 21 +- .../opencl/runtime/opencl_device_api.cc | 6 +- src/backend/opencl/runtime/texture.h | 10 +- .../rocm/codegen/llvm/codegen_amdgpu.cc | 9 +- .../rocm/codegen/llvm/intrin_rule_rocm.cc | 22 +- src/backend/trn/codegen/codegen_trn.cc | 25 +- src/backend/trn/codegen/codegen_trn.h | 4 +- .../trn/transform/lower_trainium_layout.cc | 20 +- src/backend/vulkan/codegen/codegen_spirv.cc | 112 +-- src/backend/vulkan/codegen/codegen_spirv.h | 14 +- .../vulkan/codegen/intrin_rule_spirv.cc | 13 +- src/backend/vulkan/codegen/ir_builder.cc | 221 ++--- src/backend/vulkan/codegen/ir_builder.h | 10 +- src/backend/webgpu/codegen/codegen_webgpu.cc | 154 ++-- src/backend/webgpu/codegen/codegen_webgpu.h | 15 +- .../webgpu/codegen/intrin_rule_webgpu.cc | 10 +- src/ir/expr.cc | 150 ++-- src/ir/type.cc | 105 ++- src/relax/analysis/tir_op_pattern_kind.cc | 7 +- src/relax/analysis/type_analysis.cc | 18 +- src/relax/analysis/well_formed.cc | 4 +- .../backend/contrib/codegen_c/codegen_c.h | 15 +- src/relax/backend/contrib/cublas/codegen.cc | 4 +- src/relax/backend/contrib/utils.h | 4 +- src/relax/backend/vm/codegen_vm_tir.cc | 32 +- src/relax/backend/vm/lower_runtime_builtin.cc | 4 +- src/relax/backend/vm/vm_shape_lower.cc | 19 +- src/relax/ir/dataflow_expr_rewriter.cc | 2 +- src/relax/ir/dataflow_matcher.cc | 3 +- src/relax/ir/dataflow_pattern.cc | 10 +- src/relax/ir/dependent_type.cc | 15 +- src/relax/ir/emit_te.cc | 2 +- src/relax/ir/expr.cc | 14 +- src/relax/op/ccl/ccl.cc | 4 +- src/relax/op/distributed/binary.cc | 2 +- src/relax/op/distributed/binary.h | 4 +- src/relax/op/distributed/distributed.cc | 2 +- src/relax/op/distributed/linear_algebra.cc | 6 +- src/relax/op/distributed/nn.cc | 4 +- src/relax/op/distributed/unary.cc | 2 +- src/relax/op/distributed/unary.h | 11 +- src/relax/op/image/resize.cc | 20 +- src/relax/op/image/resize.h | 4 +- src/relax/op/memory/view.cc | 28 +- src/relax/op/nn/attention.cc | 2 +- src/relax/op/nn/convolution.cc | 82 +- src/relax/op/nn/convolution.h | 16 +- src/relax/op/nn/nn.cc | 69 +- src/relax/op/nn/pooling.cc | 6 +- src/relax/op/op.cc | 20 +- src/relax/op/op_common.h | 40 +- src/relax/op/tensor/binary.cc | 8 +- src/relax/op/tensor/create.cc | 73 +- src/relax/op/tensor/create.h | 20 +- src/relax/op/tensor/datatype.cc | 8 +- src/relax/op/tensor/datatype.h | 4 +- src/relax/op/tensor/index.cc | 27 +- src/relax/op/tensor/inspect.cc | 122 +-- src/relax/op/tensor/inspect.h | 16 +- src/relax/op/tensor/linear_algebra.cc | 26 +- src/relax/op/tensor/linear_algebra.h | 2 +- src/relax/op/tensor/manipulate.cc | 125 +-- src/relax/op/tensor/qdq.cc | 63 +- src/relax/op/tensor/qdq.h | 4 +- src/relax/op/tensor/sampling.cc | 20 +- src/relax/op/tensor/sampling.h | 3 +- src/relax/op/tensor/search.cc | 15 +- src/relax/op/tensor/set.cc | 14 +- src/relax/op/tensor/sorting.cc | 10 +- src/relax/op/tensor/sorting.h | 4 +- src/relax/op/tensor/statistical.cc | 16 +- src/relax/op/tensor/statistical.h | 4 +- src/relax/op/tensor/ternary.cc | 4 +- src/relax/op/tensor/unary.cc | 2 +- src/relax/op/vision/nms.cc | 38 +- src/relax/script/printer/dependent_type.cc | 2 +- src/relax/script/printer/distributed.cc | 4 +- src/relax/script/printer/expr.cc | 18 +- src/relax/script/printer/tir.cc | 7 +- src/relax/transform/adjust_matmul_order.cc | 20 +- src/relax/transform/allocate_workspace.cc | 4 +- src/relax/transform/alter_op_impl.cc | 10 +- src/relax/transform/call_tir_rewrite.cc | 16 +- .../transform/combine_parallel_matmul.cc | 2 +- src/relax/transform/compute_prim_value.cc | 7 +- src/relax/transform/convert_layout.cc | 2 +- src/relax/transform/dataflow_inplace.cc | 2 +- src/relax/transform/decompose_ops.cc | 10 +- src/relax/transform/expand_matmul_of_sum.cc | 3 +- src/relax/transform/fold_constant.cc | 10 +- src/relax/transform/fuse_tir.cc | 38 +- src/relax/transform/gradient.cc | 7 +- src/relax/transform/infer_amp_utils.cc | 22 +- src/relax/transform/infer_amp_utils.h | 11 +- src/relax/transform/lazy_transform_params.cc | 5 +- src/relax/transform/legalize_ops.cc | 2 +- src/relax/transform/lower_alloc_tensor.cc | 5 +- src/relax/transform/remove_unused_outputs.cc | 2 +- .../transform/remove_unused_parameters.cc | 2 +- .../transform/reorder_take_after_matmul.cc | 4 +- .../transform/split_call_tir_by_pattern.cc | 14 +- .../transform/split_layout_rewrite_preproc.cc | 4 +- .../transform/static_plan_block_memory.cc | 50 +- src/relax/transform/to_mixed_precision.cc | 53 +- src/relax/transform/utils.h | 32 +- src/relax/utils.cc | 8 +- src/runtime/extra/contrib/cblas/cblas.cc | 2 +- src/runtime/extra/contrib/cblas/dnnl_blas.cc | 2 +- src/runtime/extra/contrib/cblas/gemm_common.h | 17 +- src/runtime/extra/contrib/cblas/mkl.cc | 2 +- .../extra/contrib/coreml/coreml_runtime.mm | 16 +- src/runtime/extra/contrib/cublas/cublas.cc | 6 +- .../extra/contrib/cudnn/conv_backward.cc | 2 +- .../extra/contrib/cudnn/conv_forward.cc | 2 +- .../extra/contrib/cudnn/cudnn_utils.cc | 2 +- .../extra/contrib/cutlass/fp16_group_gemm.cuh | 12 +- .../cutlass/fp8_groupwise_scaled_gemm.cuh | 50 +- .../fp8_groupwise_scaled_group_gemm_sm100.cu | 17 +- src/runtime/extra/contrib/dnnl/dnnl_utils.cc | 8 +- src/runtime/extra/contrib/dnnl/dnnl_utils.h | 2 +- src/runtime/extra/contrib/hipblas/hipblas.cc | 2 +- src/runtime/extra/contrib/json/json_node.h | 2 +- .../extra/contrib/nvshmem/memory_allocator.cc | 4 +- src/runtime/extra/contrib/random/random.cc | 2 +- src/runtime/extra/contrib/sort/sort.cc | 4 +- src/runtime/extra/contrib/vllm/cache_alloc.cc | 4 +- .../extra/contrib/vllm/cache_kernels.cu | 6 +- src/runtime/extra/disco/builtin.cc | 4 +- .../extra/disco/cuda_ipc/cuda_ipc_memory.cc | 10 +- .../extra/disco/cuda_ipc/custom_allreduce.cc | 2 +- src/runtime/extra/disco/loader.cc | 9 +- src/runtime/extra/disco/nccl/nccl.cc | 20 +- src/runtime/extra/disco/nccl/nccl_context.h | 26 +- src/runtime/tensor.cc | 8 +- src/runtime/vm/attn_backend.h | 4 +- src/runtime/vm/attn_utils.h | 8 +- src/runtime/vm/builtin.cc | 31 +- src/runtime/vm/executable.cc | 3 +- src/runtime/vm/lm_support.cc | 23 +- src/runtime/vm/paged_kv_cache.cc | 22 +- src/runtime/vm/rnn_state.cc | 2 +- src/runtime/vm/tensor_cache_support.cc | 4 +- .../analysis/calculate_allocated_memory.cc | 2 +- src/s_tir/analysis/estimate_flops.cc | 10 +- .../analysis/sblock_access_region_detector.cc | 2 +- src/s_tir/analysis/verify_gpu_code.cc | 77 +- .../backend/adreno/inject_texture_alloc.cc | 4 +- src/s_tir/backend/adreno/texture_flatten.cc | 4 +- src/s_tir/data_layout.cc | 42 +- src/s_tir/meta_schedule/arg_info.cc | 13 +- .../meta_schedule/database/database_utils.cc | 5 +- .../feature_extractor/per_store_feature.cc | 24 +- .../measure_callback/add_to_database.cc | 2 +- src/s_tir/meta_schedule/mutator/mutator.cc | 22 +- .../postproc/rewrite_cooperative_fetch.cc | 6 +- src/s_tir/meta_schedule/profiler.cc | 2 +- .../schedule/cuda/thread_bind.cc | 2 +- .../schedule_rule/cross_thread_reduction.cc | 2 +- .../schedule_rule/multi_level_tiling.cc | 10 +- .../multi_level_tiling_tensor_core.cc | 8 +- .../parallel_vectorize_unroll.cc | 2 +- .../schedule_rule/schedule_rule.cc | 2 +- src/s_tir/meta_schedule/utils.h | 2 +- src/s_tir/schedule/analysis/layout.cc | 6 +- src/s_tir/schedule/analysis/reducer.cc | 2 +- src/s_tir/schedule/concrete_schedule.cc | 8 +- src/s_tir/schedule/concrete_schedule.h | 2 +- src/s_tir/schedule/ir_comparator.cc | 18 +- .../schedule/primitive/block_annotate.cc | 18 +- .../schedule/primitive/blockize_tensorize.cc | 12 +- src/s_tir/schedule/primitive/cache_index.cc | 24 +- .../schedule/primitive/cache_read_write.cc | 30 +- src/s_tir/schedule/primitive/compute_at.cc | 8 +- .../schedule/primitive/compute_inline.cc | 4 +- .../schedule/primitive/decompose_padding.cc | 4 +- src/s_tir/schedule/primitive/for_kind.cc | 6 +- .../primitive/layout_transformation.cc | 42 +- .../schedule/primitive/loop_transformation.cc | 34 +- src/s_tir/schedule/primitive/pad_einsum.cc | 16 +- src/s_tir/schedule/primitive/reduction.cc | 38 +- src/s_tir/schedule/transform.cc | 4 +- src/s_tir/schedule/transform.h | 2 +- src/s_tir/schedule/utils.h | 4 +- src/s_tir/transform/bound_checker.cc | 16 +- src/s_tir/transform/canonicalize_loop.cc | 4 +- src/s_tir/transform/compact_buffer_region.cc | 23 +- src/s_tir/transform/default_gpu_schedule.cc | 8 +- src/s_tir/transform/inject_double_buffer.cc | 18 +- src/s_tir/transform/inject_permuted_layout.cc | 2 +- src/s_tir/transform/inject_ptx_async_copy.cc | 10 +- src/s_tir/transform/inject_ptx_ldg32.cc | 5 +- .../transform/inject_software_pipeline.cc | 8 +- src/s_tir/transform/inject_virtual_thread.cc | 15 +- src/s_tir/transform/lift_thread_binding.cc | 2 +- src/s_tir/transform/loop_partition.cc | 8 +- src/s_tir/transform/lower_async_dma.cc | 20 +- .../transform/lower_cross_thread_reduction.cc | 12 +- src/s_tir/transform/lower_match_buffer.cc | 21 +- src/s_tir/transform/lower_opaque_block.cc | 4 +- src/s_tir/transform/lower_thread_allreduce.cc | 78 +- src/s_tir/transform/lower_vtcm_alloc.cc | 4 +- .../transform/memhammer_tensorcore_rewrite.cc | 53 +- .../merge_shared_memory_allocations.cc | 41 +- .../transform/profile_instrumentation.cc | 8 +- src/s_tir/transform/renew_defs.cc | 4 +- .../transform/renormalize_split_pattern.cc | 26 +- src/s_tir/transform/rewrite_unsafe_select.cc | 5 +- src/s_tir/transform/storage_access.cc | 8 +- src/s_tir/transform/storage_access.h | 2 +- src/s_tir/transform/thread_storage_sync.cc | 4 +- src/s_tir/transform/unify_thread_binding.cc | 18 +- .../printer/doc_printer/python_doc_printer.cc | 3 +- src/script/printer/ir/distributed.cc | 1 + src/script/printer/script_printer.cc | 6 +- src/script/printer/utils.h | 6 +- src/target/build_common.h | 2 +- src/target/intrin_rule.cc | 63 +- src/target/intrin_rule.h | 16 +- src/target/llvm/codegen_arm.cc | 34 +- src/target/llvm/codegen_cpu.cc | 47 +- src/target/llvm/codegen_cpu.h | 4 +- src/target/llvm/codegen_llvm.cc | 317 +++---- src/target/llvm/codegen_llvm.h | 32 +- src/target/llvm/codegen_params.cc | 46 +- src/target/llvm/codegen_x86_64.cc | 15 +- src/target/llvm/intrin_rule_llvm.cc | 15 +- src/target/llvm/intrin_rule_llvm.h | 8 +- src/target/source/codegen_c.cc | 223 ++--- src/target/source/codegen_c.h | 30 +- src/target/source/codegen_c_host.cc | 28 +- src/target/source/codegen_c_host.h | 4 +- src/target/source/codegen_params.cc | 57 +- src/target/source/codegen_source_base.cc | 31 +- src/target/source/codegen_source_base.h | 9 +- src/target/source/source_module.cc | 2 +- src/te/operation/compute_op.cc | 12 +- src/te/operation/create_primfunc.cc | 24 +- src/te/operation/create_primfunc.h | 4 +- src/te/operation/extern_op.cc | 4 +- src/te/operation/placeholder_op.cc | 10 +- src/te/operation/scan_op.cc | 2 +- src/te/tensor.cc | 17 +- src/tirx/analysis/deep_equal.cc | 50 +- src/tirx/ir/buffer.cc | 109 ++- src/tirx/ir/buffer_common.h | 6 +- src/tirx/ir/data_type_rewriter.cc | 152 ++-- src/tirx/ir/data_type_rewriter.h | 6 +- src/tirx/ir/exec_scope.cc | 6 +- src/tirx/ir/expr.cc | 279 ++++--- src/tirx/ir/expr_functor.cc | 4 +- src/tirx/ir/function.cc | 18 +- src/tirx/ir/index_map.cc | 15 +- src/tirx/ir/layout/axis_registry.cc | 2 +- src/tirx/ir/layout/tile_slice.cc | 4 +- src/tirx/ir/layout/utils.cc | 2 +- src/tirx/ir/script/script_complete.cc | 5 +- src/tirx/ir/stmt.cc | 98 ++- src/tirx/ir/stmt_functor.cc | 8 +- src/tirx/op/op.cc | 777 ++++++++++-------- src/tirx/script/builder/ir.cc | 111 +-- src/tirx/script/builder/utils.h | 2 +- src/tirx/script/printer/block.cc | 4 +- src/tirx/script/printer/buffer.cc | 10 +- src/tirx/script/printer/expr.cc | 25 +- src/tirx/script/printer/for_loop.cc | 4 +- src/tirx/script/printer/ir.cc | 6 +- src/tirx/script/printer/stmt.cc | 6 +- src/tirx/transform/common_subexpr_elim.cc | 5 +- src/tirx/transform/dtype_conversion.cc | 33 +- src/tirx/transform/dtype_conversion.h | 57 +- src/tirx/transform/flatten_buffer.cc | 25 +- .../transform/force_narrow_index_to_i32.cc | 8 +- src/tirx/transform/ir_utils.cc | 9 +- src/tirx/transform/ir_utils.h | 43 +- src/tirx/transform/lower_intrin.cc | 38 +- src/tirx/transform/lower_tirx_cleanup.cc | 20 +- src/tirx/transform/lower_tirx_opaque.cc | 4 +- src/tirx/transform/lower_tvm_builtin.cc | 110 +-- src/tirx/transform/lower_warp_memory.cc | 24 +- src/tirx/transform/make_packed_api.cc | 40 +- src/tirx/transform/narrow_datatype.cc | 97 +-- src/tirx/transform/split_host_device.cc | 20 +- src/tirx/transform/storage_rewrite.cc | 158 ++-- src/tirx/transform/tile_primitive_dispatch.cc | 21 +- src/tirx/transform/tvm_ffi_binder.cc | 136 +-- src/tirx/transform/tvm_ffi_binder.h | 6 +- src/tirx/transform/unroll_loop.cc | 2 +- .../transform/unsupported_dtype_legalize.cc | 183 +++-- src/tirx/transform/vectorize_loop.cc | 244 +++--- src/topi/einsum.cc | 10 +- src/topi/elemwise.cc | 6 +- src/topi/nn.cc | 2 +- src/topi/transform.cc | 8 +- tests/cpp/arith_simplify_test.cc | 8 +- tests/cpp/expr_test.cc | 15 +- tests/cpp/ir_functor_test.cc | 24 +- tests/cpp/ndarray_test.cc | 8 +- tests/cpp/nested_msg_test.cc | 27 +- tests/cpp/pattern_match_test.cc | 34 +- tests/cpp/te_compute_test.cc | 10 +- tests/cpp/tir_analysis_side_effect.cc | 8 +- tests/cpp/tir_scalable_datatype.cc | 132 +-- tests/cpp/topi_ewise_test.cc | 2 +- .../codegen/test_target_codegen_llvm.py | 2 +- tests/python/contrib/test_sort.py | 16 +- .../python/relax/frontend_nn_extern_module.cc | 16 +- .../python/relax/test_analysis_well_formed.py | 2 +- tests/python/tirx-base/test_tir_buffer.py | 4 +- tests/python/tirx-base/test_tir_intrin.py | 24 +- tests/python/tirx-base/test_tir_specialize.py | 4 +- .../tvmscript/test_tvmscript_parser_tir.py | 4 +- .../tvmscript/test_tvmscript_roundtrip.py | 4 +- 420 files changed, 6481 insertions(+), 5737 deletions(-) create mode 100644 include/tvm/ir/base_expr.h delete mode 100644 include/tvm/runtime/data_type.h diff --git a/include/tvm/ir/base_expr.h b/include/tvm/ir/base_expr.h new file mode 100644 index 000000000000..fbde9ec26aca --- /dev/null +++ b/include/tvm/ir/base_expr.h @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/base_expr.h + * \brief Base expression and primitive type nodes. + */ +#ifndef TVM_IR_BASE_EXPR_H_ +#define TVM_IR_BASE_EXPR_H_ + +#include +#include +#include +#include + +#include + +namespace tvm { + +/*! + * \brief Type is the base type of all types. + * + * TVM's type system contains following subclasses: + * + * - PrimType: type of primitive type values used in the low-level IR. + * - FuncType: type of a function. + * - TensorType: type of certain Tensor values in the expression. + * + * There are also advanced types to support generic(polymorphic types). + * \sa Type + */ +class TypeNode : public ffi::Object { + public: + /*! + * \brief Span that points to the original source code. + * Reserved debug information. + */ + mutable Span span; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + // span do not participate in structural equal and hash. + refl::ObjectDef().def_ro("span", &TypeNode::span, refl::DefaultValue(Span()), + refl::AttachFieldFlag::SEqHashIgnore()); + } + + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; + + static constexpr const uint32_t _type_child_slots = 14; + TVM_FFI_DECLARE_OBJECT_INFO("ir.Type", TypeNode, ffi::Object); +}; + +/*! + * \brief Managed reference to TypeNode. + * \sa TypeNode + */ +class Type : public ffi::ObjectRef { + public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Type, ffi::ObjectRef, TypeNode); +}; + +/*! + * \brief Primitive data types used in the low-level IR. + * + * PrimType represents POD-values and handles that are + * not automatically managed by the runtime. + * + * \sa PrimType + */ +class PrimTypeNode final : public TypeNode { + public: + /*! + * \brief The raw DLPack dtype represented by this primitive type. + */ + DLDataType dtype; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("dtype", &PrimTypeNode::dtype); + } + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.PrimType", PrimTypeNode, TypeNode); +}; + +/* + * \brief Managed reference to PrimTypeNode. + * \sa PrimTypeNode + */ +class PrimType final : public Type { + public: + /*! + * \brief Construct from a raw DLPack dtype. + * \param dtype The corresponding DLPack dtype. + */ + TVM_DLL explicit PrimType(DLDataType dtype); + + /*! + * \brief Construct from DLPack dtype fields. + * \param code The DLPack dtype code. + * \param bits The scalar bit width. + * \param lanes The fixed lane count. + */ + TVM_DLL PrimType(DLDataTypeCode code, int bits, int lanes = 1); + + /*! \brief Construct a signed integer type with fixed lanes. */ + TVM_DLL static PrimType Int(int bits, int lanes = 1); + /*! \brief Construct an unsigned integer type with fixed lanes. */ + TVM_DLL static PrimType UInt(int bits, int lanes = 1); + /*! \brief Construct a floating-point type with fixed lanes. */ + TVM_DLL static PrimType Float(int bits, int lanes = 1); + /*! \brief Construct a bfloat type with fixed lanes. */ + TVM_DLL static PrimType BFloat(int bits, int lanes = 1); + /*! \brief Construct a boolean type with fixed lanes. */ + TVM_DLL static PrimType Bool(int lanes = 1); + /*! \brief Construct an opaque handle type. */ + TVM_DLL static PrimType Handle(int bits = 64, int lanes = 1); + /*! \brief Construct the void sentinel type, encoded as handle(0, 0). */ + TVM_DLL static PrimType Void(); + /*! + * \brief Construct a scalable vector type. + * \param code The DLPack dtype code. + * \param bits The scalar bit width. + * \param lanes The positive vscale factor to encode in the DLPack lane field. + */ + TVM_DLL static PrimType ScalableVector(DLDataTypeCode code, int bits, int lanes); + + /*! \return The DLPack dtype code. */ + TVM_FFI_INLINE DLDataTypeCode code() const { + return static_cast(static_cast(get()->dtype.code)); + } + + /*! \return The scalar bit width. */ + TVM_FFI_INLINE int32_t bits() const { return get()->dtype.bits; } + + /*! + * \return The fixed lane count. + * \note Throws on scalable vector types, where the encoded lane field stores a vscale factor. + */ + TVM_FFI_INLINE int32_t lanes() const { + int16_t encoded_lanes = static_cast(get()->dtype.lanes); + if (TVM_FFI_PREDICT_FALSE(encoded_lanes < 0)) { + TVM_FFI_THROW(InternalError) + << "Can't fetch the lanes of a scalable vector at a compile time."; + } + return encoded_lanes; + } + + /*! + * \brief Check the scalar element code and bit width. + * \note Lane count and scalable-vector encoding are intentionally ignored. + */ + TVM_FFI_INLINE bool MatchesElementType(DLDataTypeCode code, int bits) const { + DLDataType dtype = get()->dtype; + return dtype.code == static_cast(code) && dtype.bits == bits; + } + + /*! + * \brief Check whether the dtype code matches any of the provided DLPack codes. + * \note Bit width and lanes are intentionally ignored. + */ + template + TVM_FFI_INLINE bool MatchesCode(Codes... codes) const { + uint8_t dtype_code = get()->dtype.code; + return ((dtype_code == static_cast(codes)) || ...); + } + + /*! \brief Whether this type is a scalar, excluding fixed and scalable vectors. */ + TVM_FFI_INLINE bool IsScalar() const { + int16_t encoded_lanes = static_cast(get()->dtype.lanes); + return encoded_lanes == 1; + } + + /*! \brief Whether this type is the void sentinel `handle(0, 0)`. */ + TVM_FFI_INLINE bool IsVoid() const { + DLDataType dtype = get()->dtype; + return dtype.code == static_cast(DLDataTypeCode::kDLOpaqueHandle) && dtype.bits == 0 && + static_cast(dtype.lanes) == 0; + } + + /*! \brief Whether this type is an opaque handle, excluding the void sentinel. */ + TVM_FFI_INLINE bool IsHandle() const { + return this->code() == DLDataTypeCode::kDLOpaqueHandle && !this->IsVoid(); + } + + /*! \brief Whether this type is a scalable vector. */ + TVM_FFI_INLINE bool IsScalableVector() const { + return static_cast(get()->dtype.lanes) < -1; + } + + /*! \brief Whether this type is a fixed-length vector. */ + TVM_FFI_INLINE bool IsFixedLengthVector() const { + return static_cast(get()->dtype.lanes) > 1; + } + + /*! \brief Return the same type with a different dtype code, preserving bits and lanes. */ + TVM_FFI_INLINE PrimType WithCode(DLDataTypeCode code) const { + DLDataType dtype = get()->dtype; + int16_t encoded_lanes = static_cast(dtype.lanes); + if (encoded_lanes < -1) { + return ScalableVector(code, dtype.bits, -encoded_lanes); + } + return PrimType(code, dtype.bits, encoded_lanes); + } + + /*! \brief Return the same type with a different scalar bit width, preserving code and lanes. */ + TVM_FFI_INLINE PrimType WithBits(int bits) const { + DLDataType dtype = get()->dtype; + int16_t encoded_lanes = static_cast(dtype.lanes); + if (encoded_lanes < -1) { + return ScalableVector(this->code(), bits, -encoded_lanes); + } + return PrimType(this->code(), bits, encoded_lanes); + } + + /*! \brief Return the same scalar element type with a fixed lane count. */ + TVM_FFI_INLINE PrimType WithLanes(int lanes) const { + return PrimType(this->code(), this->bits(), lanes); + } + + /*! \return The vscale factor encoded in a scalable vector type. */ + TVM_FFI_INLINE int32_t VScaleFactor() const { + int16_t encoded_lanes = static_cast(get()->dtype.lanes); + if (encoded_lanes >= -1) { + TVM_FFI_THROW(InternalError) << "A fixed length vector doesn't have a vscale factor."; + } + return -encoded_lanes; + } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PrimType, Type, PrimTypeNode); +}; + +inline bool operator==(const PrimType& lhs, const PrimType& rhs) { + return lhs->dtype == rhs->dtype; +} + +inline bool operator!=(const PrimType& lhs, const PrimType& rhs) { return !(lhs == rhs); } + +/*! + * \brief Base type of all the expressions. + * \sa Expr + */ +class BaseExprNode : public ffi::Object { + public: + /*! + * \brief Span that points to the original source code. + * Reserved debug information. + */ + mutable Span span; + + /*! + * \brief The deduced or annotated type of the expression. + * + * This field is intentionally nullable because type information may + * be populated by later analysis passes instead of expression + * constructors. + */ + mutable Type ty; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + // span and ty do not participate in structural equal and hash. + refl::ObjectDef() + .def_ro("span", &BaseExprNode::span, refl::DefaultValue(Span()), + refl::AttachFieldFlag::SEqHashIgnore()) + .def_ro("ty", &BaseExprNode::ty, refl::DefaultValue(Type()), + refl::AttachFieldFlag::SEqHashIgnore()); + } + + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; + + static constexpr const uint32_t _type_child_slots = 64; + TVM_FFI_DECLARE_OBJECT_INFO("ir.BaseExpr", BaseExprNode, ffi::Object); +}; + +/*! + * \brief Managed reference to BaseExprNode. + * \sa BaseExprNode + */ +class BaseExpr : public ffi::ObjectRef { + public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BaseExpr, ffi::ObjectRef, BaseExprNode); +}; + +namespace ffi { +template <> +inline constexpr bool use_default_type_traits_v = false; + +template <> +struct TypeTraits : public ObjectRefWithFallbackTraitsBase { + TVM_FFI_INLINE static PrimType ConvertFallbackValue(DLDataType dtype) { return PrimType(dtype); } +}; +} // namespace ffi + +} // namespace tvm + +#endif // TVM_IR_BASE_EXPR_H_ diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index b81e4c2feda7..70e1ffeb480c 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -24,12 +24,13 @@ #ifndef TVM_IR_EXPR_H_ #define TVM_IR_EXPR_H_ +#include #include #include #include +#include #include #include -#include #include #include @@ -54,82 +55,6 @@ class VirtualDevice; * There are also advanced types to support generic(polymorphic types). * \sa Type */ -class TypeNode : public ffi::Object { - public: - /*! - * \brief Span that points to the original source code. - * Reserved debug information. - */ - mutable Span span; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - // span do not participate in structural equal and hash. - refl::ObjectDef().def_ro("span", &TypeNode::span, refl::DefaultValue(Span()), - refl::AttachFieldFlag::SEqHashIgnore()); - } - - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - - static constexpr const uint32_t _type_child_slots = 14; - TVM_FFI_DECLARE_OBJECT_INFO("ir.Type", TypeNode, ffi::Object); -}; - -/*! - * \brief Managed reference to TypeNode. - * \sa TypeNode - */ -class Type : public ffi::ObjectRef { - public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Type, ffi::ObjectRef, TypeNode); -}; - -/*! - * \brief Base type of all the expressions. - * \sa Expr - */ -class BaseExprNode : public ffi::Object { - public: - /*! - * \brief Span that points to the original source code. - * Reserved debug information. - */ - mutable Span span; - - /*! - * \brief The deduced or annotated type of the expression. - * - * This field is intentionally nullable because type information may - * be populated by later analysis passes instead of expression - * constructors. - */ - mutable Type ty; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - // span and ty do not participate in structural equal and hash. - refl::ObjectDef() - .def_ro("span", &BaseExprNode::span, refl::DefaultValue(Span()), - refl::AttachFieldFlag::SEqHashIgnore()) - .def_ro("ty", &BaseExprNode::ty, refl::DefaultValue(Type()), - refl::AttachFieldFlag::SEqHashIgnore()); - } - - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - - static constexpr const uint32_t _type_child_slots = 64; - TVM_FFI_DECLARE_OBJECT_INFO("ir.BaseExpr", BaseExprNode, ffi::Object); -}; - -/*! - * \brief Managed reference to BaseExprNode. - * \sa BaseExprNode - */ -class BaseExpr : public ffi::ObjectRef { - public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BaseExpr, ffi::ObjectRef, BaseExprNode); -}; - /*! * \brief Base node of all primitive expressions. * @@ -144,25 +69,16 @@ class BaseExpr : public ffi::ObjectRef { */ class PrimExprNode : public BaseExprNode { public: - /*! - * \brief The runtime data type of the primitive expression. - * - * runtime::DataType(dtype) provides coarse grained type information - * during compile time and runtime. It is eagerly built in - * PrimExpr expression construction and can be used for - * quick type checking. - * - * dtype is sufficient to decide the Type of the PrimExpr - * when it corresponds to POD value types such as i32. - * - * When dtype is DataType::Handle(), the expression could corresponds to - * a more fine-grained Type, and we can get the type by running lazy type inference. - */ - DataType dtype; + /*! \return the primitive type of this expression node. */ + PrimType ty() const { + TVM_FFI_DCHECK(this->BaseExprNode::ty.defined()); + TVM_FFI_DCHECK(this->BaseExprNode::ty->IsInstance()); + return ffi::GetRef(static_cast(this->BaseExprNode::ty.get())); + } static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("dtype", &PrimExprNode::dtype); + refl::ObjectDef(); } static constexpr const uint32_t _type_child_slots = 40; @@ -186,8 +102,13 @@ class PrimExpr : public BaseExpr { */ TVM_DLL PrimExpr(float value); // NOLINT(*) - /*! \return the data type of this expression. */ - DataType dtype() const { return static_cast(get())->dtype; } + /*! \return the primitive type of this expression. */ + PrimType ty() const { + const auto* node = static_cast(get()); + TVM_FFI_DCHECK(node->BaseExprNode::ty.defined()); + TVM_FFI_DCHECK(node->BaseExprNode::ty->IsInstance()); + return ffi::GetRef(static_cast(node->BaseExprNode::ty.get())); + } TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PrimExpr, BaseExpr, PrimExprNode); @@ -554,11 +475,11 @@ class IntImm : public PrimExpr { public: /*! * \brief Constructor. - * \param dtype The data type of the value. + * \param value_ty The primitive type of the value. * \param value The internal value. * \param span The location of this object in the source code. */ - TVM_DLL IntImm(DataType dtype, int64_t value, Span span = Span()); + TVM_DLL IntImm(PrimType value_ty, int64_t value, Span span = Span()); /*! * \brief Construct a scalar boolean constant. @@ -566,7 +487,7 @@ class IntImm : public PrimExpr { * \param span The location of this object in the source code. */ static IntImm Bool(bool value, Span span = Span()) { - return IntImm(DataType::Bool(), value, span); + return IntImm(PrimType::Bool(), value, span); } /*! @@ -575,7 +496,7 @@ class IntImm : public PrimExpr { * \param span The location of this object in the source code. */ static IntImm Int32(int64_t value, Span span = Span()) { - return IntImm(DataType::Int(32), value, span); + return IntImm(PrimType::Int(32), value, span); } /*! @@ -584,7 +505,7 @@ class IntImm : public PrimExpr { * \param span The location of this object in the source code. */ static IntImm Int64(int64_t value, Span span = Span()) { - return IntImm(DataType::Int(64), value, span); + return IntImm(PrimType::Int(64), value, span); } TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IntImm, PrimExpr, IntImmNode); @@ -616,11 +537,11 @@ class FloatImm : public PrimExpr { public: /*! * \brief Constructor. - * \param dtype The data type of the value. + * \param value_ty The primitive type of the value. * \param value The internal value. * \param span The location in the source code. */ - TVM_DLL FloatImm(DataType dtype, double value, Span span = Span()); + TVM_DLL FloatImm(PrimType value_ty, double value, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FloatImm, PrimExpr, FloatImmNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(FloatImmNode); @@ -688,11 +609,11 @@ inline constexpr bool use_default_type_traits_v = false; template <> struct TypeTraits : public ObjectRefWithFallbackTraitsBase { TVM_FFI_INLINE static IntImm ConvertFallbackValue(int64_t value) { - auto dtype = + auto value_ty = (value > std::numeric_limits::max() || value < std::numeric_limits::min()) - ? DataType::Int(64) - : DataType::Int(32); - return IntImm(dtype, value); + ? PrimType::Int(64) + : PrimType::Int(32); + return IntImm(value_ty, value); } }; @@ -702,7 +623,7 @@ inline constexpr bool use_default_type_traits_v = false; template <> struct TypeTraits : public ObjectRefWithFallbackTraitsBase { TVM_FFI_INLINE static FloatImm ConvertFallbackValue(double value) { - return FloatImm(runtime::DataType::Float(32), value); + return FloatImm(PrimType::Float(32), value); } }; diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 9c56d0376405..f63b5d261500 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -26,21 +26,19 @@ * * This file contains types that are common across IR variants. * - * ## Relation between Type and runtime::DataType + * ## Relation between Type and DLPack dtype * - * Besides Type, we also store a dtype field in the low-level PrimExpr. - * runtime::DataType(dtype) provides coarse grained type information - * during compile time and runtime. It is eagerly built in - * low-level expression construction and can be used for - * quick type checking in the low-level IR. - * For example, when an Expr's dtype is int32, - * we know for sure that its type is also int32. + * PrimExpr stores a PrimType in its `ty` field, backed by a DLPack + * `DLDataType`. This provides coarse grained scalar/vector element type + * information during compile time and runtime. It is eagerly built in + * low-level expression construction and can be used for quick type checking + * in the low-level IR. For example, when an Expr's dtype is int32, we know + * for sure that its PrimType is also int32. * * On the other hand, Type provides more fine grained information. - * For example, a low level expression can have DataType::Handle() as - * its dtype and MemRef[float32] as its type. - * Types are usually lazily constructed via type checking, - * so they may not readily be available during IR construction. + * For example, a low level expression can have a handle dtype while a + * node-specific type annotation records a + * PointerType to a float32 element. * * The unified Type serves as a common bridge across IR dialects. * For example, we require all the functions to have a type signature, @@ -49,55 +47,16 @@ #ifndef TVM_IR_TYPE_H_ #define TVM_IR_TYPE_H_ -#include #include +#include #include -#include +#include #include -#include #include namespace tvm { -/*! - * \brief Primitive data types used in the low-level IR. - * - * PrimType represents POD-values and handles that are - * not automatically managed by the runtime. - * - * \sa PrimType - */ -class PrimTypeNode : public TypeNode { - public: - /*! - * \brief The corresponding dtype field. - */ - runtime::DataType dtype; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("dtype", &PrimTypeNode::dtype); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.PrimType", PrimTypeNode, TypeNode); -}; - -/* - * \brief Managed reference to PrimTypeNode. - * \sa PrimTypeNode - */ -class PrimType : public Type { - public: - /*! - * \brief Constructor - * \param dtype The corresponding dtype. - * \param span The span - */ - TVM_DLL explicit PrimType(runtime::DataType dtype, Span span = Span()); - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(PrimType, Type, PrimTypeNode); -}; - /*! * \brief Low-level raw pointer type. * diff --git a/include/tvm/relax/attrs/create.h b/include/tvm/relax/attrs/create.h index 14a3402f2503..76ef219a862c 100644 --- a/include/tvm/relax/attrs/create.h +++ b/include/tvm/relax/attrs/create.h @@ -31,7 +31,7 @@ namespace relax { /*! \brief Attributes used in full/full_like, ones/ones_like, and zeros/zeros_like operators */ struct InitAttrs : public AttrsNode { - DataType dtype; + DLDataType dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; diff --git a/include/tvm/relax/attrs/datatype.h b/include/tvm/relax/attrs/datatype.h index f67223edb546..aeac65e64484 100644 --- a/include/tvm/relax/attrs/datatype.h +++ b/include/tvm/relax/attrs/datatype.h @@ -31,7 +31,7 @@ namespace relax { /*! \brief Attributes used in astype operator */ struct AstypeAttrs : public AttrsNode { - DataType dtype; + DLDataType dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -42,7 +42,7 @@ struct AstypeAttrs : public AttrsNode { /*! \brief Attributes used in wrap_param operator */ struct WrapParamAttrs : public AttrsNode { - DataType dtype; + DLDataType dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; diff --git a/include/tvm/relax/attrs/image.h b/include/tvm/relax/attrs/image.h index c9a720374036..8f512f28e55f 100644 --- a/include/tvm/relax/attrs/image.h +++ b/include/tvm/relax/attrs/image.h @@ -39,7 +39,7 @@ struct Resize2DAttrs : public AttrsNode { double cubic_alpha; int cubic_exclude; double extrapolation_value; - DataType out_dtype; + DLDataType out_dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -88,7 +88,7 @@ struct Resize3DAttrs : public AttrsNode { double cubic_alpha; int cubic_exclude; double extrapolation_value; - DataType out_dtype; + DLDataType out_dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; diff --git a/include/tvm/relax/attrs/linear_algebra.h b/include/tvm/relax/attrs/linear_algebra.h index 817885edb871..19a5982bfe12 100644 --- a/include/tvm/relax/attrs/linear_algebra.h +++ b/include/tvm/relax/attrs/linear_algebra.h @@ -31,7 +31,7 @@ namespace relax { /*! \brief Attributes for matmul operator */ struct MatmulAttrs : public AttrsNode { - DataType out_dtype; + DLDataType out_dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 52d9c40d742d..aa3c0f4736f0 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -38,7 +38,7 @@ struct Conv1DAttrs : public AttrsNode { ffi::String data_layout; ffi::String kernel_layout; ffi::String out_layout; - DataType out_dtype; + DLDataType out_dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -82,7 +82,7 @@ struct Conv2DAttrs : public AttrsNode { ffi::String data_layout; ffi::String kernel_layout; ffi::String out_layout; - DataType out_dtype; + DLDataType out_dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -128,7 +128,7 @@ struct Conv3DAttrs : public AttrsNode { ffi::String data_layout; ffi::String kernel_layout; ffi::String out_layout; - DataType out_dtype; + DLDataType out_dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -177,7 +177,7 @@ struct Conv1DTransposeAttrs : public AttrsNode { ffi::String data_layout; ffi::String kernel_layout; ffi::String out_layout; - DataType out_dtype; + DLDataType out_dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -226,7 +226,7 @@ struct Conv2DTransposeAttrs : public AttrsNode { ffi::String data_layout; ffi::String kernel_layout; ffi::String out_layout; - DataType out_dtype; + DLDataType out_dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -277,7 +277,7 @@ struct Conv3DTransposeAttrs : public AttrsNode { ffi::String data_layout; ffi::String kernel_layout; ffi::String out_layout; - DataType out_dtype; + DLDataType out_dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; diff --git a/include/tvm/relax/attrs/qdq.h b/include/tvm/relax/attrs/qdq.h index 83ec2223c3c7..be95b9e7b8ed 100644 --- a/include/tvm/relax/attrs/qdq.h +++ b/include/tvm/relax/attrs/qdq.h @@ -31,7 +31,7 @@ namespace relax { /*! \brief Attributes for relax.quantize/relax.dequantize operator */ struct QuantizeAttrs : public AttrsNode { - DataType out_dtype; + DLDataType out_dtype; int axis; static void RegisterReflection() { diff --git a/include/tvm/relax/attrs/sampling.h b/include/tvm/relax/attrs/sampling.h index 11bbfb6eba31..07b7de25e553 100644 --- a/include/tvm/relax/attrs/sampling.h +++ b/include/tvm/relax/attrs/sampling.h @@ -31,13 +31,13 @@ namespace relax { /*! \brief Attributes used in multinomial_from_uniform operator */ struct MultinomialFromUniformAttrs : public AttrsNode { - DataType dtype; + DLDataType dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro( "dtype", &MultinomialFromUniformAttrs::dtype, "Data type of the output indices.", - refl::DefaultValue(DataType::Int(64))); + refl::DefaultValue((DLDataType{kDLInt, 64, 1}))); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.MultinomialFromUniformAttrs", MultinomialFromUniformAttrs, AttrsNode); diff --git a/include/tvm/relax/attrs/sorting.h b/include/tvm/relax/attrs/sorting.h index e8bf65d55a43..ef21bf9a637e 100644 --- a/include/tvm/relax/attrs/sorting.h +++ b/include/tvm/relax/attrs/sorting.h @@ -54,7 +54,7 @@ struct SortAttrs : public AttrsNode { struct ArgsortAttrs : public AttrsNode { int axis; bool descending; - DataType dtype; + DLDataType dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -68,7 +68,7 @@ struct ArgsortAttrs : public AttrsNode { "If it is not specified, it defaults to the ascending order.", refl::DefaultValue(false)) .def_ro("dtype", &ArgsortAttrs::dtype, "DType of the output indices.", - refl::DefaultValue(DataType::Void())); + refl::DefaultValue((DLDataType{kDLOpaqueHandle, 0, 0}))); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ArgsortAttrs", ArgsortAttrs, AttrsNode); }; // struct ArgsortAttrs @@ -79,7 +79,7 @@ struct TopKAttrs : public AttrsNode { int axis; bool largest; ffi::String ret_type; - DataType dtype; + DLDataType dtype; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -98,7 +98,7 @@ struct TopKAttrs : public AttrsNode { "By default, return the largest k elements.", refl::DefaultValue(true)) .def_ro("dtype", &TopKAttrs::dtype, "Data type of the output indices.", - refl::DefaultValue(DataType::Void())); + refl::DefaultValue((DLDataType{kDLOpaqueHandle, 0, 0}))); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.TopKAttrs", TopKAttrs, AttrsNode); }; // struct TopKAttrs diff --git a/include/tvm/relax/attrs/statistical.h b/include/tvm/relax/attrs/statistical.h index 66996c802cc3..a815e0e07e51 100644 --- a/include/tvm/relax/attrs/statistical.h +++ b/include/tvm/relax/attrs/statistical.h @@ -50,7 +50,7 @@ struct StatisticalAttrs : public AttrsNode { /*! \brief Attributes used in scan operators like cumsum, cumprod */ struct ScanopAttrs : public AttrsNode { ffi::Optional axis; - DataType dtype; + DLDataType dtype; bool exclusive = false; static void RegisterReflection() { diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h index 27894da3addd..0511395f8a67 100644 --- a/include/tvm/relax/dataflow_pattern.h +++ b/include/tvm/relax/dataflow_pattern.h @@ -116,8 +116,8 @@ class DFPattern : public ffi::ObjectRef { TVM_DLL AttrPattern HasAttr(const ffi::Map& attrs) const; /*! \brief Syntatic Sugar for creating a TypePattern */ TVM_DLL TypePattern HasType(const Type& ty) const; - /*! \brief Syntatic Sugar for creating a DataTypePattern with a DataType */ - TVM_DLL DataTypePattern HasDtype(const DataType& dtype) const; + /*! \brief Syntatic Sugar for creating a DataTypePattern with a dtype */ + TVM_DLL DataTypePattern HasDtype(DLDataType dtype) const; /*! \brief Syntatic Sugar for creating a DataTypePattern with a data type's name */ TVM_DLL DataTypePattern HasDtype(const std::string& dtype) const; /*! \brief Syntatic Sugar for creating a ShapePattern */ @@ -860,7 +860,7 @@ class SameShapeConstraint : public DFConstraint { class DataTypePatternNode : public DFPatternNode { public: DFPattern pattern; /*!< The root pattern to match */ - DataType dtype; /*!< The data type to match */ + DLDataType dtype; /*!< The data type to match */ static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -878,7 +878,7 @@ class DataTypePatternNode : public DFPatternNode { */ class DataTypePattern : public DFPattern { public: - TVM_DLL DataTypePattern(DFPattern pattern, DataType dtype); + TVM_DLL DataTypePattern(DFPattern pattern, DLDataType dtype); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DataTypePattern, DFPattern, DataTypePatternNode); }; diff --git a/include/tvm/relax/distributed/global_info.h b/include/tvm/relax/distributed/global_info.h index 62ff904fc1a4..0347ec3b85a8 100644 --- a/include/tvm/relax/distributed/global_info.h +++ b/include/tvm/relax/distributed/global_info.h @@ -25,6 +25,7 @@ #ifndef TVM_RELAX_DISTRIBUTED_GLOBAL_INFO_H_ #define TVM_RELAX_DISTRIBUTED_GLOBAL_INFO_H_ +#include #include #include namespace tvm { diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 937091255b6f..0b75bf27a7d2 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -471,7 +471,7 @@ class StringImm : public LeafExpr { class DataTypeImmNode : public LeafExprNode { public: /*! \brief The data value. */ - DataType value; + DLDataType value; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -491,7 +491,7 @@ class DataTypeImm : public LeafExpr { * \param value The value input. * \param span The source span of the expression. */ - TVM_DLL explicit DataTypeImm(DataType value, Span span = Span()); + TVM_DLL explicit DataTypeImm(DLDataType value, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DataTypeImm, LeafExpr, DataTypeImmNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(DataTypeImmNode); diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index d0d0d1bb5441..5c757ba15161 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -663,9 +663,8 @@ TVM_DLL Pass DataflowUseInplaceCalls(); * * \note Mainly operates within dataflow blocks. ConvertToDataflow may need to be called first. */ -TVM_DLL Pass -ToMixedPrecision(const DataType& out_dtype, - ffi::Optional> fp16_input_names = std::nullopt); +TVM_DLL Pass ToMixedPrecision( + DLDataType out_dtype, ffi::Optional> fp16_input_names = std::nullopt); /*! * \brief Rewrite a Relax module for executing with CUDA graph. This pass identifies diff --git a/include/tvm/relax/type.h b/include/tvm/relax/type.h index 9c27b627a7d6..a77a3cc66c38 100644 --- a/include/tvm/relax/type.h +++ b/include/tvm/relax/type.h @@ -124,7 +124,7 @@ class ShapeTypeNode : public TypeNode { * \brief The number of dimension of the shape, can be unknown. * \sa kUnknownNDim */ - int ndim; + int ndim{kUnknownNDim}; /*! \return Whether the type contains unknown ndim. */ bool IsUnknownNdim() const { return ndim == kUnknownNDim; } @@ -174,19 +174,19 @@ class TensorTypeNode : public TypeNode { * is expected to be executed. */ ffi::Optional vdevice; - /*! \brief The content data type, use void to denote the dtype is unknown. */ - DataType dtype; + /*! \brief The content dtype, use void to denote the dtype is unknown. */ + tvm::PrimType dtype{DLDataType{kDLOpaqueHandle, 0, 0}}; /*! * \brief The number of dimension of the tensor, can be unknown. * \sa kUnknownNDim */ - int ndim; + int ndim{kUnknownNDim}; /*! \return Whether the type contains unknown ndim. */ bool IsUnknownNdim() const { return ndim == kUnknownNDim; } /*! \return Whether the type contains unknown dtype. */ - bool IsUnknownDtype() const { return dtype.is_void(); } + bool IsUnknownDtype() const { return dtype->dtype == DLDataType{kDLOpaqueHandle, 0, 0}; } /*! \return Shape if it is known. */ ffi::Optional> GetShape() const { @@ -230,7 +230,7 @@ class TensorType : public Type { * * \note shape must already be normalized. */ - TVM_DLL TensorType(Expr shape, DataType dtype, ffi::Optional vdevice = std::nullopt, + TVM_DLL TensorType(Expr shape, tvm::PrimType dtype, ffi::Optional vdevice = std::nullopt, Span span = Span()); /*! @@ -240,7 +240,7 @@ class TensorType : public Type { * \param vdevice The virtual device. * \param span The span of the AST. */ - TVM_DLL TensorType(DataType dtype, int ndim, ffi::Optional vdevice = std::nullopt, + TVM_DLL TensorType(tvm::PrimType dtype, int ndim, ffi::Optional vdevice = std::nullopt, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TensorType, Type, TensorTypeNode); diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h deleted file mode 100644 index 9f230cac824e..000000000000 --- a/include/tvm/runtime/data_type.h +++ /dev/null @@ -1,522 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file tvm/runtime/data_type.h - * \brief Primitive runtime data type. - */ -// Acknowledgement: DataType structure design originates from Halide. -#ifndef TVM_RUNTIME_DATA_TYPE_H_ -#define TVM_RUNTIME_DATA_TYPE_H_ - -#include -#include -#include -#include - -#include -#include -#include - -namespace tvm { -namespace runtime { - -/*! - * \brief Runtime primitive data type. - * - * This class is a thin wrapper of DLDataType. - * We also make use of DataType in compiler to store quick hint - */ -class DataType { - public: - /*! - * \brief Type code for the DataType. - * - * DLPack consistency: - * 1) kInt is consistent with kDLInt - * 2) kUInt is consistent with kDLUInt - * 3) kFloat is consistent with kDLFloat - */ - enum TypeCode { - kInt = kDLInt, - kUInt = kDLUInt, - kFloat = kDLFloat, - kHandle = kDLOpaqueHandle, - kBFloat = kDLBfloat, - kBool = kDLBool, - kFloat8_e3m4 = kDLFloat8_e3m4, - kFloat8_e4m3 = kDLFloat8_e4m3, - kFloat8_e4m3b11fnuz = kDLFloat8_e4m3b11fnuz, - kFloat8_e4m3fn = kDLFloat8_e4m3fn, - kFloat8_e4m3fnuz = kDLFloat8_e4m3fnuz, - kFloat8_e5m2 = kDLFloat8_e5m2, - kFloat8_e5m2fnuz = kDLFloat8_e5m2fnuz, - kFloat8_e8m0fnu = kDLFloat8_e8m0fnu, - kFloat6_e2m3fn = kDLFloat6_e2m3fn, - kFloat6_e3m2fn = kDLFloat6_e3m2fn, - kFloat4_e2m1fn = kDLFloat4_e2m1fn, - kCustomBegin = 129 - }; - /*! \brief default constructor */ - DataType() { data_ = DataType::Void(); } - /*! - * \brief Constructor - * \param dtype The DLDataType - */ - explicit DataType(DLDataType dtype) : data_(dtype) {} - /*! - * \brief Constructor - * \param code The type code. - * \param bits The number of bits in the type. - * \param lanes The number of lanes. - * \param is_scalable Whether the data type is scalable. - */ - DataType(int code, int bits, int lanes, bool is_scalable = false) { - data_.code = static_cast(code); - data_.bits = static_cast(bits); - if (is_scalable) { - TVM_FFI_ICHECK(lanes > 1) << "Invalid value for vscale factor" << lanes; - } - data_.lanes = is_scalable ? static_cast(-lanes) : static_cast(lanes); - if (code == kBFloat) { - TVM_FFI_ICHECK_EQ(bits, 16); - } - if (code == kFloat8_e3m4 || code == kFloat8_e4m3 || code == kFloat8_e4m3b11fnuz || - code == kFloat8_e4m3fn || code == kFloat8_e4m3fnuz || code == kFloat8_e5m2 || - code == kFloat8_e5m2fnuz || code == kFloat8_e8m0fnu) { - TVM_FFI_ICHECK_EQ(bits, 8); - } - if (code == kFloat6_e2m3fn || code == kFloat6_e3m2fn) { - TVM_FFI_ICHECK_EQ(bits, 6); - } - if (code == kFloat4_e2m1fn) { - TVM_FFI_ICHECK_EQ(bits, 4); - } - } - /*! \return The type code. */ - int code() const { return static_cast(data_.code); } - /*! \return number of bits in the data. */ - int bits() const { return static_cast(data_.bits); } - /*! \return number of bytes to store each scalar. */ - int bytes() const { return (bits() + 7) / 8; } - /*! \return number of lanes in the data. */ - int lanes() const { - int lanes_as_int = static_cast(data_.lanes); - if (lanes_as_int < 0) { - TVM_FFI_THROW(InternalError) - << "Can't fetch the lanes of a scalable vector at a compile time."; - } - return lanes_as_int; - } - /*! \return the integer multiplier of vscale in a scalable vector. */ - int vscale_factor() const { - int lanes_as_int = static_cast(data_.lanes); - if (lanes_as_int >= -1) { - TVM_FFI_THROW(InternalError) << "A fixed length vector doesn't have a vscale factor."; - } - return -lanes_as_int; - } - /*! \return get vscale factor or lanes depending on scalability of the vector. */ - int get_lanes_or_vscale_factor() const { - return is_scalable_vector() ? vscale_factor() : lanes(); - } - /*! \return whether type is a scalar type. */ - bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; } - /*! \return whether type is a bool type. */ - bool is_bool() const { return code() == DataType::kBool; } - /*! \return whether type can be used in a predicate expression. */ - bool is_predicate_dtype() const { return is_bool() || (is_uint() && bits() == 1); } - /*! \return whether type is a float type. */ - bool is_float() const { return code() == DataType::kFloat; } - /*! \return whether type is a bfloat type. */ - bool is_bfloat() const { return code() == DataType::kBFloat; } - /*! \return whether type is any 8-bit custom Float8 variant. */ - bool is_float8() const { - return bits() == 8 && - (code() == DataType::kFloat8_e3m4 || code() == DataType::kFloat8_e4m3 || - code() == DataType::kFloat8_e4m3b11fnuz || code() == DataType::kFloat8_e4m3fn || - code() == DataType::kFloat8_e4m3fnuz || code() == DataType::kFloat8_e5m2 || - code() == DataType::kFloat8_e5m2fnuz || code() == DataType::kFloat8_e8m0fnu); - } - /*! \return whether type is any 6-bit custom Float6 variant. */ - bool is_float6() const { - return bits() == 6 && - (code() == DataType::kFloat6_e2m3fn || code() == DataType::kFloat6_e3m2fn); - } - /*! \return whether type is the 4-bit custom Float4_e2m1fn variant. */ - bool is_float4() const { return bits() == 4 && code() == DataType::kFloat4_e2m1fn; } - /*! \return whether type is Float8E3M4. */ - bool is_float8_e3m4() const { return bits() == 8 && code() == DataType::kFloat8_e3m4; } - /*! \return whether type is Float8E4M3. */ - bool is_float8_e4m3() const { return bits() == 8 && code() == DataType::kFloat8_e4m3; } - /*! \return whether type is Float8E4M3B11FNUZ. */ - bool is_float8_e4m3b11fnuz() const { - return bits() == 8 && code() == DataType::kFloat8_e4m3b11fnuz; - } - /*! \return whether type is Float8E4M3FN. */ - bool is_float8_e4m3fn() const { return bits() == 8 && code() == DataType::kFloat8_e4m3fn; } - /*! \return whether type is Float8E4M3FNUZ. */ - bool is_float8_e4m3fnuz() const { return bits() == 8 && code() == DataType::kFloat8_e4m3fnuz; } - /*! \return whether type is Float8E5M2. */ - bool is_float8_e5m2() const { return bits() == 8 && code() == DataType::kFloat8_e5m2; } - /*! \return whether type is Float8E5M2FNUZ. */ - bool is_float8_e5m2fnuz() const { return bits() == 8 && code() == DataType::kFloat8_e5m2fnuz; } - /*! \return whether type is Float8E8M0FNU. */ - bool is_float8_e8m0fnu() const { return bits() == 8 && code() == DataType::kFloat8_e8m0fnu; } - /*! \return whether type is Float6E2M3FN. */ - bool is_float6_e2m3fn() const { return bits() == 6 && code() == DataType::kFloat6_e2m3fn; } - /*! \return whether type is Float6E3M2FN. */ - bool is_float6_e3m2fn() const { return bits() == 6 && code() == DataType::kFloat6_e3m2fn; } - /*! \return whether type is Float4E2M1FN. */ - bool is_float4_e2m1fn() const { return bits() == 4 && code() == DataType::kFloat4_e2m1fn; } - /*! \return whether type is a float16 type. */ - bool is_float16() const { return is_float() && bits() == 16; } - /*! \return whether type is a bfloat16 type. */ - bool is_bfloat16() const { return code() == DataType::kBFloat && bits() == 16; } - /*! \return whether type is an int type. */ - bool is_int() const { return code() == DataType::kInt; } - /*! \return whether type is an uint type. */ - bool is_uint() const { return code() == DataType::kUInt; } - /*! \return whether type is a handle type. */ - bool is_handle() const { return code() == DataType::kHandle && !is_void(); } - /*! \return whether type is a vector type. */ - bool is_scalable_or_fixed_length_vector() const { - int encoded_lanes = static_cast(data_.lanes); - return (encoded_lanes < -1) || (1 < encoded_lanes); - } - /*! \return Whether the type is a fixed length vector. */ - bool is_fixed_length_vector() const { return static_cast(data_.lanes) > 1; } - /*! \return Whether the type is a scalable vector. */ - bool is_scalable_vector() const { return static_cast(data_.lanes) < -1; } - /*! \return whether type is a vector type. */ - bool is_vector() const { return lanes() > 1; } - /*! \return whether type is a bool vector type. */ - bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && is_bool(); } - /*! \return whether type is a Void type. */ - bool is_void() const { - return code() == DataType::kHandle && bits() == 0 && static_cast(data_.lanes) == 0; - } - /*! - * \brief Create a new data type by change lanes to a specified value. - * \param lanes The target number of lanes. - * \return the result type. - */ - DataType with_lanes(int lanes) const { return DataType(data_.code, data_.bits, lanes); } - /*! - * \brief Create a new scalable vector data type by changing the vscale multiplier to a specified - * value. We'll use the data_.lanes field for this value. \param vscale_factor The vscale - * multiplier. \return A copy of the old DataType with the number of scalable lanes. - */ - DataType with_scalable_vscale_factor(int vscale_factor) const { - return DataType(data_.code, data_.bits, -vscale_factor); - } - /*! - * \brief Create a new data type by change bits to a specified value. - * \param bits The target number of bits. - * \return the result type. - */ - DataType with_bits(int bits) const { return DataType(data_.code, bits, data_.lanes); } - /*! - * \brief Get the scalar version of the type. - * \return the result type. - */ - DataType element_of() const { return with_lanes(1); } - /*! - * \brief Assignment operator. - */ - DataType& operator=(const DataType& rhs) { - if (this == &rhs) { - return *this; - } - data_ = rhs.data_; - return *this; - } - /*! - * \brief Equal comparator. - * \param other The data type to compare against. - * \return The comparison result. - */ - bool operator==(const DataType& other) const { - return data_.code == other.data_.code && data_.bits == other.data_.bits && - data_.lanes == other.data_.lanes; - } - /*! - * \brief NotEqual comparator. - * \param other The data type to compare against. - * \return The comparison result. - */ - bool operator!=(const DataType& other) const { return !operator==(other); } - /*! - * \brief Converter to DLDataType - * \return the result. - */ - operator DLDataType() const { return data_; } - - /*! - * \brief Construct an int type. - * \param bits The number of bits in the type. - * \param lanes The number of lanes. - * \return The constructed data type. - */ - static DataType Int(int bits, int lanes = 1) { return DataType(kDLInt, bits, lanes); } - /*! - * \brief Construct an uint type. - * \param bits The number of bits in the type. - * \param lanes The number of lanes. - * \param is_scalable Whether the data type is scalable. - * \return The constructed data type. - */ - static DataType UInt(int bits, int lanes = 1, bool is_scalable = false) { - return DataType(kDLUInt, bits, lanes, is_scalable); - } - /*! - * \brief Construct an float type. - * \param bits The number of bits in the type. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType Float(int bits, int lanes = 1) { return DataType(kDLFloat, bits, lanes); } - /*! - * \brief Construct an bfloat type. - * \param bits The number of bits in the type. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType BFloat(int bits, int lanes = 1) { return DataType(kDLBfloat, bits, lanes); } - /*! - * \brief Construct float8 e3m4 datatype. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType Float8E3M4(int lanes = 1) { return DataType(kFloat8_e3m4, 8, lanes); } - - /*! - * \brief Construct float8 e4m3 datatype. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType Float8E4M3(int lanes = 1) { return DataType(kFloat8_e4m3, 8, lanes); } - - /*! - * \brief Construct float8 e4m3b11fnuz datatype. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType Float8E4M3B11FNUZ(int lanes = 1) { - return DataType(kFloat8_e4m3b11fnuz, 8, lanes); - } - - /*! - * \brief Construct float8 e4m3fn datatype. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType Float8E4M3FN(int lanes = 1) { return DataType(kFloat8_e4m3fn, 8, lanes); } - - /*! - * \brief Construct float8 e4m3fnuz datatype. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType Float8E4M3FNUZ(int lanes = 1) { return DataType(kFloat8_e4m3fnuz, 8, lanes); } - - /*! - * \brief Construct float8 e5m2 datatype. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType Float8E5M2(int lanes = 1) { return DataType(kFloat8_e5m2, 8, lanes); } - - /*! - * \brief Construct float8 e5m2fnuz datatype. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType Float8E5M2FNUZ(int lanes = 1) { return DataType(kFloat8_e5m2fnuz, 8, lanes); } - - /*! - * \brief Construct float8 e8m0fnu datatype. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType Float8E8M0FNU(int lanes = 1) { return DataType(kFloat8_e8m0fnu, 8, lanes); } - - /*! - * \brief Construct float6 e2m3fn datatype. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType Float6E2M3FN(int lanes = 1) { return DataType(kFloat6_e2m3fn, 6, lanes); } - - /*! - * \brief Construct float6 e3m2fn datatype. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType Float6E3M2FN(int lanes = 1) { return DataType(kFloat6_e3m2fn, 6, lanes); } - - /*! - * \brief Construct float4 e2m1fn datatype. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType Float4E2M1FN(int lanes = 1) { return DataType(kFloat4_e2m1fn, 4, lanes); } - /*! - * \brief Construct a bool type. - * \param lanes The number of lanes. - * \param is_scalable Whether the data type is scalable. - * \return The constructed data type. - */ - static DataType Bool(int lanes = 1, bool is_scalable = false) { - return DataType(kDLBool, 8, lanes, is_scalable); - } - /*! - * \brief Construct a handle type. - * \param bits The number of bits in the type. - * \param lanes The number of lanes - * \return The constructed data type. - */ - static DataType Handle(int bits = 64, int lanes = 1) { return DataType(kHandle, bits, lanes); } - /*! - * \brief Construct a Void type. - * \return The constructed data type. - */ - static DataType Void() { return DataType(kHandle, 0, 0); } - /*! - * \brief Get the corresponding type of TVMShapeIndex. - * \return The type of TVM shape index. - */ - static DataType ShapeIndex() { - if (std::is_signed::value) { - return DataType::Int(sizeof(ffi::Shape::index_type) * 8); - } else { - return DataType::UInt(sizeof(ffi::Shape::index_type) * 8); - } - } - - private: - DLDataType data_; -}; - -/*! - * \brief Get the number of bytes needed in a vector. - * \param dtype The data type. - * \return Number of bytes needed. - */ -inline int GetVectorBytes(DataType dtype) { - int data_bits = dtype.bits() * dtype.lanes(); - // allow bool to exist - if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == DataType::UInt(4) || - dtype == DataType::Int(1) || dtype == DataType::Float4E2M1FN() || - dtype == DataType::Float6E2M3FN() || dtype == DataType::Float6E3M2FN()) { - return 1; - } - TVM_FFI_ICHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes"; - return data_bits / 8; -} - -/*! - * \brief Check whether type matches the given spec. - * \param t The type - * \param code The type code. - * \param bits The number of bits to be matched. - * \param lanes The number of lanes in the type. - */ -inline bool TypeMatch(DLDataType t, int code, int bits, int lanes = 1) { - return t.code == code && t.bits == bits && t.lanes == lanes; -} -/*! - * \brief Check whether two types are equal . - * \param lhs The left operand. - * \param rhs The right operand. - */ -inline bool TypeEqual(DLDataType lhs, DLDataType rhs) { - return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes; -} - -inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*) - return os << dtype.operator DLDataType(); -} -} // namespace runtime - -using DataType = runtime::DataType; - -namespace ffi { - -// runtime::DataType -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDataType; - - TVM_FFI_INLINE static void CopyToAnyView(const runtime::DataType& src, TVMFFIAny* result) { - // clear padding part to ensure the equality check can always check the v_uint64 part - result->v_uint64 = 0; - result->zero_padding = 0; - result->type_index = TypeIndex::kTVMFFIDataType; - result->v_dtype = src; - } - - TVM_FFI_INLINE static void MoveToAny(runtime::DataType src, TVMFFIAny* result) { - // clear padding part to ensure the equality check can always check the v_uint64 part - result->v_uint64 = 0; - result->zero_padding = 0; - result->type_index = TypeIndex::kTVMFFIDataType; - result->v_dtype = src; - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny* src) { - auto opt_dtype = TypeTraits::TryCastFromAnyView(src); - if (opt_dtype) { - return runtime::DataType(opt_dtype.value()); - } - return std::nullopt; - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - return TypeTraits::CheckAnyStrict(src); - } - - TVM_FFI_INLINE static runtime::DataType CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return runtime::DataType(TypeTraits::CopyFromAnyViewAfterCheck(src)); - } - - TVM_FFI_INLINE static std::string TypeStr() { return ffi::StaticTypeKey::kTVMFFIDataType; } - - TVM_FFI_INLINE static std::string TypeSchema() { - return R"({"type":")" + std::string(ffi::StaticTypeKey::kTVMFFIDataType) + R"("})"; - } -}; - -} // namespace ffi -} // namespace tvm - -namespace std { -template <> -struct hash { - inline int cantor_pairing_function(int a, int b) const { return (a + b) * (a + b + 1) / 2 + b; } - std::size_t operator()(tvm::DataType const& dtype) const { - int a = dtype.code(); - int b = dtype.bits(); - int c = dtype.lanes(); - int d = cantor_pairing_function(a, b); - return cantor_pairing_function(c, d); - } -}; -} // namespace std - -#endif // TVM_RUNTIME_DATA_TYPE_H_ diff --git a/include/tvm/runtime/disco/builtin.h b/include/tvm/runtime/disco/builtin.h index a9487c866acc..9d66a09507c5 100644 --- a/include/tvm/runtime/disco/builtin.h +++ b/include/tvm/runtime/disco/builtin.h @@ -19,8 +19,8 @@ #ifndef TVM_RUNTIME_DISCO_BUILTIN_H_ #define TVM_RUNTIME_DISCO_BUILTIN_H_ +#include #include -#include #include #include @@ -70,7 +70,7 @@ TVM_RUNTIME_DLL ffi::Module LoadVMModule(std::string path, ffi::Optional * \param device The device the Tensor is created on. If None, use the thread local default device * \return The Tensor created */ -TVM_RUNTIME_DLL Tensor DiscoEmptyTensor(ffi::Shape shape, DataType dtype, +TVM_RUNTIME_DLL Tensor DiscoEmptyTensor(ffi::Shape shape, DLDataType dtype, ffi::Optional device); /*! * \brief Perform an allreduce operation using the underlying communication library diff --git a/include/tvm/runtime/tensor.h b/include/tvm/runtime/tensor.h index d3497c8ff78f..cb93c4abd741 100644 --- a/include/tvm/runtime/tensor.h +++ b/include/tvm/runtime/tensor.h @@ -26,10 +26,10 @@ #include #include +#include #include #include #include -#include #include #include #include @@ -59,7 +59,7 @@ class Tensor : public tvm::ffi::Tensor { Tensor(const ffi::Tensor& other) : tvm::ffi::Tensor(other) {} // NOLINT(*) ffi::ShapeView Shape() const { return this->shape(); } - runtime::DataType DataType() const { return runtime::DataType(this->dtype()); } + DLDataType DataType() const { return this->dtype(); } // DLPack handling static Tensor FromDLPack(DLManagedTensor* tensor) { diff --git a/include/tvm/runtime/vm/bytecode.h b/include/tvm/runtime/vm/bytecode.h index 0f1927e0cbcb..ea246da5d354 100644 --- a/include/tvm/runtime/vm/bytecode.h +++ b/include/tvm/runtime/vm/bytecode.h @@ -24,8 +24,8 @@ #ifndef TVM_RUNTIME_VM_BYTECODE_H_ #define TVM_RUNTIME_VM_BYTECODE_H_ +#include #include -#include #include #include diff --git a/include/tvm/runtime/vm/tensor_cache_support.h b/include/tvm/runtime/vm/tensor_cache_support.h index ea997f0755bd..b112043c376f 100644 --- a/include/tvm/runtime/vm/tensor_cache_support.h +++ b/include/tvm/runtime/vm/tensor_cache_support.h @@ -54,7 +54,7 @@ struct TensorCacheMetadata { /*! \brief Shape of the parameter */ ffi::Shape shape; /*! \brief Data type of the parameter */ - DataType dtype; + DLDataType dtype; /*! \brief Format of the parameter */ std::string format; /*! \brief Number of bytes */ diff --git a/include/tvm/s_tir/data_layout.h b/include/tvm/s_tir/data_layout.h index 48836c5a53d5..ee6d51832dba 100644 --- a/include/tvm/s_tir/data_layout.h +++ b/include/tvm/s_tir/data_layout.h @@ -140,10 +140,10 @@ class SLayout : public ffi::ObjectRef { * the corresponding lower case with factor size * indicates the split dimension. * return undefined layout if "__undef__" is passed. - * \param dtype The dtype of generated axes vars in the returned layout. + * \param index_ty The type of generated axes vars in the returned layout. * It is required to be integer type. */ - TVM_DLL SLayout(const std::string& name, DataType dtype = DataType::Int(32)); // NOLINT(*) + TVM_DLL SLayout(const std::string& name, PrimType index_ty = PrimType::Int(32)); // NOLINT(*) /*! * \brief access the internal node container diff --git a/include/tvm/s_tir/meta_schedule/arg_info.h b/include/tvm/s_tir/meta_schedule/arg_info.h index 463e73b0e246..a346a73dd441 100644 --- a/include/tvm/s_tir/meta_schedule/arg_info.h +++ b/include/tvm/s_tir/meta_schedule/arg_info.h @@ -20,9 +20,9 @@ #define TVM_S_TIR_META_SCHEDULE_ARG_INFO_H_ #include +#include #include #include -#include #include namespace tvm { @@ -77,7 +77,7 @@ class ArgInfo : public ffi::ObjectRef { class TensorInfoNode : public ArgInfoNode { public: /*! \brief The data type of the tensor. */ - runtime::DataType dtype; + DLDataType dtype; /*! \brief The shape of the tensor. */ ffi::Shape shape; @@ -104,7 +104,7 @@ class TensorInfo : public ArgInfo { * \param dtype The data type of the tensor argument. * \param shape The shape tuple of the tensor argument. */ - TVM_DLL explicit TensorInfo(runtime::DataType dtype, ffi::Shape shape); + TVM_DLL explicit TensorInfo(DLDataType dtype, ffi::Shape shape); /*! * \brief Parse the argument information from a JSON object. * \param json_obj The json object to parse. diff --git a/include/tvm/script/printer/config.h b/include/tvm/script/printer/config.h index beea4042470c..e0ed32d38094 100644 --- a/include/tvm/script/printer/config.h +++ b/include/tvm/script/printer/config.h @@ -30,10 +30,11 @@ #include #include #include +#include #include #include #include -#include +#include #include @@ -53,15 +54,15 @@ class PrinterConfigNode : public ffi::Object { */ ffi::String module_alias = "cls"; /*! \brief Default buffer dtype */ - DataType buffer_dtype = DataType::Float(32); + DLDataType buffer_dtype = DLDataType{kDLFloat, 32, 1}; /*! \brief Default data type of integer literals */ - DataType int_dtype = DataType::Int(32); + DLDataType int_dtype = DLDataType{kDLInt, 32, 1}; /*! * \brief Default data type of float literals. Right now we always print out the explicit type * of floating point values, so setting it to Void means we do not print without the * T.float32/T.float64 wrapper. */ - DataType float_dtype = DataType::Void(); + DLDataType float_dtype = DLDataType{kDLOpaqueHandle, 0, 0}; /*! \brief Whether or not to verbose print expressions. */ bool verbose_expr = false; /*! \brief Number of spaces used for indentation*/ diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 2389c1b50d15..bc90e5365734 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -19,10 +19,11 @@ #ifndef TVM_SCRIPT_PRINTER_DOC_H_ #define TVM_SCRIPT_PRINTER_DOC_H_ +#include #include #include #include -#include +#include #include #include @@ -293,7 +294,7 @@ class LiteralDoc : public ExprDoc { * \param p The object path */ static LiteralDoc Float(double v, const ffi::Optional& p) { - return LiteralDoc(FloatImm(DataType::Float(64), v), p); + return LiteralDoc(FloatImm(PrimType::Float(64), v), p); } /*! * \brief Create a LiteralDoc to represent string. @@ -308,8 +309,9 @@ class LiteralDoc : public ExprDoc { * \param v The string value. * \param p The object path */ - static LiteralDoc DataType(const runtime::DataType& v, const ffi::Optional& p) { - std::string dtype = v.is_void() ? "void" : ffi::DLDataTypeToString(v); + static LiteralDoc DataType(DLDataType v, const ffi::Optional& p) { + std::string dtype = + v == DLDataType{kDLOpaqueHandle, 0, 0} ? "void" : ffi::DLDataTypeToString(v); return LiteralDoc::Str(dtype, p); } /*! diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index 98249c6f30bd..e9c82265ff27 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -333,7 +333,7 @@ inline TDoc IRDocsifierNode::AsDoc(const Any& value, const AccessPath& path) con return LiteralDoc::Str(string_value, path).as_or_throw(); } case ffi::TypeIndex::kTVMFFIDataType: - return LiteralDoc::DataType(value.as().value(), path).as_or_throw(); + return LiteralDoc::DataType(value.as().value(), path).as_or_throw(); case ffi::TypeIndex::kTVMFFIDevice: return LiteralDoc::Device(value.as().value(), path).as_or_throw(); default: { diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index c9d35a77fe99..ba5267a8ce85 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -34,6 +34,7 @@ #include #include +#include #include namespace tvm { @@ -67,11 +68,11 @@ class TVM_DLL OperationNode : public ffi::Object { /*! \return number of outputs */ virtual int num_outputs() const = 0; /*! - * \brief Get data type. i-th output tensor. + * \brief Get the primitive element type of the i-th output tensor. * \param i The output index. - * \return type of i-th output. + * \return primitive element type of i-th output. */ - virtual DataType output_dtype(size_t i) const = 0; + virtual PrimType output_dtype(size_t i) const = 0; /*! * \brief Get shape of i-th output tensor. * \param i The output index. @@ -101,11 +102,11 @@ class PlaceholderOpNode : public OperationNode { public: /*! \brief The shape of the input */ ffi::Array shape; - /*! \brief The data type of the input. */ - DataType dtype; + /*! \brief The dtype of the input. */ + PrimType dtype{DLDataType{kDLOpaqueHandle, 0, 0}}; // override behavior. int num_outputs() const final; - DataType output_dtype(size_t i) const final; + PrimType output_dtype(size_t i) const final; ffi::Array output_shape(size_t i) const final; ffi::Array InputTensors() const final; @@ -124,7 +125,9 @@ class PlaceholderOpNode : public OperationNode { */ class PlaceholderOp : public Operation { public: - TVM_DLL PlaceholderOp(std::string name, ffi::Array shape, DataType dtype); + TVM_DLL PlaceholderOp(std::string name, ffi::Array shape, PrimType dtype); + PlaceholderOp(std::string name, ffi::Array shape, DLDataType dtype) + : PlaceholderOp(std::move(name), std::move(shape), PrimType(dtype)) {} TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PlaceholderOp, Operation, PlaceholderOpNode); }; @@ -162,7 +165,7 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { ComputeOpNode() {} // override functions int num_outputs() const final; - DataType output_dtype(size_t i) const final; + PrimType output_dtype(size_t i) const final; ffi::Array InputTensors() const final; static void RegisterReflection() { @@ -217,7 +220,7 @@ class ScanOpNode : public OperationNode { ScanOpNode() {} // override behavior. int num_outputs() const final; - DataType output_dtype(size_t i) const final; + PrimType output_dtype(size_t i) const final; ffi::Array output_shape(size_t i) const final; ffi::Array InputTensors() const final; @@ -266,7 +269,7 @@ class ExternOpNode : public OperationNode { ExternOpNode() {} // override functions int num_outputs() const final; - DataType output_dtype(size_t i) const final; + PrimType output_dtype(size_t i) const final; ffi::Array output_shape(size_t i) const final; ffi::Array InputTensors() const final; @@ -299,7 +302,7 @@ class ExternOp : public Operation { * \param name_hint The name hint for the expression * \param t The type of the expression */ -TVM_DLL Var var(std::string name_hint, DataType t = DataType::Int(32)); +TVM_DLL Var var(std::string name_hint, PrimType t = PrimType::Int(32)); /*! * \brief Create a new IterVar that represents an axis in thread. @@ -329,9 +332,14 @@ using FBatchCompute = std::function(const ffi::Array& * \param dtype the data type of the tensor. * \param name The name of the Tensor. */ -TVM_DLL Tensor placeholder(ffi::Array shape, DataType dtype = DataType::Float(32), +TVM_DLL Tensor placeholder(ffi::Array shape, PrimType dtype = PrimType::Float(32), std::string name = "placeholder"); +inline Tensor placeholder(ffi::Array shape, DLDataType dtype, + std::string name = "placeholder") { + return placeholder(std::move(shape), PrimType(dtype), std::move(name)); +} + /*! * \brief Construct a new tensor by computing over shape, * using the computation rule: result_tensor[axis] = fcompute(axis) diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index ed07a35fb2da..760d308623f8 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -71,8 +71,8 @@ class TensorNode : public DataProducerNode { public: /*! \brief The shape of the tensor */ ffi::Array shape; - /*! \brief data type in the content of the tensor */ - DataType dtype; + /*! \brief dtype in the content of the tensor */ + PrimType dtype{DLDataType{kDLOpaqueHandle, 0, 0}}; /*! \brief the source operation, can be None */ Operation op; /*! \brief the output index from source operation */ @@ -82,7 +82,7 @@ class TensorNode : public DataProducerNode { ffi::Array GetShape() const final { return shape; } - DataType GetDataType() const final { return dtype; } + PrimType GetDataType() const final { return dtype; } TVM_DLL PrimExpr ToPrimExpr() const final; @@ -108,7 +108,9 @@ class Tensor : public DataProducer { inline PrimExpr IndexTensor(ffi::Array indices, bool support_negative_indices) const; public: - TVM_DLL Tensor(ffi::Array shape, DataType dtype, Operation op, int value_index); + TVM_DLL Tensor(ffi::Array shape, PrimType dtype, Operation op, int value_index); + Tensor(ffi::Array shape, DLDataType dtype, Operation op, int value_index) + : Tensor(std::move(shape), PrimType(dtype), std::move(op), value_index) {} /*! * \brief check if two tensors equals each other. * \param other tensor to be checked. diff --git a/include/tvm/tirx/buffer.h b/include/tvm/tirx/buffer.h index 1456787d688b..71d4c974dbb8 100644 --- a/include/tvm/tirx/buffer.h +++ b/include/tvm/tirx/buffer.h @@ -40,11 +40,20 @@ namespace tirx { #define TVM_INDEX_DEFAULT_I64 1 #endif /*! \brief if TVM_INDEX_DEFAULT_I64 is set, return int64, otherwise return int32 */ -inline DataType DefaultIndexType() { +inline PrimType DefaultIndexPrimType() { #if TVM_INDEX_DEFAULT_I64 - return DataType::Int(64); + static const PrimType default_index_ty = PrimType::Int(64); #else - return DataType::Int(32); + static const PrimType default_index_ty = PrimType::Int(32); +#endif + return default_index_ty; +} + +inline DLDataType DefaultIndexType() { +#if TVM_INDEX_DEFAULT_I64 + return DLDataType{kDLInt, 64, 1}; +#else + return DLDataType{kDLInt, 32, 1}; #endif } @@ -67,8 +76,8 @@ class BufferNode : public ffi::Object { * \sa data_alignment The alignment of data in bytes. */ Var data; - /*! \brief data type in the content of the tensor */ - DataType dtype; + /*! \brief dtype in the content of the tensor */ + PrimType dtype{DLDataType{kDLOpaqueHandle, 0, 0}}; /*! \brief The type of the buffer prior to flattening * * This contains the shape as it is accessed by @@ -147,10 +156,13 @@ class BufferNode : public ffi::Object { } /*! \return preferred index type for this buffer node */ - DataType DefaultIndexType() const { - return shape.size() != 0 ? shape[0].dtype() : tvm::tirx::DefaultIndexType(); + DLDataType DefaultIndexType() const { + return shape.size() != 0 ? shape[0].ty()->dtype : tvm::tirx::DefaultIndexType(); } + /*! \return primitive element type for compiler-side uses. */ + PrimType ElementType() const { return dtype; } + /*! \brief Determine the offset in the buffer of the given index. * * Returns the buffer offset, in number of elements of type dtype, @@ -176,11 +188,19 @@ class Buffer : public ffi::ObjectRef { public: // User can specify data_alignment and offset_factor to be 0 // A default value will be picked. - TVM_DLL Buffer(Var data, DataType dtype, ffi::Array shape, ffi::Array strides, + TVM_DLL Buffer(Var data, PrimType dtype, ffi::Array shape, ffi::Array strides, PrimExpr elem_offset, ffi::String name, int data_alignment, int offset_factor, BufferType buffer_type, ffi::Array axis_separators = {}, Span span = Span(), ffi::Optional layout = std::nullopt, ffi::Array allocated_addr = {}); + Buffer(Var data, DLDataType dtype, ffi::Array shape, ffi::Array strides, + PrimExpr elem_offset, ffi::String name, int data_alignment, int offset_factor, + BufferType buffer_type, ffi::Array axis_separators = {}, Span span = Span(), + ffi::Optional layout = std::nullopt, ffi::Array allocated_addr = {}) + : Buffer(std::move(data), PrimType(dtype), std::move(shape), std::move(strides), + std::move(elem_offset), std::move(name), data_alignment, offset_factor, buffer_type, + std::move(axis_separators), std::move(span), std::move(layout), + std::move(allocated_addr)) {} /*! * \brief Return a new buffer that is equivalent with current one @@ -205,7 +225,7 @@ class Buffer : public ffi::ObjectRef { * \param offset The offset of ptr. * \param input_extent The extent of ptr. */ - TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(), + TVM_DLL PrimExpr access_ptr(int access_mask, PrimType ptr_type = PrimType::Handle(), int content_lanes = 1, PrimExpr offset = IntImm::Int32(0), ffi::Optional input_extent = std::nullopt) const; /*! @@ -215,7 +235,7 @@ class Buffer : public ffi::ObjectRef { * \param predicate A vector mask of boolean values indicating which lanes of a vector are to be * loaded. The number lanes of the mask must be equal to the number of lanes in being loaded. */ - TVM_DLL PrimExpr vload(ffi::Array begin, DataType dtype, + TVM_DLL PrimExpr vload(ffi::Array begin, PrimType dtype, ffi::Optional predicate = std::nullopt) const; /*! * \brief Create a Stmt that does a vector store at begin index. @@ -267,7 +287,11 @@ class Buffer : public ffi::ObjectRef { /*! * \brief Return a new buffer with the dtype. */ - TVM_DLL Buffer with_dtype(DataType dtype) const; + TVM_DLL Buffer with_dtype(PrimType dtype) const; + Buffer with_dtype(DLDataType dtype) const { return with_dtype(PrimType(dtype)); } + + /*! \return primitive element type for compiler-side uses. */ + PrimType ElementType() const { return (*this)->ElementType(); } /*! * \brief Return a new buffer with the data. @@ -289,11 +313,20 @@ class Buffer : public ffi::ObjectRef { * \return The created buffer. * \sa Buffer for complete constructor. */ -TVM_DLL Buffer decl_buffer(ffi::Array shape, DataType dtype = DataType::Float(32), +TVM_DLL Buffer decl_buffer(ffi::Array shape, + DLDataType dtype = DLDataType{kDLFloat, 32, 1}, ffi::String name = "buffer", ffi::String storage_scope = "", ffi::Optional> axis_separators = std::nullopt, Span span = Span()); +inline Buffer decl_buffer(ffi::Array shape, PrimType dtype, ffi::String name = "buffer", + ffi::String storage_scope = "", + ffi::Optional> axis_separators = std::nullopt, + Span span = Span()) { + return decl_buffer(std::move(shape), dtype->dtype, std::move(name), std::move(storage_scope), + std::move(axis_separators), std::move(span)); +} + /*! * \brief Base node for data producers. * @@ -316,10 +349,10 @@ class DataProducerNode : public PrimExprConvertibleNode { */ virtual ffi::Array GetShape() const = 0; /*! - * \brief Get the data type of the result. - * \return The data type. + * \brief Get the raw element dtype of the result. + * \return The raw dtype. */ - virtual DataType GetDataType() const = 0; + virtual PrimType GetDataType() const = 0; /*! * \brief Get the name hint of the data producer. * \return The data type. @@ -350,10 +383,18 @@ class DataProducer : public PrimExprConvertible { * \param compact If the statement has already bound to a compact buffer. * \param memory_scope memory scope of the buffer */ -TVM_DLL tirx::Buffer BufferWithOffsetAlignment(ffi::Array shape, DataType dtype, +TVM_DLL tirx::Buffer BufferWithOffsetAlignment(ffi::Array shape, DLDataType dtype, std::string name, int data_alignment, int offset_factor, bool compact, std::string memory_scope = ""); + +inline tirx::Buffer BufferWithOffsetAlignment(ffi::Array shape, PrimType dtype, + std::string name, int data_alignment, + int offset_factor, bool compact, + std::string memory_scope = "") { + return BufferWithOffsetAlignment(std::move(shape), dtype->dtype, std::move(name), data_alignment, + offset_factor, compact, std::move(memory_scope)); +} } // namespace tirx } // namespace tvm #endif // TVM_TIR_BUFFER_H_ diff --git a/include/tvm/tirx/expr.h b/include/tvm/tirx/expr.h index cd51108b0d23..bf4c9004e84d 100644 --- a/include/tvm/tirx/expr.h +++ b/include/tvm/tirx/expr.h @@ -27,13 +27,13 @@ #include #include +#include #include #include #include #include #include #include -#include #include #include @@ -96,7 +96,7 @@ class CastNode : public PrimExprNode { */ class Cast : public PrimExpr { public: - TVM_DLL Cast(DataType dtype, PrimExpr value, Span span = Span()); + TVM_DLL Cast(PrimType value_ty, PrimExpr value, Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Cast, PrimExpr, CastNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(CastNode); }; @@ -752,9 +752,9 @@ class CallNode : public PrimExprNode { */ class Call : public PrimExpr { public: - TVM_DLL Call(DataType dtype, RelaxExpr op, ffi::Array args, Attrs attrs = Attrs(), + TVM_DLL Call(PrimType ret_ty, RelaxExpr op, ffi::Array args, Attrs attrs = Attrs(), Span span = Span()); - TVM_DLL Call(DataType dtype, RelaxExpr op, ffi::Array args, Span span); + TVM_DLL Call(PrimType ret_ty, RelaxExpr op, ffi::Array args, Span span); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Call, PrimExpr, CallNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode); }; diff --git a/include/tvm/tirx/op.h b/include/tvm/tirx/op.h index 416aff73ee29..be827b9ef534 100644 --- a/include/tvm/tirx/op.h +++ b/include/tvm/tirx/op.h @@ -39,6 +39,7 @@ #include #include #include +#include namespace tvm { @@ -58,34 +59,36 @@ namespace tvm { /*! * \brief Get the type of the expression under the unified type system. * - * This function could return a more refined type than - * the runtime type provided by expr->dtype + * This function could return a more refined type than the runtime dtype + * implied by PrimExpr::ty(). * * \param expr The input parameter. * \return The result type. * - * \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType. + * \sa tvm/ir/type.h for discussion about the relation between Type and DLPack dtype. */ TVM_DLL Type GetType(const PrimExpr& expr); /*! - * \brief Get the type corresponding to DataType - * \param dtype The data type + * \brief Get the type corresponding to a runtime DLPack dtype. + * \param dtype The runtime dtype. * \return The result type * - * \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType. + * \sa tvm/ir/type.h for discussion about the relation between Type and DLPack dtype. */ -TVM_DLL Type GetTypeFromRuntimeDataType(const DataType& dtype); +TVM_DLL Type GetTypeFromRuntimeDataType(DLDataType dtype); /*! - * \brief Get the implied DataType for storing values with type during runtime. + * \brief Get the implied DLPack dtype for storing values with type during runtime. * * \param type The input type. - * \return The result runtime::DataType. + * \return The result DLPack dtype. * - * \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType. + * \sa tvm/ir/type.h for discussion about the relation between Type and DLPack dtype. */ -TVM_DLL runtime::DataType GetRuntimeDataType(const Type& type); +TVM_DLL DLDataType GetRuntimeDLDataType(const Type& type); + +inline DLDataType GetRuntimeDataType(const Type& type) { return GetRuntimeDLDataType(type); } /*! * \brief Return the value. @@ -120,27 +123,27 @@ TVM_DLL PrimExpr break_loop(Span span = Span()); /*! * Query the maximum possible value of dtype. - * \param dtype The data type. + * \param dtype The primitive type. * \param span The location of this operation in the source. * \return the maximum possible value in this format. */ -TVM_DLL PrimExpr max_value(const DataType& dtype, Span span = Span()); +TVM_DLL PrimExpr max_value(PrimType dtype, Span span = Span()); /*! * Query the minimum possible value of dtype. - * \param dtype The data type. + * \param dtype The primitive type. * \param span The location of this operation in the source. * \return the minimum possible value in this format. */ -TVM_DLL PrimExpr min_value(const DataType& dtype, Span span = Span()); +TVM_DLL PrimExpr min_value(PrimType dtype, Span span = Span()); /*! * Get the value of infinity. - * \param dtype The data type. + * \param dtype The primitive type. * \param span The location of this operation in the source. * \return the infinity value in this format. */ -TVM_DLL PrimExpr infinity(const DataType& dtype, Span span = Span()); +TVM_DLL PrimExpr infinity(PrimType dtype, Span span = Span()); /*! * \brief cast value to type. @@ -151,7 +154,7 @@ TVM_DLL PrimExpr infinity(const DataType& dtype, Span span = Span()); * \return The result expression. * \note This function may return value if the type is the same. */ -TVM_DLL PrimExpr cast(const DataType& t, PrimExpr value, Span span = Span()); +TVM_DLL PrimExpr cast(PrimType t, PrimExpr value, Span span = Span()); /*! * \brief perform reinterpret cast value to type. * @@ -161,7 +164,7 @@ TVM_DLL PrimExpr cast(const DataType& t, PrimExpr value, Span span = Span()); * \return The result expression. * \note This function may return value if the type is the same. */ -TVM_DLL PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span = Span()); +TVM_DLL PrimExpr reinterpret(PrimType t, PrimExpr value, Span span = Span()); /*! * \brief add operator * @@ -691,13 +694,13 @@ TVM_DLL PrimExpr trunc(PrimExpr x, Span span = Span()); /*! * \brief Construct a large uint constant by its low 32 bits and high 32bits. - * \param dtype The final data type. + * \param value_ty The final primitive type. * \param low The lower 32 bits. * \param high The higher 32 bits. * \param span The location of this operation in the source. * \return The constructed expression. */ -TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span span = Span()); +TVM_DLL PrimExpr LargeUIntImm(PrimType value_ty, int64_t low, int64_t high, Span span = Span()); /*! * \brief Execute a multiplication between two Q-numbers x and y @@ -731,29 +734,35 @@ TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s */ TVM_DLL PrimExpr fast_erf_float_expr(PrimExpr arg, int bits); -inline void CheckMathUnaryOpInputDType(const char* op_name, DataType dtype) { - TVM_FFI_CHECK(dtype.is_float() || dtype.is_bfloat16(), TypeError) +inline void CheckMathUnaryOpInputDType(const char* op_name, const PrimType& dtype) { + TVM_FFI_CHECK(dtype.code() == DLDataTypeCode::kDLFloat || + dtype.MatchesElementType(DLDataTypeCode::kDLBfloat, 16), + TypeError) << "tirx." << op_name << " only supports floating-point inputs, but got " << dtype; } // Intrinsic operators -#define TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, CheckInputDType) \ - inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ - static const Op op = Op::Get("tirx." #OpName); \ - CheckInputDType(#OpName, x.dtype()); \ - if (x.dtype().is_bfloat16()) { \ - DataType bf16_dtype = x.dtype(); \ - DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes()); \ - PrimExpr x_fp32 = tirx::Cast(fp32_dtype, {x}, span); \ - PrimExpr result_fp32 = tirx::Call(fp32_dtype, op, {x_fp32}, {}, span); \ - return tirx::Cast(bf16_dtype, {result_fp32}, span); \ - } else { \ - return tirx::Call(x.dtype(), op, {x}, {}, span); \ - } \ +#define TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, CheckInputDType) \ + inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ + static const Op op = Op::Get("tirx." #OpName); \ + PrimType x_ty = x.ty(); \ + CheckInputDType(#OpName, x_ty); \ + if (x_ty.MatchesElementType(DLDataTypeCode::kDLBfloat, 16)) { \ + PrimType bf16_ty = x_ty; \ + PrimType f32_ty = \ + x_ty.IsScalableVector() \ + ? PrimType::ScalableVector(DLDataTypeCode::kDLFloat, 32, x_ty.VScaleFactor()) \ + : PrimType::Float(32, x_ty.lanes()); \ + PrimExpr x_fp32 = tirx::Cast(f32_ty, x, span); \ + PrimExpr result_fp32 = tirx::Call(f32_ty, op, {x_fp32}, {}, span); \ + return tirx::Cast(bf16_ty, result_fp32, span); \ + } else { \ + return tirx::Call(x_ty, op, {x}, {}, span); \ + } \ } #define TVM_DECLARE_INTRIN_UNARY(OpName) \ - TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, [](const char*, DataType) {}) + TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, [](const char*, const PrimType&) {}) #define TVM_DECLARE_FLOAT_INTRIN_UNARY(OpName) \ TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, CheckMathUnaryOpInputDType) @@ -787,7 +796,7 @@ TVM_DECLARE_INTRIN_UNARY(clz); #define TVM_DECLARE_INTRIN_BINARY(OpName) \ inline PrimExpr OpName(PrimExpr x, PrimExpr y, Span span = Span()) { \ static const Op op = Op::Get("tirx." #OpName); \ - return tirx::Call(x.dtype(), op, {x, y}, {}, span); \ + return tirx::Call(x.ty(), op, {x, y}, {}, span); \ } TVM_DECLARE_INTRIN_BINARY(atan2); @@ -804,7 +813,7 @@ namespace tirx { * \param element_type The corresponding element type. * \return The check results */ -inline bool IsPointerType(const Type& type, const DataType& element_type) { +inline bool IsPointerType(const Type& type, DLDataType element_type) { if (!type.defined()) return false; if (const auto* ptr_type = type.as()) { if (const auto* prim_type = ptr_type->element_type.as()) { @@ -832,7 +841,7 @@ inline bool IsPointerType(const Type& type, const DataType& element_type) { template ::value && std::is_trivial::value>::type> -inline PrimExpr MakeConst(DataType dtype, ValueType value, Span span = Span()); +inline PrimExpr MakeConst(PrimType dtype, ValueType value, Span span = Span()); /*! * \brief Make a constant handle value. * \param value The integer payload to reinterpret as a handle. @@ -970,9 +979,12 @@ inline bool is_no_op(const tirx::Stmt& stmt) { } template -inline PrimExpr MakeConstScalar(DataType dtype, ValueType value, Span span = Span()) { - if (dtype.is_int() || dtype.is_bool()) return IntImm(dtype, static_cast(value), span); - if (dtype.is_uint()) { +inline PrimExpr MakeConstScalar(PrimType dtype, ValueType value, Span span = Span()) { + DLDataTypeCode code = dtype.code(); + if (code == DLDataTypeCode::kDLInt || code == DLDataTypeCode::kDLBool) { + return IntImm(dtype, static_cast(value), span); + } + if (code == DLDataTypeCode::kDLUInt) { // Use IntImm if it is a small integer uint64_t uval = static_cast(value); if (value < static_cast(0)) { @@ -986,8 +998,13 @@ inline PrimExpr MakeConstScalar(DataType dtype, ValueType value, Span span = Spa return LargeUIntImm(dtype, static_cast(low), static_cast(high), span); } } - if (dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || dtype.is_float6() || - dtype.is_float4()) { + if (dtype.MatchesCode(DLDataTypeCode::kDLFloat, DLDataTypeCode::kDLFloat8_e3m4, + DLDataTypeCode::kDLFloat8_e4m3, DLDataTypeCode::kDLFloat8_e4m3b11fnuz, + DLDataTypeCode::kDLFloat8_e4m3fn, DLDataTypeCode::kDLFloat8_e4m3fnuz, + DLDataTypeCode::kDLFloat8_e5m2, DLDataTypeCode::kDLFloat8_e5m2fnuz, + DLDataTypeCode::kDLFloat8_e8m0fnu, DLDataTypeCode::kDLFloat6_e2m3fn, + DLDataTypeCode::kDLFloat6_e3m2fn, DLDataTypeCode::kDLFloat4_e2m1fn) || + dtype.MatchesElementType(DLDataTypeCode::kDLBfloat, 16)) { return FloatImm(dtype, static_cast(value), span); } TVM_FFI_THROW(InternalError) << "cannot make const for type " << dtype; @@ -995,27 +1012,26 @@ inline PrimExpr MakeConstScalar(DataType dtype, ValueType value, Span span = Spa } template <> -inline PrimExpr MakeConstScalar(DataType dtype, bool value, Span span) { +inline PrimExpr MakeConstScalar(PrimType dtype, bool value, Span span) { return MakeConstScalar(dtype, static_cast(value), span); } template -inline PrimExpr MakeConst(DataType dtype, ValueType value, Span span) { - if (dtype.is_scalar()) { +inline PrimExpr MakeConst(PrimType dtype, ValueType value, Span span) { + if (!dtype.IsScalableVector() && !dtype.IsFixedLengthVector()) { return MakeConstScalar(dtype, value, span); - } else { - if (dtype.is_fixed_length_vector()) { - return tirx::Broadcast(MakeConstScalar(dtype.element_of(), value, span), dtype.lanes(), span); - } else { - PrimExpr lanes = tirx::Mul(tirx::Call(DataType::Int(32), tirx::builtin::vscale(), {}), - dtype.vscale_factor()); - return tirx::Broadcast(MakeConstScalar(dtype.element_of(), value, span), lanes, span); - } } + PrimType elem_ty = dtype.WithLanes(1); + if (dtype.IsFixedLengthVector()) { + return tirx::Broadcast(MakeConstScalar(elem_ty, value, span), dtype.lanes(), span); + } + PrimExpr lanes = + tirx::Mul(tirx::Call(PrimType::Int(32), tirx::builtin::vscale(), {}), dtype.VScaleFactor()); + return tirx::Broadcast(MakeConstScalar(elem_ty, value, span), lanes, span); } inline PrimExpr ConstHandle(int64_t value, Span span) { - return reinterpret(DataType::Handle(), IntImm(DataType::UInt(64), value, span)); + return reinterpret(PrimType::Handle(), IntImm(PrimType::UInt(64), value, span)); } } // namespace tirx @@ -1027,17 +1043,13 @@ inline PrimExpr ConstHandle(int64_t value, Span span) { return a; \ } -#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \ - inline PrimExpr Name(const PrimExpr& a, float b) { return Name(a, PrimExpr(b)); } \ - inline PrimExpr Name(float a, const PrimExpr& b) { return Name(PrimExpr(a), b); } \ - inline PrimExpr Name(int a, const PrimExpr& b) { \ - return Name(tirx::MakeConst(b.dtype(), a), b); \ - } \ - inline PrimExpr Name(const PrimExpr& a, int b) { \ - return Name(a, tirx::MakeConst(a.dtype(), b)); \ - } \ - inline PrimExpr Name(const PrimExpr& a, double b) { \ - return Name(a, FloatImm(DataType::Float(64), b)); \ +#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \ + inline PrimExpr Name(const PrimExpr& a, float b) { return Name(a, PrimExpr(b)); } \ + inline PrimExpr Name(float a, const PrimExpr& b) { return Name(PrimExpr(a), b); } \ + inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tirx::MakeConst(b.ty(), a), b); } \ + inline PrimExpr Name(const PrimExpr& a, int b) { return Name(a, tirx::MakeConst(a.ty(), b)); } \ + inline PrimExpr Name(const PrimExpr& a, double b) { \ + return Name(a, FloatImm(PrimType::Float(64), b)); \ } #define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(Name) \ @@ -1048,13 +1060,13 @@ inline PrimExpr ConstHandle(int64_t value, Span span) { return Name(PrimExpr(a), b, span); \ } \ inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \ - return Name(tirx::MakeConst(b.dtype(), a), b, span); \ + return Name(tirx::MakeConst(b.ty(), a), b, span); \ } \ inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \ - return Name(a, tirx::MakeConst(a.dtype(), b), span); \ + return Name(a, tirx::MakeConst(a.ty(), b), span); \ } \ inline PrimExpr Name(const PrimExpr& a, double b, Span span = Span()) { \ - return Name(a, FloatImm(DataType::Float(64), b), span); \ + return Name(a, FloatImm(PrimType::Float(64), b), span); \ } #define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \ @@ -1069,18 +1081,16 @@ inline PrimExpr ConstHandle(int64_t value, Span span) { return Name(PrimExpr(a), b, span); \ } -#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \ - inline PrimExpr Name(const PrimExpr& a, int b) { \ - return Name(a, tirx::MakeConst(a.dtype(), b)); \ - } \ - inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tirx::MakeConst(b.dtype(), a), b); } +#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \ + inline PrimExpr Name(const PrimExpr& a, int b) { return Name(a, tirx::MakeConst(a.ty(), b)); } \ + inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tirx::MakeConst(b.ty(), a), b); } #define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(Name) \ inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \ - return Name(a, tirx::MakeConst(a.dtype(), b), span); \ + return Name(a, tirx::MakeConst(a.ty(), b), span); \ } \ inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \ - return Name(tirx::MakeConst(b.dtype(), a), b, span); \ + return Name(tirx::MakeConst(b.ty(), a), b, span); \ } TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+); diff --git a/include/tvm/tirx/script/builder/ir.h b/include/tvm/tirx/script/builder/ir.h index ad18d7ac4001..684653134a55 100644 --- a/include/tvm/tirx/script/builder/ir.h +++ b/include/tvm/tirx/script/builder/ir.h @@ -57,7 +57,7 @@ using tvm::tirx::Var; * \param axis_separators The separators between input axes when generating flattened output axes. * \return The declared buffer. */ -Buffer BufferDecl(ffi::Array shape, DataType dtype, ffi::String buffer_name, +Buffer BufferDecl(ffi::Array shape, PrimType dtype, ffi::String buffer_name, ffi::Optional data, ffi::Optional> strides, ffi::Optional elem_offset, ffi::String storage_scope, int align, int offset_factor, ffi::String buffer_type, @@ -122,7 +122,7 @@ Type FuncRet(Type ret_type); * \return The matched buffer. */ Buffer MatchBuffer(ffi::ObjectRef param, ffi::Array shape, - DataType dtype = DataType::Float(32), ffi::Optional data = std::nullopt, + PrimType dtype = PrimType::Float(32), ffi::Optional data = std::nullopt, ffi::Array strides = {}, PrimExpr elem_offset = PrimExpr(), ffi::String storage_scope = "global", int align = -1, int offset_factor = 0, ffi::String buffer_type = "default", @@ -197,7 +197,7 @@ void BlockAttrs(ffi::Map attrs); * T.prim_func(tirx=True). */ ffi::Variant SBlockAllocBuffer( - ffi::Array shape, DataType dtype = DataType::Float(32), + ffi::Array shape, PrimType dtype = PrimType::Float(32), ffi::Optional data = std::nullopt, ffi::Array strides = {}, PrimExpr elem_offset = PrimExpr(), ffi::String storage_scope = "", int align = -1, int offset_factor = 0, ffi::String buffer_type = "default", @@ -213,7 +213,7 @@ namespace axis { * \param dtype The data type of the iteration variable. * \return The iteration variable. */ -Var Spatial(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); +Var Spatial(Range dom, PrimExpr binding, PrimType dtype = PrimType::Int(32)); /*! * \brief The reduced block axis defining function. @@ -222,7 +222,7 @@ Var Spatial(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); * \param dtype The data type of the iteration variable. * \return The iteration variable. */ -Var Reduce(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); +Var Reduce(Range dom, PrimExpr binding, PrimType dtype = PrimType::Int(32)); /*! * \brief The scanning block axis defining function. @@ -231,7 +231,7 @@ Var Reduce(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); * \param dtype The data type of the iteration variable. * \return The iteration variable. */ -Var Scan(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); +Var Scan(Range dom, PrimExpr binding, PrimType dtype = PrimType::Int(32)); /*! * \brief The opaque block axis defining function. @@ -240,7 +240,7 @@ Var Scan(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); * \param dtype The data type of the iteration variable. * \return The iteration variable. */ -Var Opaque(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); +Var Opaque(Range dom, PrimExpr binding, PrimType dtype = PrimType::Int(32)); /*! * \brief The block axis remapping function. @@ -250,7 +250,7 @@ Var Opaque(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32)); * \return The iteration variables. */ ffi::Array Remap(ffi::String kinds, ffi::Array bindings, - DataType dtype = DataType::Int(32)); + PrimType dtype = PrimType::Int(32)); } // namespace axis @@ -412,7 +412,7 @@ ElseFrame Else(); * \param layout The layout of the buffer. * \return The declaration frame. */ -DeclBufferFrame DeclBuffer(ffi::Array shape, DataType dtype, ffi::String buffer_name, +DeclBufferFrame DeclBuffer(ffi::Array shape, PrimType dtype, ffi::String buffer_name, ffi::Optional data, ffi::Optional> strides, ffi::Optional elem_offset, ffi::String storage_scope, int align, int offset_factor, ffi::String buffer_type, @@ -428,7 +428,7 @@ DeclBufferFrame DeclBuffer(ffi::Array shape, DataType dtype, ffi::Stri * \param annotations Optional annotations for the allocation. * \return The allocated buffer. */ -Buffer AllocBuffer(ffi::Array shape, DataType dtype = DataType::Float(32), +Buffer AllocBuffer(ffi::Array shape, PrimType dtype = PrimType::Float(32), ffi::String storage_scope = "global", ffi::Optional> annotations = std::nullopt); @@ -465,7 +465,7 @@ ComposeOpFrame ComposeOp(ffi::Map workspace, * \param dtype The data type of the variable. * \return The result variable which gets bound to the thread env. */ -Var EnvThread(ffi::String thread_tag, DataType dtype = DataType::Int(32)); +Var EnvThread(ffi::String thread_tag, PrimType dtype = PrimType::Int(32)); /*! * \brief Store data in a buffer. @@ -494,21 +494,20 @@ void Evaluate(PrimExpr value); * \param is_size_var Whether the pointer is a size var. * * \param is_unknown_type Used to distinguish between - * `PrimType(DataType::Handle())` and - * `PointerType(PrimType(DataType::Void()))`. If true, resolve dtype + * `PrimType::Handle()` and `PointerType(PrimType(DLDataType{kDLOpaqueHandle, 0, 0}))`. + * If true, resolve dtype * of `Void()` as `PrimType`, and if false resolve dtype of `Void()` * as a `PointerType`. * * \return The pointer. */ -inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(), - ffi::String storage_scope = "global", bool is_size_var = false, - bool is_unknown_type = false) { +inline Var Handle(PrimType dtype = PrimType::Handle(), ffi::String storage_scope = "global", + bool is_size_var = false, bool is_unknown_type = false) { Type type_annotation{nullptr}; if (is_unknown_type && storage_scope == "global") { - type_annotation = PrimType(runtime::DataType::Handle()); + type_annotation = PrimType::Handle(); } else { - type_annotation = PointerType(PrimType(dtype), storage_scope); + type_annotation = PointerType(dtype, storage_scope); } return is_size_var ? tvm::tirx::SizeVar("", type_annotation) : tvm::tirx::Var("", type_annotation); @@ -519,67 +518,67 @@ inline Var TensorMap() { return tvm::tirx::Var("", PointerType(TensorMapType())) #define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \ inline PrimExpr FuncName(ffi::Optional expr = std::nullopt, \ bool is_size_var = false) { \ - DataType dtype = DType; \ + PrimType dtype(DType); \ return expr.defined() \ ? tvm::cast(dtype, expr.value()) \ : (is_size_var ? tvm::tirx::SizeVar("", dtype) : tvm::tirx::Var("", dtype)); \ } -#define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES(DType, FDType) \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##8, FDType(8)); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##16, FDType(16)); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##32, FDType(32)); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##64, FDType(64)); - -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES(BFloat, DataType::BFloat); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Float, DataType::Float); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES(UInt, DataType::UInt); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, DataType::Int); - -#define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES(FuncName, FDType, Size) \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x2, FDType(Size, 2)) \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x4, FDType(Size, 4)); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x8, FDType(Size, 8)); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x16, FDType(Size, 16)); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x32, FDType(Size, 32)); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x64, FDType(Size, 64)); - -#define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(DType, FDType) \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##8, FDType, 8); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##16, FDType, 16); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##32, FDType, 32); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##64, FDType, 64); - -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(BFloat, DataType::BFloat); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float, DataType::Float); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, DataType::UInt); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int); - -#define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(DType, FDType) \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType, FDType(1)); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x2, FDType(2)); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x4, FDType(4)); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x8, FDType(8)); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x16, FDType(16)); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x32, FDType(32)); \ - TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x64, FDType(64)); - -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E3M4, DataType::Float8E3M4); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3, DataType::Float8E4M3); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3B11FNUZ, DataType::Float8E4M3B11FNUZ); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3FN, DataType::Float8E4M3FN); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3FNUZ, DataType::Float8E4M3FNUZ); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E5M2, DataType::Float8E5M2); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E5M2FNUZ, DataType::Float8E5M2FNUZ); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E8M0FNU, DataType::Float8E8M0FNU); - -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float6E2M3FN, DataType::Float6E2M3FN); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float6E3M2FN, DataType::Float6E3M2FN); - -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float4E2M1FN, DataType::Float4E2M1FN); - -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool()); -TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void()); +#define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES(DType, Code) \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##8, (DLDataType{Code, 8, 1})); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##16, (DLDataType{Code, 16, 1})); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##32, (DLDataType{Code, 32, 1})); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##64, (DLDataType{Code, 64, 1})); + +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES(BFloat, kDLBfloat); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Float, kDLFloat); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES(UInt, kDLUInt); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, kDLInt); + +#define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES(FuncName, Code, Size) \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x2, (DLDataType{Code, Size, 2})) \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x4, (DLDataType{Code, Size, 4})); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x8, (DLDataType{Code, Size, 8})); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x16, (DLDataType{Code, Size, 16})); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x32, (DLDataType{Code, Size, 32})); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x64, (DLDataType{Code, Size, 64})); + +#define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(DType, Code) \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##8, Code, 8); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##16, Code, 16); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##32, Code, 32); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##64, Code, 64); + +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(BFloat, kDLBfloat); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float, kDLFloat); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, kDLUInt); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, kDLInt); + +#define TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(DType, Code, Bits) \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType, (DLDataType{Code, Bits, 1})); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x2, (DLDataType{Code, Bits, 2})); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x4, (DLDataType{Code, Bits, 4})); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x8, (DLDataType{Code, Bits, 8})); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x16, (DLDataType{Code, Bits, 16})); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x32, (DLDataType{Code, Bits, 32})); \ + TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(DType##x64, (DLDataType{Code, Bits, 64})); + +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E3M4, kDLFloat8_e3m4, 8); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3, kDLFloat8_e4m3, 8); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3B11FNUZ, kDLFloat8_e4m3b11fnuz, 8); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3FN, kDLFloat8_e4m3fn, 8); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E4M3FNUZ, kDLFloat8_e4m3fnuz, 8); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E5M2, kDLFloat8_e5m2, 8); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E5M2FNUZ, kDLFloat8_e5m2fnuz, 8); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float8E8M0FNU, kDLFloat8_e8m0fnu, 8); + +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float6E2M3FN, kDLFloat6_e2m3fn, 6); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float6E3M2FN, kDLFloat6_e3m2fn, 6); + +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(Float4E2M1FN, kDLFloat4_e2m1fn, 4); + +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(Boolean, (DLDataType{kDLBool, 8, 1})); +TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST(Void, (DLDataType{kDLOpaqueHandle, 0, 0})); #undef TVM_TIRX_IR_BUILDER_DEF_DTYPE_CAST diff --git a/include/tvm/tirx/stmt.h b/include/tvm/tirx/stmt.h index 1ed4d5acac54..7eb004f8cf25 100644 --- a/include/tvm/tirx/stmt.h +++ b/include/tvm/tirx/stmt.h @@ -1282,7 +1282,7 @@ inline bool IsPragmaKey(const std::string& attr_key) { * \param span The location of this object in the source code. * \return Expr a expression with dtype. */ -TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span()); +TVM_DLL PrimExpr TypeAnnotation(PrimType dtype, Span span = Span()); // overload printing of for type. TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind); diff --git a/include/tvm/tirx/var.h b/include/tvm/tirx/var.h index 8c536ef0d668..3a4746a3f6a2 100644 --- a/include/tvm/tirx/var.h +++ b/include/tvm/tirx/var.h @@ -24,9 +24,9 @@ #ifndef TVM_TIR_VAR_H_ #define TVM_TIR_VAR_H_ +#include #include #include -#include #include #include @@ -57,7 +57,7 @@ class VarNode : public PrimExprNode { * * It is an optional field that provides a refined type of the variable than dtype. * - * \sa tvm/ir/type.h for discussion of relations between runtime::DataType and Type. + * \sa tvm/ir/type.h for discussion of relations between DLPack dtype and Type. */ Type type_annotation; @@ -84,7 +84,7 @@ class Var : public PrimExpr { * \param dtype data type * \param span The location of this object in the source code. */ - TVM_DLL explicit Var(ffi::String name_hint = "v", DataType dtype = DataType::Int(32), + TVM_DLL explicit Var(ffi::String name_hint = "v", PrimType dtype = PrimType::Int(32), Span span = Span()); /*! * \brief Constructor which provides a more detailed type annotation. @@ -110,7 +110,7 @@ class Var : public PrimExpr { * \param dtype The specified dtype * \return The new variable */ - TVM_DLL Var copy_with_dtype(DataType dtype) const; + TVM_DLL Var copy_with_dtype(PrimType dtype) const; /*! * \brief Get pointer to the internal value. @@ -150,7 +150,7 @@ class SizeVar : public Var { * \param t data type * \param span The location of this object in the source code. */ - TVM_DLL explicit SizeVar(ffi::String name_hint = "s", DataType t = DataType::Int(32), + TVM_DLL explicit SizeVar(ffi::String name_hint = "s", PrimType t = PrimType::Int(32), Span span = Span()); /*! * \brief Constructor which provides a more detailed type annotation. diff --git a/include/tvm/topi/broadcast.h b/include/tvm/topi/broadcast.h index b0c6ac8f6722..26bf7c100ca5 100644 --- a/include/tvm/topi/broadcast.h +++ b/include/tvm/topi/broadcast.h @@ -252,7 +252,8 @@ TOPI_DEFINE_BCAST_OP(divide, { return div(a, b); }); * \return The result. */ TOPI_DEFINE_BCAST_OP(floor_divide, { - if (a.dtype().is_int() || a.dtype().is_uint()) { + PrimType a_ty = a.ty(); + if (a_ty.code() == DLDataTypeCode::kDLInt || a_ty.code() == DLDataTypeCode::kDLUInt) { return floordiv(a, b); } else { return floor(div(a, b)); @@ -287,7 +288,8 @@ TOPI_DEFINE_BCAST_OP(log_add_exp, { return logaddexp(a, b); }); * \return The result. */ TOPI_DEFINE_BCAST_OP(trunc_divide, { - if (a.dtype().is_int() || a.dtype().is_uint()) { + PrimType a_ty = a.ty(); + if (a_ty.code() == DLDataTypeCode::kDLInt || a_ty.code() == DLDataTypeCode::kDLUInt) { return truncdiv(a, b); } else { return trunc(div(a, b)); @@ -319,7 +321,8 @@ TOPI_DEFINE_BCAST_OP(mod, { return truncmod(a, b); }); * \return The result. */ TOPI_DEFINE_BCAST_OP(floor_mod, { - if (a.dtype().is_int() || a.dtype().is_uint()) { + PrimType a_ty = a.ty(); + if (a_ty.code() == DLDataTypeCode::kDLInt || a_ty.code() == DLDataTypeCode::kDLUInt) { return floormod(a, b); } else { return a - floor_divide(a, b) * b; @@ -338,7 +341,8 @@ TOPI_DEFINE_BCAST_OP(floor_mod, { * \return The result. */ TOPI_DEFINE_BCAST_OP(trunc_mod, { - if (a.dtype().is_int() || a.dtype().is_uint()) { + PrimType a_ty = a.ty(); + if (a_ty.code() == DLDataTypeCode::kDLInt || a_ty.code() == DLDataTypeCode::kDLUInt) { return truncmod(a, b); } else { return a - trunc_divide(a, b) * b; diff --git a/include/tvm/topi/contrib/cublas.h b/include/tvm/topi/contrib/cublas.h index 3590b7a54458..18ad4320f489 100644 --- a/include/tvm/topi/contrib/cublas.h +++ b/include/tvm/topi/contrib/cublas.h @@ -48,7 +48,7 @@ inline Tensor cublas_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, b auto m = transb ? rhs->shape[0] : rhs->shape[1]; return make_extern( - {{n, m}}, {lhs->dtype}, {lhs, rhs}, + {{n, m}}, {lhs->GetDataType()}, {lhs, rhs}, [&](ffi::Array ins, ffi::Array outs) { return call_packed({StringImm("tvm.contrib.cublas.matmul"), pack_buffer(ins[0]), pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb}); @@ -73,7 +73,7 @@ inline Tensor cublas_batch_matmul(const Tensor& lhs, const Tensor& rhs, bool tra auto m = transb ? rhs->shape[1] : rhs->shape[2]; return make_extern( - {{b, n, m}}, {lhs->dtype}, {lhs, rhs}, + {{b, n, m}}, {lhs->GetDataType()}, {lhs, rhs}, [&](ffi::Array ins, ffi::Array outs) { return call_packed({StringImm("tvm.contrib.cublas.batch_matmul"), pack_buffer(ins[0]), pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb}); diff --git a/include/tvm/topi/detail/broadcast.h b/include/tvm/topi/detail/broadcast.h index c9dce9eb7489..e5984fd1d787 100644 --- a/include/tvm/topi/detail/broadcast.h +++ b/include/tvm/topi/detail/broadcast.h @@ -42,10 +42,12 @@ struct BroadcastHelper { std::deque vars2; }; -static inline DataType CommonType(DataType type1, DataType type2) { - TVM_FFI_ICHECK(type1.is_scalar() && type2.is_scalar()); +static inline PrimType CommonType(const PrimType& type1, const PrimType& type2) { + TVM_FFI_ICHECK(!type1.IsScalableVector() && !type2.IsScalableVector()); + TVM_FFI_ICHECK_EQ(type1.lanes(), 1); + TVM_FFI_ICHECK_EQ(type2.lanes(), 1); TVM_FFI_ICHECK(type1.code() == type2.code()); - return DataType(type1.code(), std::max(type1.bits(), type2.bits()), /*lanes=*/1); + return type1.bits() < type2.bits() ? type1.WithBits(type2.bits()) : type1; } inline BroadcastHelper BroadcastShape(const tvm::ffi::Array& shape1, @@ -56,15 +58,15 @@ inline BroadcastHelper BroadcastShape(const tvm::ffi::Array& shap tvm::PrimExpr one(1); int i; - auto cast_if_needed = [](DataType to_type, PrimExpr expr) { - return to_type != expr.dtype() ? cast(to_type, expr) : expr; + auto cast_if_needed = [](PrimType to_type, PrimExpr expr) { + return to_type->dtype == expr.ty()->dtype ? expr : cast(to_type, expr); }; for (i = 1; i <= std::min(s1_size, s2_size); ++i) { // TODO(@icemelon9): Need to revisit this part const IntImmNode* static_size1 = shape1[s1_size - i].as(); const IntImmNode* static_size2 = shape2[s2_size - i].as(); - DataType common_type = CommonType(shape1[s1_size - i].dtype(), shape2[s2_size - i].dtype()); + PrimType common_type = CommonType(shape1[s1_size - i].ty(), shape2[s2_size - i].ty()); bh.all_vars.push_front(tvm::tirx::Var("dim", common_type)); if (topi::detail::EqualCheck(shape1[s1_size - i], shape2[s2_size - i])) { @@ -104,7 +106,7 @@ inline BroadcastHelper BroadcastShape(const tvm::ffi::Array& shap auto& shape = (s1_size > s2_size) ? shape1 : shape2; auto& vars = (s1_size > s2_size) ? bh.vars1 : bh.vars2; for (; i <= max_size; ++i) { - bh.all_vars.push_front(tvm::tirx::Var("v", shape[max_size - 1].dtype())); + bh.all_vars.push_front(tvm::tirx::Var("v", shape[max_size - 1].ty())); bh.common_shape.push_front(shape[max_size - i]); vars.push_front(bh.all_vars[0]); } @@ -130,7 +132,7 @@ inline tvm::ffi::Array InputIndexFromBroadcast( // Only inject 0 here if we have not yet reached the dimension of I // (i.e. this must be a 1) if (!found && (ovars.size() - i) <= expected_dims) { - ivars.push_back(tvm::IntImm(ovars[i].dtype(), 0)); + ivars.push_back(tvm::IntImm(ovars[i].ty(), 0)); } } TVM_FFI_ICHECK(expected_dims == ivars.size()); diff --git a/include/tvm/topi/detail/extern.h b/include/tvm/topi/detail/extern.h index 161d5291c38e..b0ce2d713bee 100644 --- a/include/tvm/topi/detail/extern.h +++ b/include/tvm/topi/detail/extern.h @@ -28,6 +28,7 @@ #include #include +#include #include namespace tvm { @@ -61,7 +62,7 @@ using FExtern = std::function, ffi::Array)>; * element of out_types. */ inline ffi::Array make_extern(const ffi::Array>& out_shapes, - const std::vector& out_types, + const std::vector& out_types, const ffi::Array& inputs, FExtern fextern, std::string name, std::string tag, ::tvm::ffi::Map attrs) { @@ -100,10 +101,10 @@ inline ffi::Array make_extern(const ffi::Array>& ou inline PrimExpr pack_buffer(Buffer buf) { TVM_FFI_ICHECK_GT(buf->shape.size(), 0) << "buf shape must have at least one element"; auto shape = - tvm::tirx::Call(DataType::Handle(), tvm::tirx::builtin::tvm_stack_make_shape(), buf->shape); + tvm::tirx::Call(PrimType::Handle(), tvm::tirx::builtin::tvm_stack_make_shape(), buf->shape); PrimExpr strides; if (buf->strides.size() > 0) { - strides = tvm::tirx::Call(DataType::Handle(), tvm::tirx::builtin::tvm_stack_make_shape(), + strides = tvm::tirx::Call(PrimType::Handle(), tvm::tirx::builtin::tvm_stack_make_shape(), buf->strides); } else { strides = 0; @@ -112,9 +113,9 @@ inline PrimExpr pack_buffer(Buffer buf) { shape, strides, IntImm::Int32(static_cast(buf->shape.size())), - MakeConst(buf->dtype, 0), + MakeConst(PrimType(buf->dtype), 0), buf->elem_offset}; - return tvm::tirx::Call(DataType::Handle(), tvm::tirx::builtin::tvm_stack_make_array(), pack_args); + return tvm::tirx::Call(PrimType::Handle(), tvm::tirx::builtin::tvm_stack_make_array(), pack_args); } /*! @@ -127,7 +128,7 @@ inline PrimExpr pack_buffer(Buffer buf) { * \return An expression representing the invocation */ inline PrimExpr call_packed(ffi::Array args) { - return tvm::tirx::Call(DataType::Int(32), tvm::tirx::builtin::tvm_call_packed(), args); + return tvm::tirx::Call(PrimType::Int(32), tvm::tirx::builtin::tvm_call_packed(), args); } } // namespace detail diff --git a/include/tvm/topi/detail/strided_slice.h b/include/tvm/topi/detail/strided_slice.h index 19ee79a2086f..95ab3a38cbc0 100644 --- a/include/tvm/topi/detail/strided_slice.h +++ b/include/tvm/topi/detail/strided_slice.h @@ -91,7 +91,7 @@ inline ffi::Array StridedSliceCanonicalizeBegin(const ffi::Array& begin, const std::vector& strides, const ffi::Array& axes, - DataType dtype, + PrimType dtype, std::string slice_mode = "end") { ffi::Array begin_expr; for (size_t i = 0; i < axes.size(); ++i) { @@ -140,9 +140,9 @@ inline ffi::Array StridedSliceOutputShape( static_cast((interval + std::abs(strides[i]) - 1) / std::abs(strides[i])); TVM_FFI_ICHECK(strides[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i)) << ": Input [Begin=" << begin[i] << ", End=" << end[i] << "] is invalid for axis=" << i; - out_shape.Set(ax, cast(out_shape[i].dtype(), PrimExpr(slice_size))); + out_shape.Set(ax, cast(out_shape[i].ty(), PrimExpr(slice_size))); } else { - out_shape.Set(ax, tvm::tirx::Var("dim", out_shape[i]->dtype)); + out_shape.Set(ax, tvm::tirx::Var("dim", out_shape[i].ty())); } } diff --git a/include/tvm/topi/detail/tensor_utils.h b/include/tvm/topi/detail/tensor_utils.h index d67ad6359434..82649cd0b387 100644 --- a/include/tvm/topi/detail/tensor_utils.h +++ b/include/tvm/topi/detail/tensor_utils.h @@ -70,10 +70,10 @@ inline PrimExpr bilinear_sample_nchw(const Tensor& input, const ffi::Arraydtype, -9.0), minimum(MakeConst(in->dtype, 9.0), in)); + PrimType input_type = in->GetDataType(); + auto x = maximum(MakeConst(input_type, -9.0), minimum(MakeConst(input_type, 9.0), in)); // The monomial coefficients of the numerator polynomial (odd). - auto alpha_1 = MakeConst(in->dtype, 4.89352455891786e-03); - auto alpha_3 = MakeConst(in->dtype, 6.37261928875436e-04); - auto alpha_5 = MakeConst(in->dtype, 1.48572235717979e-05); - auto alpha_7 = MakeConst(in->dtype, 5.12229709037114e-08); - auto alpha_9 = MakeConst(in->dtype, -8.60467152213735e-11); - auto alpha_11 = MakeConst(in->dtype, 2.00018790482477e-13); - auto alpha_13 = MakeConst(in->dtype, -2.76076847742355e-16); + auto alpha_1 = MakeConst(input_type, 4.89352455891786e-03); + auto alpha_3 = MakeConst(input_type, 6.37261928875436e-04); + auto alpha_5 = MakeConst(input_type, 1.48572235717979e-05); + auto alpha_7 = MakeConst(input_type, 5.12229709037114e-08); + auto alpha_9 = MakeConst(input_type, -8.60467152213735e-11); + auto alpha_11 = MakeConst(input_type, 2.00018790482477e-13); + auto alpha_13 = MakeConst(input_type, -2.76076847742355e-16); // The monomial coefficients of the denominator polynomial (even). - auto beta_0 = MakeConst(in->dtype, 4.89352518554385e-03); - auto beta_2 = MakeConst(in->dtype, 2.26843463243900e-03); - auto beta_4 = MakeConst(in->dtype, 1.18534705686654e-04); - auto beta_6 = MakeConst(in->dtype, 1.19825839466702e-06); + auto beta_0 = MakeConst(input_type, 4.89352518554385e-03); + auto beta_2 = MakeConst(input_type, 2.26843463243900e-03); + auto beta_4 = MakeConst(input_type, 1.18534705686654e-04); + auto beta_6 = MakeConst(input_type, 1.19825839466702e-06); return compute( x->shape, @@ -130,7 +131,7 @@ inline Tensor fast_tanh_float(const Tensor& in, std::string name, std::string ta */ inline Tensor fast_tanh(const Tensor& x, std::string name = "T_fast_tanh", std::string tag = kElementWise) { - if (x->dtype == DataType::Float(32)) { + if (x->GetDataType().MatchesElementType(DLDataTypeCode::kDLFloat, 32)) { // invoke fast_tanh_float implementation return fast_tanh_float(x, name, tag); } else { @@ -209,9 +210,10 @@ inline Tensor sign(const Tensor& x, std::string name = "T_sign", std::string tag return compute( x->shape, [&](const ffi::Array& i) { - PrimExpr zero = MakeConst(x->dtype, 0); - PrimExpr one = MakeConst(x->dtype, 1); - PrimExpr minus_one = MakeConst(x->dtype, -1); + PrimType x_type(x->GetDataType()); + PrimExpr zero = MakeConst(x_type, 0); + PrimExpr one = MakeConst(x_type, 1); + PrimExpr minus_one = MakeConst(x_type, -1); auto s1 = tvm::tirx::Select((x(i) < zero), minus_one, zero); auto s2 = tvm::tirx::Select((x(i) > zero), one, s1); return s2; @@ -232,7 +234,7 @@ inline Tensor rsqrt(const Tensor& x, std::string name = "tensor", std::string ta return compute( x->shape, [&](const ffi::Array& i) { - PrimExpr one = MakeConst(x->dtype, 1); + PrimExpr one = MakeConst(x->GetDataType(), 1); return one / tvm::sqrt(x(i)); }, name, tag); @@ -255,8 +257,9 @@ inline Tensor clip(const Tensor& x, const PrimExpr& a_min, const PrimExpr& a_max return compute( x->shape, [&](const ffi::Array& i) { - auto min_val = tvm::cast(x->dtype, a_min); - auto max_val = tvm::cast(x->dtype, a_max); + PrimType x_type(x->GetDataType()); + auto min_val = tvm::cast(x_type, a_min); + auto max_val = tvm::cast(x_type, a_max); return tvm::max(tvm::min(x(i), max_val), min_val); // NOLINT(*) }, name, tag); @@ -274,16 +277,24 @@ inline Tensor clip(const Tensor& x, const PrimExpr& a_min, const PrimExpr& a_max * * \return A Tensor whose op member is the cast operation */ -inline Tensor cast(const Tensor& x, DataType type, std::string name = "T_cast", +inline Tensor cast(const Tensor& x, PrimType type, std::string name, std::string tag); + +inline Tensor cast(const Tensor& x, DLDataType type, std::string name = "T_cast", + std::string tag = kElementWise) { + return cast(x, PrimType(type), std::move(name), std::move(tag)); +} + +inline Tensor cast(const Tensor& x, PrimType type, std::string name = "T_cast", std::string tag = kElementWise) { return compute( x->shape, [&](const ffi::Array& i) -> PrimExpr { auto expr = x(i); - if (expr.dtype().code() == type.code() && expr.dtype().bits() == type.bits()) { - if (expr.dtype().lanes() == type.lanes()) { + PrimType expr_ty = expr.ty(); + if (expr_ty.MatchesElementType(type.code(), type.bits())) { + if (expr_ty.lanes() == type.lanes()) { return expr; - } else if (expr.dtype().lanes() == 1 && type.is_vector()) { + } else if (expr_ty.lanes() == 1 && type.IsFixedLengthVector()) { return tvm::tirx::Broadcast(expr, type.lanes()); } } @@ -303,7 +314,14 @@ inline Tensor cast(const Tensor& x, DataType type, std::string name = "T_cast", * * \return A Tensor whose op member is the reinterpret operation */ -inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "tensor", +inline Tensor reinterpret(const Tensor& x, PrimType type, std::string name, std::string tag); + +inline Tensor reinterpret(const Tensor& x, DLDataType type, std::string name = "tensor", + std::string tag = kElementWise) { + return reinterpret(x, PrimType(type), std::move(name), std::move(tag)); +} + +inline Tensor reinterpret(const Tensor& x, PrimType type, std::string name = "tensor", std::string tag = kElementWise) { return compute( x->shape, [&](const ffi::Array& i) { return reinterpret(type, x(i)); }, name, tag); @@ -344,7 +362,15 @@ inline Tensor elemwise_sum(const ffi::Array& xs, std::string name = "T_e * * \return A Tensor whose op member is the full operation */ -inline Tensor full(const ffi::Array& shape, DataType dtype, const PrimExpr fill_value, +inline Tensor full(const ffi::Array& shape, PrimType dtype, const PrimExpr fill_value, + std::string name, std::string tag); + +inline Tensor full(const ffi::Array& shape, DLDataType dtype, const PrimExpr fill_value, + std::string name = "T_full", std::string tag = kElementWise) { + return full(shape, PrimType(dtype), fill_value, std::move(name), std::move(tag)); +} + +inline Tensor full(const ffi::Array& shape, PrimType dtype, const PrimExpr fill_value, std::string name = "T_full", std::string tag = kElementWise) { PrimExpr ev = cast(dtype, fill_value); if (!ev.defined()) { @@ -366,7 +392,7 @@ inline Tensor full(const ffi::Array& shape, DataType dtype, const Prim */ inline Tensor full_like(const Tensor& x, const PrimExpr fill_value, std::string name = "T_full_like", std::string tag = kElementWise) { - PrimExpr ev = cast(x->dtype, fill_value); + PrimExpr ev = cast(x->GetDataType(), fill_value); return compute(x->shape, [&](const ffi::Array& i) { return ev; }, name, tag); } @@ -392,19 +418,17 @@ inline Tensor full_like(const Tensor& x, const PrimExpr fill_value, * y = exp(f) = 1 + 2 * P(x**2)/(Q(x**2) - P(x**2)) */ inline Tensor fast_exp_float32(const Tensor& _x, std::string name, std::string tag) { - auto x_hi = FloatImm(DataType::Float(32), 88.3762626647950f); - auto x_lo = FloatImm(DataType::Float(32), -88.3762626647949f); - auto log2e = FloatImm(DataType::Float(32), 1.44269504088896341f); - auto ln2 = FloatImm(DataType::Float(32), 0.6931471805599453f); - PrimExpr p[6] = {FloatImm(DataType::Float(32), 1.9875691500E-4f), - FloatImm(DataType::Float(32), 1.3981999507E-3f), - FloatImm(DataType::Float(32), 8.3334519073E-3f), - FloatImm(DataType::Float(32), 4.1665795894E-2f), - FloatImm(DataType::Float(32), 1.6666665459E-1f), - FloatImm(DataType::Float(32), 5.0000001201E-1f)}; - auto one = FloatImm(DataType::Float(32), 1.0f); - auto one_half = FloatImm(DataType::Float(32), 0.5f); - auto b = FloatImm(DataType::Float(32), 127.0f); + PrimType f32_ty = PrimType::Float(32); + auto x_hi = FloatImm(f32_ty, 88.3762626647950f); + auto x_lo = FloatImm(f32_ty, -88.3762626647949f); + auto log2e = FloatImm(f32_ty, 1.44269504088896341f); + auto ln2 = FloatImm(f32_ty, 0.6931471805599453f); + PrimExpr p[6] = {FloatImm(f32_ty, 1.9875691500E-4f), FloatImm(f32_ty, 1.3981999507E-3f), + FloatImm(f32_ty, 8.3334519073E-3f), FloatImm(f32_ty, 4.1665795894E-2f), + FloatImm(f32_ty, 1.6666665459E-1f), FloatImm(f32_ty, 5.0000001201E-1f)}; + auto one = FloatImm(f32_ty, 1.0f); + auto one_half = FloatImm(f32_ty, 0.5f); + auto b = FloatImm(f32_ty, 127.0f); return compute( _x->shape, @@ -419,7 +443,7 @@ inline Tensor fast_exp_float32(const Tensor& _x, std::string name, std::string t (((((p[0] * f + p[1]) * f + p[2]) * f + p[3]) * f + p[4]) * f + p[5]) * f * f + f + one; // Return 2^m * exp(r). auto ef = - tvm::reinterpret(DataType::Float(32), ::tvm::cast(DataType::Int(32), n + b) << 23); + tvm::reinterpret(PrimType::Float(32), ::tvm::cast(PrimType::Int(32), n + b) << 23); return ::tvm::max(ef * y, _x(i)); // NOLINT(*) }, name, tag); @@ -437,7 +461,7 @@ inline Tensor fast_exp_float32(const Tensor& _x, std::string name, std::string t */ inline Tensor fast_exp(const Tensor& x, std::string name = "T_fast_exp", std::string tag = kElementWise) { - if (x->dtype == DataType::Float(32)) { + if (x->GetDataType().MatchesElementType(DLDataTypeCode::kDLFloat, 32)) { auto ret = fast_exp_float32(x, name, tag); return ret; } else { @@ -474,10 +498,11 @@ inline Tensor fast_erf_float16(const Tensor& data, std::string name, std::string */ inline Tensor fast_erf(const Tensor& x, std::string name = "T_fast_erf", std::string tag = kElementWise) { - if (x->dtype == DataType::Float(32)) { + PrimType x_type(x->GetDataType()); + if (x_type.MatchesElementType(DLDataTypeCode::kDLFloat, 32)) { auto ret = fast_erf_float32(x, name, tag); return ret; - } else if (x->dtype == DataType::Float(16)) { + } else if (x_type.MatchesElementType(DLDataTypeCode::kDLFloat, 16)) { auto ret = fast_erf_float16(x, name, tag); return ret; } else { diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index 0a448620dae3..b864bfe53ea3 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -57,7 +57,7 @@ inline tvm::te::Tensor relu(const tvm::te::Tensor& t, T threshold = static_cast< return tvm::te::compute( t->shape, [&](const tvm::ffi::Array& i) { - auto threshold_const = tvm::tirx::MakeConst(t->dtype, threshold); + auto threshold_const = tvm::tirx::MakeConst(tvm::PrimType(t->dtype), threshold); return tvm::max(t(i), threshold_const); }, name, tag); @@ -80,7 +80,7 @@ inline tvm::te::Tensor leaky_relu(const tvm::te::Tensor& t, double alpha = 0.1, t->shape, [&](const tvm::ffi::Array& i) { auto value = t(i); - auto calpha = tvm::tirx::MakeConst(value.dtype(), alpha); + auto calpha = tvm::tirx::MakeConst(value.ty(), alpha); return tvm::tirx::Select(value > 0, value, value * calpha); }, name, tag); @@ -171,10 +171,10 @@ inline tvm::te::Tensor pad( tvm::ffi::Array pad_after_int32; for (const auto& ele : pad_before) { - pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele)); + pad_before_int32.push_back(tvm::cast(tvm::PrimType::Int(32), ele)); } for (const auto& ele : pad_after) { - pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele)); + pad_after_int32.push_back(tvm::cast(tvm::PrimType::Int(32), ele)); } tvm::ffi::Array output_shape; @@ -194,7 +194,7 @@ inline tvm::te::Tensor pad( } if (!pad_value.defined()) { - pad_value = tvm::tirx::MakeConst(t->dtype, 0); + pad_value = tvm::tirx::MakeConst(tvm::PrimType(t->dtype), 0); } auto l = [&](tvm::ffi::Array ovars) { @@ -495,19 +495,19 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data, tvm::ffi::Array pad_after_int32; // pad size for batch dimension is 0 - pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), 0)); - pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), 0)); + pad_before_int32.push_back(tvm::cast(tvm::PrimType::Int(32), 0)); + pad_after_int32.push_back(tvm::cast(tvm::PrimType::Int(32), 0)); // insert pad sizes given for spatial dimensions for (const auto& ele : pad_before) { - pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele)); + pad_before_int32.push_back(tvm::cast(tvm::PrimType::Int(32), ele)); } for (const auto& ele : pad_after) { - pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele)); + pad_after_int32.push_back(tvm::cast(tvm::PrimType::Int(32), ele)); } // pad the input with paddings provided if (!pad_value.defined()) { - pad_value = tvm::tirx::MakeConst(data->dtype, 0); + pad_value = tvm::tirx::MakeConst(tvm::PrimType(data->dtype), 0); } padded_t = pad(data, pad_before_int32, pad_after_int32, pad_value); @@ -629,9 +629,9 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, // Crop the start and end of dimensions of out ffi::Array> begin_idx, end_idx; ffi::Array strides; - DataType index_dtype = DataType::Int(64); + PrimType index_ty = PrimType::Int(64); for (size_t i = 0; i < r_p_shape.size(); ++i) { - strides.push_back(IntImm(index_dtype, 1)); + strides.push_back(IntImm(index_ty, 1)); if (i > 0 && i <= num_block_dims) { // prepare begin and end index for spatial dimensions int64_t begin_i = GetConstInt(crop_begin_list[i - 1]); @@ -640,12 +640,12 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, TVM_FFI_ICHECK_GT(out_i, (begin_i + end_i)) << "Incorrect crop sizes for (" << i << ")th dim, can not crop more than" << " output size" << out_i << " vs " << (begin_i + end_i); - begin_idx.push_back(IntImm(index_dtype, begin_i)); - end_idx.push_back(IntImm(index_dtype, out_i - end_i)); + begin_idx.push_back(IntImm(index_ty, begin_i)); + end_idx.push_back(IntImm(index_ty, out_i - end_i)); } else { // ignore the batch and remaining dimension - begin_idx.push_back(IntImm(index_dtype, 0)); - end_idx.push_back(IntImm(index_dtype, GetConstInt(r_p_shape[i]))); + begin_idx.push_back(IntImm(index_ty, 0)); + end_idx.push_back(IntImm(index_ty, GetConstInt(r_p_shape[i]))); } } @@ -677,7 +677,7 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T [&](const tvm::ffi::Array& target_indices) { auto c = targets(); return tvm::tirx::Select(c != ignore_index, -predictions(c) * weights(c), - tvm::tirx::MakeConst(predictions->dtype, 0)); + tvm::tirx::MakeConst(tvm::PrimType(predictions->dtype), 0)); }, name, tag); if (reduction == "mean") { @@ -686,7 +686,7 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T [&](const tvm::ffi::Array& target_indices) { auto c = targets(); return tvm::tirx::Select(c != ignore_index, weights(c), - tvm::tirx::MakeConst(predictions->dtype, 0)); + tvm::tirx::MakeConst(tvm::PrimType(predictions->dtype), 0)); }, name, tag); return topi::divide(T, W); @@ -705,7 +705,7 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T pred_indices.push_back(target_indices[i]); // indices for multidimensional loss } return tvm::tirx::Select(c != ignore_index, -predictions(pred_indices) * weights(c), - tvm::tirx::MakeConst(predictions->dtype, 0)); + tvm::tirx::MakeConst(tvm::PrimType(predictions->dtype), 0)); }, name, tag); TVM_FFI_ICHECK(T->shape.size() != 0); @@ -715,7 +715,7 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T [&](const tvm::ffi::Array& target_indices) { auto c = targets(target_indices); return tvm::tirx::Select(c != ignore_index, weights(c), - tvm::tirx::MakeConst(predictions->dtype, 0)); + tvm::tirx::MakeConst(tvm::PrimType(predictions->dtype), 0)); }, name, tag); return topi::divide(topi::sum(T, tvm::ffi::Array(nullptr)), diff --git a/include/tvm/topi/nn/bnn.h b/include/tvm/topi/nn/bnn.h index 5faed879c005..56a6f3aaa815 100644 --- a/include/tvm/topi/nn/bnn.h +++ b/include/tvm/topi/nn/bnn.h @@ -71,14 +71,14 @@ inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data, int axis, start_idx.push_back(i == static_cast(axis) ? indices[i] * 32 : static_cast(indices[i])); } - PrimExpr packed = IntImm(DataType::UInt(32), 0); + PrimExpr packed = IntImm(PrimType::UInt(32), 0); for (size_t j = 0; j < 32; ++j) { ffi::Array idx; for (size_t i = 0; i < n; ++i) { idx.push_back(i == static_cast(axis) ? start_idx[i] + static_cast(j) : start_idx[i]); } - auto sign = tvm::cast(DataType::UInt(32), data(idx) >= 0); + auto sign = tvm::cast(PrimType::UInt(32), data(idx) >= 0); packed = (packed | sign); if (j == 31) { return packed; @@ -101,8 +101,8 @@ inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data, int axis, inline tvm::te::Tensor binary_dense(const tvm::te::Tensor& data, const tvm::te::Tensor& weight) { TVM_FFI_ICHECK_EQ(data->shape.size(), 2) << "binary_dense requires 2-D data"; TVM_FFI_ICHECK_EQ(weight->shape.size(), 2) << "binary_dense requires 2-D weight"; - TVM_FFI_ICHECK_EQ(data->dtype, DataType::UInt(32)) << "binary_dense requires uint32 data"; - TVM_FFI_ICHECK_EQ(weight->dtype, DataType::UInt(32)) << "binary_dense requires uint32 weight"; + TVM_FFI_ICHECK_EQ(data->dtype, PrimType::UInt(32)) << "binary_dense requires uint32 data"; + TVM_FFI_ICHECK_EQ(weight->dtype, PrimType::UInt(32)) << "binary_dense requires uint32 weight"; auto batch = data->shape[0]; auto in_dim = data->shape[1]; diff --git a/include/tvm/topi/nn/dense.h b/include/tvm/topi/nn/dense.h index be0030cd40d5..2c7b2330505e 100644 --- a/include/tvm/topi/nn/dense.h +++ b/include/tvm/topi/nn/dense.h @@ -46,7 +46,7 @@ using namespace tvm::te; * \return Tensor with shape [batch, out_dim] */ inline tvm::te::Tensor dense(const tvm::te::Tensor& data, const tvm::te::Tensor& weight, - const tvm::te::Tensor& bias, const DataType& out_dtype) { + const tvm::te::Tensor& bias, const PrimType& out_dtype) { TVM_FFI_ICHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data"; TVM_FFI_ICHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight"; if (bias.defined()) { diff --git a/include/tvm/topi/nn/dilate.h b/include/tvm/topi/nn/dilate.h index 0c8ea395c701..f45543eda337 100644 --- a/include/tvm/topi/nn/dilate.h +++ b/include/tvm/topi/nn/dilate.h @@ -95,7 +95,7 @@ inline Tensor dilate(const Tensor& x, ffi::Array strides, double dilat if (not_zero.size() > 0) { auto all_not_zero = all(not_zero); return tvm::if_then_else(all_not_zero, x(index_tuple), - MakeConst(x->dtype, dilation_value)); + MakeConst(PrimType(x->dtype), dilation_value)); } return x(index_tuple); }, diff --git a/include/tvm/topi/nn/group_norm.h b/include/tvm/topi/nn/group_norm.h index 4962587a9396..7a778dea8ce5 100644 --- a/include/tvm/topi/nn/group_norm.h +++ b/include/tvm/topi/nn/group_norm.h @@ -45,9 +45,9 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& const auto& beta_type = beta.defined() ? beta->dtype : data_type; TVM_FFI_ICHECK(data_type == gamma_type && data_type == beta_type) << "group_norm: data, gamma and beta must have the same type"; - TVM_FFI_ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16)) + TVM_FFI_ICHECK(data_type == PrimType::Float(32) || data_type == PrimType::Float(16)) << "group_norm: only support float32 and float16 for now"; - bool is_float16 = data_type == DataType::Float(16); + bool is_float16 = data_type == PrimType::Float(16); // reshape data C -> G, C/G int ndim = data->shape.size(); channel_axis = GetRealAxis(static_cast(ndim), ffi::Array({channel_axis}))[0]; @@ -65,7 +65,7 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& } Tensor data_reshaped; if (is_float16) { - data_reshaped = cast(reshape(data, new_shape), DataType::Float(32)); + data_reshaped = cast(reshape(data, new_shape), PrimType::Float(32)); } else { data_reshaped = reshape(data, new_shape); } @@ -126,7 +126,7 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& auto temp_x = temp_x_x2[0]; auto temp_x2 = temp_x_x2[1]; - PrimExpr reduce_extent = FloatImm(DataType::Float(32), 1); + PrimExpr reduce_extent = FloatImm(PrimType::Float(32), 1); for (auto axis : new_axes) { reduce_extent *= data_reshaped->shape[axis]; } @@ -142,10 +142,10 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& gamma_indices = {indices[channel_axis], indices[channel_axis + 1]}; auto mean = temp_x(non_reduce_indices) / reduce_extent; auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean; - PrimExpr group_norm = - (data_reshaped(indices) - mean) * tvm::rsqrt(var + MakeConst(data->dtype, epsilon)); + PrimExpr group_norm = (data_reshaped(indices) - mean) * + tvm::rsqrt(var + MakeConst(PrimType(data->dtype), epsilon)); if (is_float16) { - group_norm = Cast(DataType::Float(16), group_norm); + group_norm = Cast(PrimType::Float(16), group_norm); } if (gamma.defined()) { group_norm = topi::multiply(group_norm, gamma_reshaped(gamma_indices)); diff --git a/include/tvm/topi/nn/instance_norm.h b/include/tvm/topi/nn/instance_norm.h index 60361e8bc681..e246d97a59df 100644 --- a/include/tvm/topi/nn/instance_norm.h +++ b/include/tvm/topi/nn/instance_norm.h @@ -58,9 +58,9 @@ inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tenso const auto& beta_type = beta.defined() ? beta->dtype : data_type; TVM_FFI_ICHECK(data_type == gamma_type && data_type == beta_type) << "instance_norm: data, gamma and beta must have the same type"; - TVM_FFI_ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16)) + TVM_FFI_ICHECK(data_type == PrimType::Float(32) || data_type == PrimType::Float(16)) << "instance_norm: only support float32 and float16 for now"; - bool is_float16 = data_type == DataType::Float(16); + bool is_float16 = data_type == PrimType::Float(16); // sum x and x^2 auto ndim = data->shape.size(); TVM_FFI_ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; @@ -69,9 +69,10 @@ inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tenso auto target_shape = MakeReduceTargetShape(real_axis, data, /*keepdims=*/false, /*atleast1d=*/true); auto func = MakeTupleSumReducer(); + PrimType f32_ty = PrimType::Float(32); - auto compute = [ndim, is_float16, &real_axis, &reduce_axes, &func, - &data](const ffi::Array& indices) { + auto compute = [ndim, is_float16, &real_axis, &reduce_axes, &func, &data, + f32_ty](const ffi::Array& indices) { ffi::Array eval_range; int arg_counter = 0; int red_counter = 0; @@ -86,15 +87,14 @@ inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tenso arg_counter++; } } - auto square = [is_float16](const PrimExpr& x) { + auto square = [is_float16, f32_ty](const PrimExpr& x) { if (is_float16) { - return Cast(DataType::Float(32), x) * Cast(DataType::Float(32), x); + return Cast(f32_ty, x) * Cast(f32_ty, x); } return x * x; }; if (is_float16) { - return func({Cast(DataType::Float(32), data(eval_range)), square(data(eval_range))}, - reduce_axes, nullptr); + return func({Cast(f32_ty, data(eval_range)), square(data(eval_range))}, reduce_axes, nullptr); } else { return func({data(eval_range), square(data(eval_range))}, reduce_axes, nullptr); } @@ -106,7 +106,7 @@ inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tenso auto temp_x = temp_x_x2[0]; auto temp_x2 = temp_x_x2[1]; - auto reduce_extent = MakeConst(data->dtype, 1); + auto reduce_extent = MakeConst(PrimType(data->dtype), 1); for (int i : real_axis) { reduce_extent *= data->shape[i]; } @@ -124,9 +124,9 @@ inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tenso channel = indices[channel_axis]; auto mean = temp_x(non_reduce_indices) / reduce_extent; auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean; - auto instance_norm = (data(indices) - mean) * tvm::rsqrt(var + MakeConst(var->dtype, epsilon)); + auto instance_norm = (data(indices) - mean) * tvm::rsqrt(var + MakeConst(var.ty(), epsilon)); if (is_float16) { - instance_norm = Cast(DataType::Float(16), instance_norm); + instance_norm = Cast(PrimType::Float(16), instance_norm); } instance_norm = topi::multiply(instance_norm, gamma(channel)); if (beta.defined()) { diff --git a/include/tvm/topi/nn/layer_norm.h b/include/tvm/topi/nn/layer_norm.h index fb8155ef654a..8a995d7b91fe 100644 --- a/include/tvm/topi/nn/layer_norm.h +++ b/include/tvm/topi/nn/layer_norm.h @@ -57,9 +57,9 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& const auto& beta_type = beta.defined() ? beta->dtype : data_type; TVM_FFI_ICHECK(data_type == gamma_type && data_type == beta_type) << "layer_norm: data, gamma and beta must have the same type"; - TVM_FFI_ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16)) + TVM_FFI_ICHECK(data_type == PrimType::Float(32) || data_type == PrimType::Float(16)) << "layer_norm: only support float32 and float16 for now"; - bool is_float16 = data_type == DataType::Float(16); + bool is_float16 = data_type == PrimType::Float(16); // Two-pass algorithm for improved numerical stability: // pass1: mean = E[x] // pass2: var = E[(x - mean)^2] @@ -69,6 +69,7 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& auto reduce_axes = MakeReduceAxes(real_axis, data); auto target_shape = MakeReduceTargetShape(real_axis, data, /*keepdims=*/false, /*atleast1d=*/false); + PrimType f32_ty = PrimType::Float(32); auto make_eval_range = [&real_axis, &reduce_axes, ndim](const ffi::Array& non_reduce_indices) { @@ -91,17 +92,17 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& Tensor temp_sum = te::compute( target_shape, - [is_float16, &data, &reduce_axes, &make_eval_range](const ffi::Array& indices) { + [is_float16, &data, &reduce_axes, &make_eval_range, f32_ty](const ffi::Array& indices) { auto eval_range = make_eval_range(indices); PrimExpr x = data(eval_range); if (is_float16) { - x = Cast(DataType::Float(32), x); + x = Cast(f32_ty, x); } return sum(x, reduce_axes); }, data->op->name + "_sum", kCommReduce); - DataType reduce_dtype = is_float16 ? DataType::Float(32) : data->dtype; + PrimType reduce_dtype = is_float16 ? PrimType::Float(32) : PrimType(data->dtype); PrimExpr reduce_extent = MakeConst(reduce_dtype, 1); for (int i : real_axis) { reduce_extent *= data->shape[i]; @@ -115,12 +116,12 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& Tensor temp_var_sum = te::compute( target_shape, - [is_float16, &data, &reduce_axes, &make_eval_range, - &temp_mean](const ffi::Array& indices) { + [is_float16, &data, &reduce_axes, &make_eval_range, &temp_mean, + f32_ty](const ffi::Array& indices) { auto eval_range = make_eval_range(indices); PrimExpr x = data(eval_range); if (is_float16) { - x = Cast(DataType::Float(32), x); + x = Cast(f32_ty, x); } PrimExpr diff = x - temp_mean(indices); return sum(diff * diff, reduce_axes); @@ -138,9 +139,9 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& } auto mean = temp_mean(non_reduce_indices); auto var = temp_var_sum(non_reduce_indices) / reduce_extent; - auto layer_norm = (data(indices) - mean) * rsqrt(var + MakeConst(var->dtype, epsilon)); + auto layer_norm = (data(indices) - mean) * rsqrt(var + MakeConst(var.ty(), epsilon)); if (is_float16) { - layer_norm = Cast(DataType::Float(16), layer_norm); + layer_norm = Cast(PrimType::Float(16), layer_norm); } layer_norm = topi::multiply(layer_norm, gamma(reduce_indices)); if (beta.defined()) { diff --git a/include/tvm/topi/nn/local_response_norm.h b/include/tvm/topi/nn/local_response_norm.h index 7407448f88c5..4f411076387d 100644 --- a/include/tvm/topi/nn/local_response_norm.h +++ b/include/tvm/topi/nn/local_response_norm.h @@ -55,7 +55,8 @@ inline Tensor lrn(const Tensor& data, int size, int axis = 1, float alpha = 0.00 TVM_FFI_ICHECK_EQ(data->shape.size(), 4) << "LRN requires 4-D input"; TVM_FFI_ICHECK_EQ(size % 2, 1) << "size should be odd number"; TVM_FFI_ICHECK(axis == 1 || axis == 3) << "axis should be 1 or 3 for NCHW and NHWC"; - TVM_FFI_ICHECK(data->dtype.is_float()) << "datatype should be float"; + // LRN only requires a floating-point element kind; lane encoding is irrelevant here. + TVM_FFI_ICHECK_EQ(data->dtype.code(), DLDataTypeCode::kDLFloat) << "datatype should be float"; auto input_shape = data->shape; ffi::Array pad_before{0, 0, 0, 0}; ffi::Array pad_after{0, 0, 0, 0}; @@ -79,9 +80,9 @@ inline Tensor lrn(const Tensor& data, int size, int axis = 1, float alpha = 0.00 }, "tensor", "sqr_sum"); } - PrimExpr alpha_imm = tvm::te::MakeConst(data->dtype, alpha); - PrimExpr beta_imm = tvm::te::MakeConst(data->dtype, beta); - PrimExpr bias_imm = tvm::te::MakeConst(data->dtype, bias); + PrimExpr alpha_imm = tvm::te::MakeConst(PrimType(data->dtype), alpha); + PrimExpr beta_imm = tvm::te::MakeConst(PrimType(data->dtype), beta); + PrimExpr bias_imm = tvm::te::MakeConst(PrimType(data->dtype), bias); auto sqrt_sum_up = tvm::te::compute( input_shape, [&](Var i, Var j, Var k, Var l) { diff --git a/include/tvm/topi/nn/pooling.h b/include/tvm/topi/nn/pooling.h index e8410d8add22..91b10e7d8df9 100644 --- a/include/tvm/topi/nn/pooling.h +++ b/include/tvm/topi/nn/pooling.h @@ -117,7 +117,8 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width), "ww"); auto argmax = MakeArgmaxReducer(); - auto pad_x = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; + auto pad_x = + do_pad ? pad(x, pad_before, pad_after, tvm::min_value(PrimType(x->dtype)), "pad_temp") : x; auto mp_argmax = tvm::te::compute( out_shape, @@ -145,17 +146,17 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww); PrimExpr out_idx_lower_h = tirx::Select( - pad_inds[height_axis] < kernel_height, IntImm(pad_inds[height_axis].dtype(), 0), + pad_inds[height_axis] < kernel_height, IntImm(pad_inds[height_axis].ty(), 0), (pad_inds[height_axis] - kernel_height) / stride_height + 1); PrimExpr out_idx_lower_w = tirx::Select( - pad_inds[width_axis] < kernel_width, IntImm(pad_inds[width_axis].dtype(), 0), + pad_inds[width_axis] < kernel_width, IntImm(pad_inds[width_axis].ty(), 0), (pad_inds[width_axis] - kernel_width) / stride_width + 1); return tvm::sum( tvm::if_then_else(tirx::And(tirx::And(out_idx[height_axis] >= out_idx_lower_h, out_idx[width_axis] >= out_idx_lower_w), mp_inds(out_idx) == idx), - out_grad(out_idx), MakeConst(x->dtype, 0)), + out_grad(out_idx), MakeConst(PrimType(x->dtype), 0)), {windowh, windoww}); }, "T_pool_grad", "pool_grad_max"); @@ -176,10 +177,10 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww)); PrimExpr out_idx_lower_h = - tirx::Select(pad_h_idx < kernel_height, IntImm(pad_h_idx.dtype(), 0), + tirx::Select(pad_h_idx < kernel_height, IntImm(pad_h_idx.ty(), 0), (pad_h_idx - kernel_height) / stride_height + 1); PrimExpr out_idx_lower_w = - tirx::Select(pad_w_idx < kernel_width, IntImm(pad_w_idx.dtype(), 0), + tirx::Select(pad_w_idx < kernel_width, IntImm(pad_w_idx.ty(), 0), (pad_w_idx - kernel_width) / stride_width + 1); PrimExpr divide_factor; // number of pooled elements @@ -191,16 +192,17 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, PrimExpr h_end = min(h_start + kernel_height, height); PrimExpr w_end = min(w_start + kernel_width, width); - h_start = max(h_start, IntImm(h_start.dtype(), 0)); - w_start = max(w_start, IntImm(w_start.dtype(), 0)); - divide_factor = max((h_end - h_start) * (w_end - w_start), MakeConst(h_end.dtype(), 1)); + h_start = max(h_start, IntImm(h_start.ty(), 0)); + w_start = max(w_start, IntImm(w_start.ty(), 0)); + divide_factor = max((h_end - h_start) * (w_end - w_start), MakeConst(h_end.ty(), 1)); } return tvm::sum( tvm::if_then_else(tirx::And(tirx::And(out_idx[height_axis] >= out_idx_lower_h, out_idx[height_axis] < out_height), tirx::And(out_idx[width_axis] >= out_idx_lower_w, out_idx[width_axis] < out_width)), - out_grad(out_idx) / divide_factor, MakeConst(out_grad->dtype, 0)), + out_grad(out_idx) / divide_factor, + MakeConst(PrimType(out_grad->dtype), 0)), {windowh, windoww}); }, "T_pool_grad", "pool_grad_avg"); @@ -384,9 +386,9 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const ffi::Array& ou ffi::Array reduce_axes; std::tie(indices, reduce_axes) = get_iter_vars(output, false); - PrimExpr divide_factor = tvm::cast(x->dtype, 1); + PrimExpr divide_factor = tvm::cast(PrimType(x->dtype), 1); for (size_t i = 0; i < n_dim; ++i) { - divide_factor *= tvm::cast(DataType::Int(32), reduce_axes[i]->dom->extent); + divide_factor *= tvm::cast(PrimType::Int(32), reduce_axes[i]->dom->extent); } return div(pool_sum(indices), divide_factor); @@ -582,7 +584,8 @@ inline Tensor pool_impl_nd(const Tensor& x, const ffi::Array& kernel_s ffi::Map attrs; if (pool_type == kMaxPool) { - auto temp = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; + auto temp = + do_pad ? pad(x, pad_before, pad_after, tvm::min_value(PrimType(x->dtype)), "pad_temp") : x; attrs.Set("schedule_rule", tvm::ffi::String("meta_schedule.pool_max")); return tvm::te::compute( out_shape, @@ -657,7 +660,7 @@ inline Tensor pool_impl_nd(const Tensor& x, const ffi::Array& kernel_s // number that represents the number of steps along the dilated kernel to reach a // non-padded value. Otherwise this should be 0. PrimExpr jumps_to_non_pad = (dilation[i] - 1 - start[i]) / dilation[i]; - jumps_to_non_pad = max(jumps_to_non_pad, IntImm(jumps_to_non_pad.dtype(), 0)); + jumps_to_non_pad = max(jumps_to_non_pad, IntImm(jumps_to_non_pad.ty(), 0)); end[i] = min(end[i], data_shape[ii] - 1); num_el *= (end[i] - (start[i] + dilation[i] * jumps_to_non_pad)) / dilation[i] + 1; diff --git a/include/tvm/topi/nn/rms_norm.h b/include/tvm/topi/nn/rms_norm.h index 294d82054e3e..29f46918a754 100644 --- a/include/tvm/topi/nn/rms_norm.h +++ b/include/tvm/topi/nn/rms_norm.h @@ -54,8 +54,8 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const ffi::Arra const auto& weight_type = weight.defined() ? weight->dtype : data_type; TVM_FFI_ICHECK(data_type == weight_type) << "rms_norm: data and weight must have the same type"; - const auto& data_fp32 = cast(data, DataType::Float(32)); - const auto& weight_fp32 = cast(weight, DataType::Float(32)); + const auto& data_fp32 = cast(data, PrimType::Float(32)); + const auto& weight_fp32 = cast(weight, PrimType::Float(32)); auto square = multiply(data_fp32, data_fp32); auto square_sum = sum(square, axis, /*keepdims=*/false, /*atleast1d=*/true); @@ -63,7 +63,7 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const ffi::Arra auto ndim = data_fp32->shape.size(); TVM_FFI_ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; auto real_axis = GetRealAxis(static_cast(ndim), axis); - auto reduce_extent = MakeConst(data_fp32->dtype, 1); + auto reduce_extent = MakeConst(PrimType(data_fp32->dtype), 1); for (int i : real_axis) { reduce_extent *= data_fp32->shape[i]; } @@ -74,8 +74,8 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const ffi::Arra non_reduce_indices.push_back(indices[i]); } } - auto output = - tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent + MakeConst(data_type, epsilon)); + auto output = tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent + + MakeConst(PrimType(data_type), epsilon)); return output; }; auto rsqrt_shape = ffi::Array(); diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index e6b4c5af1dea..fbea4a57eabf 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -259,7 +259,7 @@ inline Tensor CommReduceIdx(const Tensor& data, const ffi::Optional(ffi::Array lhs, ffi::Array rhs)>; /*! \brief An initializer function for a reduction */ -using FIdentity = std::function(std::vector types)>; +using FIdentity = std::function(std::vector types)>; /*! * \brief Create a commutative reducer for a reduction @@ -275,10 +275,10 @@ inline FCommReduce MakeCommReducer(FCombine fcombine, FIdentity fidentity, return [fcombine, fidentity, name](ffi::Array exprs, const ffi::Array& axis, PrimExpr* condition) { ffi::Array lhs, rhs; - std::vector dtypes; + std::vector dtypes; for (size_t i = 0; i < exprs.size(); ++i) { - auto dtype = exprs[i].dtype(); + PrimType dtype = exprs[i].ty(); dtypes.push_back(dtype); lhs.push_back(var(name + "_lhs_" + std::to_string(i), dtype)); rhs.push_back(var(name + "_rhs_" + std::to_string(i), dtype)); @@ -330,7 +330,8 @@ inline PrimExpr ProdOp(PrimExpr source, ffi::Array axis, ffi::Array>& axis, bool keepdims = false, bool atleast1d = false) { - if (data->dtype.is_bool()) { + // Reduction dispatch only depends on boolean element kind; lane encoding is irrelevant here. + if (data->dtype.code() == DLDataTypeCode::kDLBool) { return CommReduce(data, axis, tvm::any, keepdims, atleast1d); } else { return CommReduce(data, axis, tvm::sum, keepdims, atleast1d); @@ -477,7 +478,7 @@ inline FCommReduce MakeArgminReducer(bool select_last_index = false) { result.push_back(tvm::tirx::Select(is_smaller, lhs[1], rhs[1])); // val return result; }; - auto fidentity = [&](std::vector types) { + auto fidentity = [&](std::vector types) { ffi::Array result; result.push_back(tvm::tirx::MakeConst(types[0], -1)); // idx result.push_back(tvm::max_value(types[1])); // val @@ -539,7 +540,7 @@ inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { result.push_back(tvm::tirx::Select(is_bigger, lhs[1], rhs[1])); // val return result; }; - auto fidentity = [&](std::vector types) { + auto fidentity = [&](std::vector types) { ffi::Array result; result.push_back(tvm::tirx::MakeConst(types[0], -1)); // idx result.push_back(tvm::min_value(types[1])); // val @@ -601,7 +602,7 @@ inline FCommReduce MakeTupleSumReducer() { } return result; }; - auto fidentity = [](std::vector types) { + auto fidentity = [](std::vector types) { ffi::Array result; for (size_t i = 0; i < types.size(); ++i) { result.push_back(tvm::tirx::MakeConst(types[i], 0)); diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index e216cf86ced4..f2ede7af8aa0 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -44,8 +44,8 @@ #include #include +#include "tvm/ffi/dtype.h" #include "tvm/ir/expr.h" -#include "tvm/runtime/data_type.h" #include "tvm/tirx/expr.h" #include "tvm/tirx/op.h" #include "tvm/tirx/var.h" @@ -338,7 +338,8 @@ inline Tensor reshape(const Tensor& x, ffi::Array newshape, // If either the input shape or the target shape contains a zero, return an empty tensor. if (is_empty_shape(target_shape) || is_empty_shape(x->shape)) { return compute( - target_shape, [&](const ffi::Array& indices) { return tvm::cast(x->dtype, 0); }, name, + target_shape, + [&](const ffi::Array& indices) { return tvm::cast(PrimType(x->dtype), 0); }, name, tag); } else { return compute( @@ -679,7 +680,7 @@ inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr stri if (index->IsInstance() && extent->IsInstance() && stride->IsInstance()) { return tvm::IntImm( - tvm::DataType::Int(64), + tvm::PrimType::Int(64), StaticCanonicalizeIndex(GetConstInt(index), GetConstInt(extent), GetConstInt(stride))); } return DynamicCanonicalizeIndex(index, extent, stride); @@ -835,14 +836,14 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b bool assume_inbound = true, std::string name = "T_strided_slice_dynamic", std::string tag = topi::kInjective) { - DataType index_dtype = begin->shape[0]->dtype; + PrimType index_ty = begin->shape[0].ty(); const int64_t num_dynamic_axes = begin->shape[0].as()->value; TVM_FFI_ICHECK_EQ(end->shape[0].as()->value, num_dynamic_axes); TVM_FFI_ICHECK_EQ(strides->shape[0].as()->value, num_dynamic_axes); ffi::Array begin_expr, end_expr, strides_expr; for (int64_t i = 0; i < num_dynamic_axes; ++i) { - auto ind = MakeConst(index_dtype, i); + auto ind = MakeConst(index_ty, i); begin_expr.push_back(begin(ind)); end_expr.push_back(end(ind)); strides_expr.push_back(strides(ind)); @@ -874,10 +875,10 @@ inline ffi::Array StridedSliceOutputShape(const ffi::Array& axes.size() == strides.size()); std::vector begin_vec, end_vec, strides_vec; std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode); - DataType index_dtype = - (begin.size() > 0 && begin[0].defined()) ? begin[0].value()->dtype : DataType::Int(64); + PrimType index_ty = + (begin.size() > 0 && begin[0].defined()) ? begin[0].value().ty() : PrimType::Int(64); auto begin_canonicalized = - StridedSliceCanonicalizeBegin(ishape, begin_vec, strides_vec, axes, index_dtype, slice_mode); + StridedSliceCanonicalizeBegin(ishape, begin_vec, strides_vec, axes, index_ty, slice_mode); return StridedSliceOutputShape(ishape, begin_vec, end_vec, strides_vec, axes, slice_mode, begin_canonicalized, true); } @@ -924,10 +925,10 @@ inline Tensor strided_slice_with_axes( std::vector begin_vec, end_vec, strides_vec; std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode); - DataType index_dtype = - (begin.size() > 0 && begin[0].defined()) ? begin[0].value()->dtype : DataType::Int(64); + PrimType index_ty = + (begin.size() > 0 && begin[0].defined()) ? begin[0].value().ty() : PrimType::Int(64); auto begin_expr = StridedSliceCanonicalizeBegin(x->shape, begin_vec, strides_vec, normalized_axes, - index_dtype, slice_mode); + index_ty, slice_mode); auto out_shape = StridedSliceOutputShape(x->shape, begin_vec, end_vec, strides_vec, normalized_axes, slice_mode, begin_expr); @@ -938,7 +939,7 @@ inline Tensor strided_slice_with_axes( for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]); for (size_t i = 0; i < normalized_axes.size(); ++i) { int64_t ax = normalized_axes[i]; - auto stride = MakeConst(strides[i]->dtype, strides_vec[i]); + auto stride = MakeConst(strides[i]->ty(), strides_vec[i]); PrimExpr ind = indices[ax] * stride + begin_expr[i]; real_indices.Set(ax, ind); } @@ -972,11 +973,11 @@ inline Tensor strided_slice(const Tensor& x, const ffi::Array> end_full(end); ffi::Array strides_full(strides); - DataType index_dtype = - (begin.size() > 0 && begin[0].defined()) ? begin[0].value()->dtype : DataType::Int(64); - const IntImm one = IntImm(index_dtype, 1); - const IntImm zero = IntImm(index_dtype, 0); - const IntImm max_range = max_value(index_dtype).as_or_throw(); + PrimType index_ty = + (begin.size() > 0 && begin[0].defined()) ? begin[0].value().ty() : PrimType::Int(64); + const IntImm one = IntImm(index_ty, 1); + const IntImm zero = IntImm(index_ty, 0); + const IntImm max_range = max_value(index_ty).as_or_throw(); for (size_t i = strides.size(); i < src_tensor_dim; ++i) { strides_full.push_back(one); @@ -1073,7 +1074,8 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, [&](const ffi::Array& out_index) { auto idx = tvm::if_then_else( indices(out_index) < 0 || indices(out_index) >= a_size, - tvm::FloatImm(a->dtype, std::numeric_limits::quiet_NaN()), indices(out_index)); + tvm::FloatImm(tvm::PrimType(a->dtype), std::numeric_limits::quiet_NaN()), + indices(out_index)); return a(UnravelIndex(idx, a_shape)); }, name, tag); @@ -1116,9 +1118,9 @@ inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, doub auto tid = out_index[axis]; auto bid = out_index[1 - axis]; len_index.push_back(bid); - PrimExpr ret = - tvm::if_then_else(tvm::cast(valid_length->dtype, tid) >= valid_length(len_index), - tvm::tirx::MakeConst(data->dtype, mask_value), data(out_index)); + PrimExpr ret = tvm::if_then_else( + tvm::cast(PrimType(valid_length->dtype), tid) >= valid_length(len_index), + tvm::tirx::MakeConst(PrimType(data->dtype), mask_value), data(out_index)); return ret; }, name, tag); @@ -1293,7 +1295,7 @@ inline Tensor take(const Tensor& a, ffi::Variant indices, int PrimExpr in_bounds = idx >= 0 && idx < axis_dim; return tvm::if_then_else( in_bounds, a(real_indices), - tvm::tirx::MakeConst(a->dtype, std::numeric_limits::quiet_NaN())); + tvm::tirx::MakeConst(PrimType(a->dtype), std::numeric_limits::quiet_NaN())); }, name, tag); } else { // mode == "wrap" @@ -1443,8 +1445,8 @@ inline Tensor tile(const Tensor& x, ffi::Array reps, std::string name = if (is_empty_shape(new_shape)) { return compute( - new_shape, [&](const ffi::Array& indices) { return tvm::cast(x->dtype, 0); }, name, - tag); + new_shape, [&](const ffi::Array& indices) { return tvm::cast(PrimType(x->dtype), 0); }, + name, tag); } else { return compute( new_shape, @@ -1478,8 +1480,8 @@ inline Tensor dyn_tile(const Tensor& x, ffi::Array new_shape, size_t r size_t ndim = x->shape.size(); if (is_empty_shape(new_shape)) { return compute( - new_shape, [&](const ffi::Array& indices) { return tvm::cast(x->dtype, 0); }, name, - tag); + new_shape, [&](const ffi::Array& indices) { return tvm::cast(PrimType(x->dtype), 0); }, + name, tag); } else { return compute( new_shape, @@ -1526,7 +1528,9 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, size_t indices_dim_i = static_cast(GetConstInt(indices->shape[axis])); TVM_FFI_ICHECK_GE(indices_dim_i, 1); } - TVM_FFI_ICHECK(indices->dtype.is_int() || indices->dtype.is_uint()); + // Index tensors are validated by integer element kind; vector lane encoding is irrelevant here. + PrimType indices_ty = indices->dtype; + TVM_FFI_ICHECK(indices_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)); ffi::Array out_shape; for (size_t i = 0; i < ndim_i; ++i) { @@ -1593,10 +1597,13 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dim } for (size_t i = 0; i < indices_dim0; ++i) { indices_position.Set(0, IntImm::Int32(i)); - if (indices->dtype.is_int() || indices->dtype.is_uint()) { + // Index tensors are validated by integer element kind; vector lane encoding is + // irrelevant for choosing whether an index cast is needed. + PrimType indices_ty = indices->dtype; + if (indices_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) { real_indices.push_back(indices(indices_position)); } else { - real_indices.push_back(tvm::cast(tvm::DataType::Int(32), indices(indices_position))); + real_indices.push_back(tvm::cast(tvm::PrimType::Int(32), indices(indices_position))); } } if (real_indices.size() == ndim_d) { @@ -1740,10 +1747,15 @@ inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, ffi::ArrayCanProveGreaterEqual(step, 1)) { // fast path for integer arange when step is positive num_elem = tvm::floordiv((stop - start + step - 1), step); @@ -1752,8 +1764,8 @@ inline Tensor arange(const PrimExpr& start, const PrimExpr& stop, const PrimExpr num_elem = tvm::floordiv((start - stop - step - 1), -step); } else { // fallback path for non-integer or step of unknown sign - num_elem = tvm::cast(DefaultIndexType(), - tvm::ceil(tvm::cast(tvm::DataType::Float(32), stop - start) / step)); + num_elem = tvm::cast(PrimType(DefaultIndexType()), + tvm::ceil(tvm::cast(tvm::PrimType::Float(32), stop - start) / step)); } num_elem = analyzer->Simplify(num_elem); @@ -1845,7 +1857,8 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout, for (size_t i = 0; i < src.ndim(); ++i) { in_range = in_range && (src_indices[i] < src->shape[i]); } - return if_then_else(in_range, src(src_indices), tvm::cast(src->dtype, PrimExpr(0))); + return if_then_else(in_range, src(src_indices), + tvm::cast(PrimType(src->dtype), PrimExpr(0))); }, name, tag, attrs); } @@ -1960,7 +1973,7 @@ inline Tensor meta_schedule_layout_transform( ffi::Array iter_domain; iter_domain.reserve(src->shape.size()); for (const PrimExpr& e : src->shape) { - iter_domain.push_back(Range::FromMinExtent(IntImm(e->dtype, 0), e)); + iter_domain.push_back(Range::FromMinExtent(IntImm(e.ty(), 0), e)); } ffi::Array post_transform_shape = index_map->MapShape(src->shape, analyzer); return compute( @@ -1980,7 +1993,7 @@ inline Tensor meta_schedule_layout_transform( * \param tag output tensor tag. * \return Tensor of input shape. */ -inline Tensor shape(const Tensor& src, DataType dtype, const std::string name = "T_shape", +inline Tensor shape(const Tensor& src, PrimType dtype, const std::string name = "T_shape", const std::string tag = kInjective) { int ndim = static_cast(src->shape.size()); ffi::Array out_shape{ndim}; @@ -1997,6 +2010,11 @@ inline Tensor shape(const Tensor& src, DataType dtype, const std::string name = name, tag); } +inline Tensor shape(const Tensor& src, DLDataType dtype, const std::string name = "T_shape", + const std::string tag = kInjective) { + return shape(src, PrimType(dtype), name, tag); +} + /*! * \brief Get the size of input tensor. * \param src the input tensor. @@ -2005,7 +2023,7 @@ inline Tensor shape(const Tensor& src, DataType dtype, const std::string name = * \param tag output tensor tag. * \return Tensor of input shape. */ -inline te::Tensor tensor_size(const te::Tensor& src, const DataType& dtype, +inline te::Tensor tensor_size(const te::Tensor& src, PrimType dtype, const std::string& name = "tensor_size", const std::string& tag = kInjective) { int ndim = static_cast(src->shape.size()); @@ -2022,6 +2040,12 @@ inline te::Tensor tensor_size(const te::Tensor& src, const DataType& dtype, name, tag); } +inline te::Tensor tensor_size(const te::Tensor& src, DLDataType dtype, + const std::string& name = "tensor_size", + const std::string& tag = kInjective) { + return tensor_size(src, PrimType(dtype), name, tag); +} + /*! * \brief Returns a one-hot tensor where the locations repsented by indices take value on_value, other locations take value off_value. @@ -2037,7 +2061,7 @@ inline te::Tensor tensor_size(const te::Tensor& src, const DataType& dtype, * \return one-hot tensor. */ inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const PrimExpr off_value, - int depth, int axis, const DataType& dtype, + int depth, int axis, PrimType dtype, ffi::Array oshape = ffi::Array(), const std::string name = "T_one_hot", const std::string tag = kInjective) { int true_axis = (axis == -1) ? indices->shape.size() : axis; @@ -2073,6 +2097,14 @@ inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const Prim name, tag); } +inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const PrimExpr off_value, + int depth, int axis, DLDataType dtype, + ffi::Array oshape = ffi::Array(), + const std::string name = "T_one_hot", const std::string tag = kInjective) { + return one_hot(indices, on_value, off_value, depth, axis, PrimType(dtype), std::move(oshape), + name, tag); +} + /*! * \brief Get a dense tensor. * \param sparse_indices sparse_indices[i] contains sparse_values[i] will be placed. @@ -2088,7 +2120,9 @@ inline Tensor sparse_to_dense(const Tensor& sparse_indices, const PrimExpr& default_value, const std::string name = "T_sparse_to_dense", const std::string tag = kInjective) { - TVM_FFI_ICHECK(sparse_indices->dtype.is_int()) << "sparse_indices only accepts integer values"; + // Sparse indices are validated by signed integer element kind; lane encoding is irrelevant here. + TVM_FFI_ICHECK_EQ(sparse_indices->dtype.code(), DLDataTypeCode::kDLInt) + << "sparse_indices only accepts integer values"; TVM_FFI_ICHECK_LE(sparse_indices->shape.size(), 3) << "sparse_indices tensor should be 0D, 1D, or 2D only"; TVM_FFI_ICHECK_LE(sparse_values->shape.size(), 2) diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 4fbebeddd0f5..e6a33ac9b4a6 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -43,7 +43,10 @@ class PrimExpr(BaseExpr): optimizations and integer analysis. """ - dtype: str + @property + def dtype(self): + """Return the runtime dtype represented by this expression's PrimType.""" + return self.ty.dtype @tvm_ffi.register_object("ir.RelaxExpr") diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py index 567ebafa2d5c..96548439d70e 100644 --- a/python/tvm/ir/type.py +++ b/python/tvm/ir/type.py @@ -53,6 +53,35 @@ class PrimType(Type): def __init__(self, dtype): self.__init_handle_by_constructor__(_ffi_api.PrimType, dtype) + def __eq__(self, other): + if isinstance(other, str): + return self.dtype == other + return super().__eq__(other) + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + dtype = self.dtype + return hash((dtype.type_code, dtype.bits, dtype.lanes)) + + def __str__(self): + return str(self.dtype) + + def matches_code(self, *codes) -> bool: + """Return whether this type has any of the given DLPack dtype codes.""" + type_code = self.dtype.type_code + return any(type_code == int(code) for code in codes) + + def matches_element_type(self, code, bits: int) -> bool: + """Return whether this type has the given scalar element code and bits.""" + dtype = self.dtype + return dtype.type_code == int(code) and dtype.bits == bits + + def is_scalar(self) -> bool: + """Return whether this type has exactly one fixed lane.""" + return self.dtype.lanes == 1 + @tvm_ffi.register_object("ir.PointerType") class PointerType(Type): diff --git a/python/tvm/relax/frontend/nn/extern.py b/python/tvm/relax/frontend/nn/extern.py index 9c8efce690f1..6c7f3dc72c9f 100644 --- a/python/tvm/relax/frontend/nn/extern.py +++ b/python/tvm/relax/frontend/nn/extern.py @@ -145,7 +145,7 @@ def shape_dtype_inference(a, b): // those headers are guaranteed to be available #include - #include + #include #include namespace { diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index f987f48d4251..b9ab88da0b43 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -29,6 +29,7 @@ import tvm_ffi from tvm import relax, tirx +from tvm.runtime import DataTypeCode class BaseFXGraphImporter(metaclass=abc.ABCMeta): @@ -566,7 +567,7 @@ def _pow(self, node: fx.Node) -> relax.Var: if ( isinstance(lhs, relax.Expr) and isinstance(lhs.ty, relax.TensorType) - and "int" in lhs.ty.dtype + and lhs.ty.dtype.matches_code(DataTypeCode.INT, DataTypeCode.UINT) and isinstance(rhs, int) and not isinstance(rhs, bool) and rhs >= 0 @@ -1607,7 +1608,7 @@ def transpose_and_reshape_back(tensor): if attn_mask is not None: attn_mask = self.env[attn_mask] msg = "Only a float mask is supported for the attn_mask input." - assert "float" in attn_mask.ty.dtype, msg + assert attn_mask.ty.dtype.matches_code(DataTypeCode.FLOAT, DataTypeCode.BFLOAT), msg attention_output = self.block_builder.emit( relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask) diff --git a/python/tvm/relax/op/create.py b/python/tvm/relax/op/create.py index 9d28ed92f9c5..1bbeeee8f272 100644 --- a/python/tvm/relax/op/create.py +++ b/python/tvm/relax/op/create.py @@ -17,6 +17,7 @@ """Creation operators.""" from tvm import DataType, DataTypeCode +from tvm.ir import PrimType from tvm.ir.expr import PrimExpr from ..expr import Expr, PrimValue, ShapeExpr @@ -267,7 +268,12 @@ def is_int(expr): return True if isinstance(expr, PrimValue): expr = expr.value - return isinstance(expr, PrimExpr) and DataType(expr.dtype).type_code == DataTypeCode.INT # type: ignore + if isinstance(expr, PrimExpr): + dtype = expr.dtype # type: ignore + if isinstance(dtype, PrimType): + dtype = dtype.dtype + return DataType(dtype).type_code == DataTypeCode.INT + return False if dtype is None: args = (start, end, step) diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 4b787c265bc3..43a2bd400351 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -19,6 +19,7 @@ from collections.abc import Callable from tvm.ir.expr import PrimExpr +from tvm.runtime import DataTypeCode from tvm.tirx import FloatImm, IndexMap, IntImm from ..expr import Expr, PrimValue, ShapeExpr @@ -151,10 +152,12 @@ def layout_transform( if pad_value is None: pass elif not isinstance(pad_value, PrimValue): - if "int" in x_dtype and isinstance(pad_value, int): - pad_value = IntImm(x_dtype, pad_value) - elif "float" in x_dtype and (isinstance(pad_value, int | float)): - pad_value = FloatImm(x_dtype, float(pad_value)) + if x_dtype.matches_code(DataTypeCode.INT, DataTypeCode.UINT) and isinstance(pad_value, int): + pad_value = IntImm(x_dtype.dtype, pad_value) + elif x_dtype.matches_code(DataTypeCode.FLOAT, DataTypeCode.BFLOAT) and ( + isinstance(pad_value, int | float) + ): + pad_value = FloatImm(x_dtype.dtype, float(pad_value)) pad_value = PrimValue(pad_value) if axis_separators is None: diff --git a/python/tvm/relax/transform/legalize_ops/common.py b/python/tvm/relax/transform/legalize_ops/common.py index 1b7d1179a521..f464c248e363 100644 --- a/python/tvm/relax/transform/legalize_ops/common.py +++ b/python/tvm/relax/transform/legalize_ops/common.py @@ -20,6 +20,7 @@ import tvm from tvm import te +from tvm.runtime import DataTypeCode from tvm.tirx import FloatImm, IntImm from ...block_builder import BlockBuilder @@ -38,9 +39,6 @@ LegalizeFunc = Callable[[BlockBuilder, Call], Expr] -##################### Utilities ##################### - - def _try_convert_to_scalar_const( expr: Expr, python_native: bool = False ) -> Expr | FloatImm | IntImm | bool | float | int: @@ -69,13 +67,14 @@ def _try_convert_to_scalar_const( # get the value of the scalar constant value = expr.data.numpy()[()].item() dtype = expr.ty.dtype + dtype_str = str(dtype.dtype) if python_native: return value # preserve the data type of the constant - if dtype.startswith("float"): - return tvm.tirx.FloatImm(dtype, value) - elif dtype.startswith("int") or dtype.startswith("uint") or dtype.startswith("bool"): - return tvm.tirx.IntImm(dtype, value) + if dtype.matches_code(DataTypeCode.FLOAT, DataTypeCode.BFLOAT): + return tvm.tirx.FloatImm(dtype_str, value) + elif dtype.matches_code(DataTypeCode.INT, DataTypeCode.UINT, DataTypeCode.BOOL): + return tvm.tirx.IntImm(dtype_str, value) return expr diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index f0cc8977d4ef..a59b1f9fe52e 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -19,7 +19,7 @@ """Default legalization function for manipulate operators.""" import tvm -from tvm import relax, s_tir, te, tirx, topi +from tvm import DataTypeCode, relax, s_tir, te, tirx, topi from tvm.relax.op.base import call_tir from tvm.relax.type import TensorType from tvm.relax.utils import gen_call_tir_inputs @@ -337,7 +337,7 @@ def set_axis_sep(axis_sep: list, sch: s_tir.schedule, buffer_type: str): if pad_value is not None: pad_value = pad_value.value else: - if "int" in call.args[0].ty.dtype: + if call.args[0].ty.dtype.matches_code(DataTypeCode.INT, DataTypeCode.UINT): pad_value = 0 else: pad_value = 0.0 diff --git a/python/tvm/relax/transform/legalize_ops/qdq.py b/python/tvm/relax/transform/legalize_ops/qdq.py index aa86f6fca2c3..7a825e300e40 100644 --- a/python/tvm/relax/transform/legalize_ops/qdq.py +++ b/python/tvm/relax/transform/legalize_ops/qdq.py @@ -19,6 +19,7 @@ import tvm from tvm import te, tirx +from tvm.runtime import DataTypeCode from ...block_builder import BlockBuilder from ...expr import Call, Expr @@ -140,7 +141,11 @@ def dequantize_compute(*indices): zp_value = zp[(0,) * len(zp.shape)] else: zp_value = zp[indices[axis]] - dtype = "float32" if "float" in data.dtype else "int32" + dtype = ( + "float32" + if data.dtype.matches_code(DataTypeCode.FLOAT, DataTypeCode.BFLOAT) + else "int32" + ) sub = te.subtract(data[indices].astype(dtype), zp_value) out = te.multiply(sub, scale_value.astype("float32")) if out_dtype == "float32": diff --git a/python/tvm/relax/type.py b/python/tvm/relax/type.py index ad8f469826ef..305f01750306 100644 --- a/python/tvm/relax/type.py +++ b/python/tvm/relax/type.py @@ -21,7 +21,7 @@ import tvm_ffi from tvm_ffi import Array -from tvm.ir import EnvFunc, PrimExpr, Span, TupleType, VDevice +from tvm.ir import EnvFunc, PrimExpr, PrimType, Span, TupleType, VDevice from . import _ffi_api from .expr import Expr, ShapeExpr, Type @@ -92,7 +92,7 @@ class TensorType(Type): """ shape: Expr | None - dtype: str + dtype: PrimType vdevice: VDevice | None ndim: int span: Span @@ -100,13 +100,15 @@ class TensorType(Type): def __init__( self, shape: Expr | None | list[PrimExpr] = None, - dtype: str = "float32", + dtype: str | PrimType | None = "float32", vdevice: VDevice | None | str = None, ndim: int = -1, span: Span = None, ) -> None: if isinstance(shape, list | tuple | Array): shape = ShapeExpr(shape) + if dtype is not None and not isinstance(dtype, PrimType): + dtype = PrimType(dtype) self.__init_handle_by_constructor__( _ffi_api.TensorType, shape, diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py index 51c8805f9445..505613d0372e 100644 --- a/python/tvm/runtime/object_generic.py +++ b/python/tvm/runtime/object_generic.py @@ -66,5 +66,9 @@ def const(value, dtype=None, span=None): if dtype is None: dtype = _scalar_type_inference(value) if dtype == "uint64" and value >= (1 << 63): - return _ffi_node_api.LargeUIntImm(dtype, value & ((1 << 32) - 1), value >> 32, span) + from tvm.ir import PrimType # pylint: disable=import-outside-toplevel + + return _ffi_node_api.LargeUIntImm( + PrimType(dtype), value & ((1 << 32) - 1), value >> 32, span + ) return _ffi_node_api._const(value, dtype, span) diff --git a/python/tvm/s_tir/schedule/schedule.py b/python/tvm/s_tir/schedule/schedule.py index 7f191df98d84..25b81239189d 100644 --- a/python/tvm/s_tir/schedule/schedule.py +++ b/python/tvm/s_tir/schedule/schedule.py @@ -24,7 +24,7 @@ from tvm.error import register_error from tvm.ir import GlobalVar, IRModule, PrimExpr -from tvm.runtime import Object +from tvm.runtime import DataTypeCode, Object from tvm.tirx import Buffer, FloatImm, For, IntImm, PrimFunc, SBlock from tvm.tirx.function import IndexMap @@ -3465,10 +3465,14 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> # buffer's type. If the default `tvm.runtime.convert` # behavior is applied, these would be converted to # int32/float32, which may not match the buffer's type. - if "int" in buffer_obj.dtype and isinstance(pad_value, int): - pad_value = IntImm(buffer_obj.dtype, pad_value) - elif "float" in buffer_obj.dtype and isinstance(pad_value, float): - pad_value = FloatImm(buffer_obj.dtype, pad_value) + if buffer_obj.dtype.matches_code(DataTypeCode.INT, DataTypeCode.UINT) and isinstance( + pad_value, int + ): + pad_value = IntImm(buffer_obj.dtype.dtype, pad_value) + elif buffer_obj.dtype.matches_code(DataTypeCode.FLOAT, DataTypeCode.BFLOAT) and ( + isinstance(pad_value, float) + ): + pad_value = FloatImm(buffer_obj.dtype.dtype, pad_value) pad_value = IndexMap.from_func( lambda *indices: pad_value, ndim=len(index_map.final_indices), diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py index 4d38292b9b56..dec30f29a114 100644 --- a/python/tvm/script/parser/core/evaluator.py +++ b/python/tvm/script/parser/core/evaluator.py @@ -396,7 +396,7 @@ def _eval_if_exp(self, fields: dict[str, Any]) -> Any: orelse = self._eval_expr(fields["orelse"]) if isinstance(test, bool): return body if test else orelse - elif isinstance(test, tvm.tirx.PrimExpr) and test.dtype == "bool": + elif isinstance(test, tvm.tirx.PrimExpr) and test.dtype.type_code == tvm.DataTypeCode.BOOL: return tvm.tirx.op.if_then_else(test, body, orelse) else: raise TypeError(f"Expected Python bool or TIR bool, but got {type(test)}") diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index 531915c6798a..b7238cf07eda 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -19,6 +19,7 @@ # pylint: disable=invalid-name import tvm_ffi +from tvm.ir import PrimType from tvm.runtime import Object, ObjectConvertible from tvm.tirx import DataProducer from tvm.tirx import expr as _expr @@ -49,6 +50,10 @@ def dtype(self): """Data content of the tensor.""" return self.tensor.dtype + def expr_ty(self): + """Compile-time element type of the tensor.""" + return self.tensor.expr_ty() + @tvm_ffi.register_object("te.Tensor") class Tensor(DataProducer, _expr.ExprOp): @@ -86,6 +91,15 @@ def ndim(self): """Dimension of the tensor.""" return len(self.shape) + @property + def dtype(self): + """Data content of the tensor.""" + return PrimType(_ffi_api.TensorDType(self)) + + def expr_ty(self): + """Compile-time element type of the tensor.""" + return self.dtype + @property def name(self): op = self.op diff --git a/python/tvm/tirx/buffer.py b/python/tvm/tirx/buffer.py index 4caf154547fa..d021bb317220 100644 --- a/python/tvm/tirx/buffer.py +++ b/python/tvm/tirx/buffer.py @@ -544,7 +544,7 @@ def decl_buffer( elem_offset = Var(f"{name}_elem_offset", shape_dtype) if data is None: # Bool is represented as uint1 in the IR, but stored as int8 - storage_type = PrimType(dtype) + storage_type = dtype if isinstance(dtype, PrimType) else PrimType(dtype) storage_type = PrimType("int8") if storage_type.dtype == "bool" else storage_type data = Var(name, PointerType(storage_type, scope), span) return _ffi_api.Buffer( # type: ignore diff --git a/python/tvm/tirx/expr.py b/python/tvm/tirx/expr.py index a97171e436ae..2e01f0b6d556 100644 --- a/python/tvm/tirx/expr.py +++ b/python/tvm/tirx/expr.py @@ -34,7 +34,7 @@ from tvm import ir from tvm.ir import Op, PrimExpr from tvm.ir.base import Span -from tvm.runtime import DataType, DataTypeCode, Object, ObjectConvertible, Scriptable, const +from tvm.runtime import DataTypeCode, Object, ObjectConvertible, Scriptable, const from . import _ffi_api from . import generic as _generic @@ -56,13 +56,17 @@ def div_ambiguity_error() -> RuntimeError: def _dtype_is_int(value): if isinstance(value, int): return True - return isinstance(value, ExprOp) and DataType(value.dtype).type_code == DataTypeCode.INT # type: ignore + if isinstance(value, ExprOp): + return value.expr_ty().matches_code(DataTypeCode.INT) + return False def _dtype_is_float(value): if isinstance(value, float): return True - return isinstance(value, ExprOp) and DataType(value.dtype).type_code == DataTypeCode.FLOAT # type: ignore + if isinstance(value, ExprOp): + return value.expr_ty().matches_code(DataTypeCode.FLOAT) + return False class ExprOp: @@ -70,6 +74,13 @@ class ExprOp: # TODO(tkonolige): use inspect to add source information to these objects + def expr_ty(self) -> ir.PrimType: + """Return the compile-time primitive type for expression operators.""" + ty = getattr(self, "ty", None) + if isinstance(ty, ir.PrimType): + return ty + raise TypeError(f"Cannot determine PrimType for {type(self).__name__}") + def __add__(self, other: PrimExpr) -> PrimExpr: return _generic.add(self, other) @@ -259,6 +270,10 @@ def asobject(self) -> PrimExpr: """Convert object.""" return _ffi_api._OpEQ(self.a, self.b, self.span) # type: ignore + def expr_ty(self) -> ir.PrimType: + """Compile-time type of the equality result.""" + return ir.PrimType("bool") + def __repr__(self) -> str: return f"EqualOp({self.a!r}, {self.b!r})" @@ -299,6 +314,10 @@ def asobject(self) -> PrimExpr: """Convert object.""" return _ffi_api._OpNE(self.a, self.b, self.span) # type: ignore + def expr_ty(self) -> ir.PrimType: + """Compile-time type of the inequality result.""" + return ir.PrimType("bool") + def __repr__(self) -> str: return f"NotEqualOp({self.a!r}, {self.b!r})" @@ -473,6 +492,10 @@ def __init__( span, # type: ignore ) + def expr_ty(self) -> ir.PrimType: + """Compile-time type of the iteration variable.""" + return self.var.ty + @tvm_ffi.register_object("tirx.CommReducer") class CommReducer(Object, Scriptable): @@ -1332,6 +1355,8 @@ def __init__( op = Op.get(op) if isinstance(attrs, dict): attrs = ir.make_node("ir.DictAttrs", **attrs) + if not isinstance(dtype, ir.PrimType): + dtype = ir.PrimType(dtype) if attrs: self.__init_handle_by_constructor__( # type: ignore _ffi_api.CallWithAttrs, dtype, op, args, attrs, span diff --git a/python/tvm/tirx/script/parser/operation.py b/python/tvm/tirx/script/parser/operation.py index dac8f06ebf80..c6cb50f291af 100644 --- a/python/tvm/tirx/script/parser/operation.py +++ b/python/tvm/tirx/script/parser/operation.py @@ -17,7 +17,8 @@ """The tirx expression operation registration""" from tvm import tirx -from tvm.runtime import DataType, DataTypeCode +from tvm.ir import PrimType +from tvm.runtime import DataTypeCode from tvm.script.parser._core import OpMethod, doc, register_op from tvm.tirx import IntImm from tvm.tirx.expr import FloatImm @@ -26,12 +27,18 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name ty._dispatch_type = ty # pylint: disable=protected-access + def _expr_ty(expr): + ty = expr.expr_ty() + if not isinstance(ty, PrimType): + raise TypeError(f"Expected a PrimType expression, but got {ty}") + return ty + def _and(a, b): if isinstance(a, bool): a = IntImm("bool", a) if isinstance(b, bool): b = IntImm("bool", b) - if DataType(a.dtype).lanes > 1 or DataType(b.dtype).lanes > 1: + if not _expr_ty(a).is_scalar() or not _expr_ty(b).is_scalar(): return a & b else: return tirx.And(a, b) @@ -41,58 +48,56 @@ def _or(a, b): a = IntImm("bool", a) if isinstance(b, bool): b = IntImm("bool", b) - if DataType(a.dtype).lanes > 1 or DataType(b.dtype).lanes > 1: + if not _expr_ty(a).is_scalar() or not _expr_ty(b).is_scalar(): return a | b else: return tirx.Or(a, b) - def _get_type_str(dtype: str): - if DataType(dtype).lanes == 1: - return dtype - index = dtype.find("x") - return dtype[0:index] + def _get_type_str(ty: PrimType): + dtype_str = str(ty.dtype) + if ty.is_scalar(): + return dtype_str + index = dtype_str.find("x") + return dtype_str[0:index] def _auto_broadcast(a, b, op): if isinstance(a, int): if hasattr(b, "dtype"): - if ( - DataType(b.dtype).type_code == DataTypeCode.INT - or DataType(b.dtype).type_code == DataTypeCode.UINT - or DataType(b.dtype).type_code == DataTypeCode.BOOL - ): - a = IntImm(_get_type_str(b.dtype), a) - elif DataType(b.dtype).type_code == DataTypeCode.FLOAT: - a = FloatImm(_get_type_str(b.dtype), a) + b_ty = _expr_ty(b) + if b_ty.matches_code(DataTypeCode.INT, DataTypeCode.UINT, DataTypeCode.BOOL): + a = IntImm(_get_type_str(b_ty), a) + elif b_ty.matches_code(DataTypeCode.FLOAT): + a = FloatImm(_get_type_str(b_ty), a) elif isinstance(b, float): a = FloatImm("float32", a) else: a = IntImm("int32", a) elif isinstance(a, float): - if DataType(b.dtype).type_code == DataTypeCode.FLOAT: - a = FloatImm(_get_type_str(b.dtype), a) + b_ty = _expr_ty(b) + if b_ty.matches_code(DataTypeCode.FLOAT): + a = FloatImm(_get_type_str(b_ty), a) else: a = FloatImm("float32", a) assert isinstance(a, tirx.PrimExpr), "Operand should be a PrimExpr." if isinstance(b, int): - if ( - DataType(a.dtype).type_code == DataTypeCode.INT - or DataType(a.dtype).type_code == DataTypeCode.UINT - or DataType(a.dtype).type_code == DataTypeCode.BOOL - ): - b = IntImm(_get_type_str(a.dtype), b) - elif DataType(a.dtype).type_code == DataTypeCode.FLOAT: - b = FloatImm(_get_type_str(a.dtype), b) + a_ty = _expr_ty(a) + if a_ty.matches_code(DataTypeCode.INT, DataTypeCode.UINT, DataTypeCode.BOOL): + b = IntImm(_get_type_str(a_ty), b) + elif a_ty.matches_code(DataTypeCode.FLOAT): + b = FloatImm(_get_type_str(a_ty), b) elif isinstance(b, float): - b = FloatImm(_get_type_str(a.dtype), b) + b = FloatImm(_get_type_str(_expr_ty(a)), b) - if DataType(a.dtype).lanes == DataType(b.dtype).lanes: + a_ty = _expr_ty(a) + b_ty = _expr_ty(b) + if a_ty.dtype.lanes == b_ty.dtype.lanes: return op(a, b) - elif DataType(a.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes: - broadcast_a = tirx.Broadcast(a, DataType(b.dtype).lanes) + elif a_ty.is_scalar() and a_ty.dtype.lanes != b_ty.dtype.lanes: + broadcast_a = tirx.Broadcast(a, b_ty.dtype.lanes) return op(broadcast_a, b) - elif DataType(b.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes: - broadcast_b = tirx.Broadcast(b, DataType(a.dtype).lanes) + elif b_ty.is_scalar() and a_ty.dtype.lanes != b_ty.dtype.lanes: + broadcast_b = tirx.Broadcast(b, a_ty.dtype.lanes) return op(a, broadcast_b) else: raise TypeError("do not know how to deal with it.") diff --git a/python/tvm/topi/math.py b/python/tvm/topi/math.py index d3e8991c85c7..6088c4baa800 100644 --- a/python/tvm/topi/math.py +++ b/python/tvm/topi/math.py @@ -18,7 +18,7 @@ # pylint: disable=redefined-builtin,unused-argument import tvm -from tvm import DataType, DataTypeCode, te +from tvm import DataTypeCode, te from tvm.tirx import PrimExpr from . import cpp, tag @@ -26,11 +26,15 @@ def _require_float_tensor(op_name, x): - if DataType(x.dtype).type_code not in (DataTypeCode.FLOAT, DataTypeCode.BFLOAT): + if not x.dtype.matches_code(DataTypeCode.FLOAT, DataTypeCode.BFLOAT): raise TypeError(f"topi.{op_name} only supports floating-point inputs, but got {x.dtype}") return x +def _is_integer_tensor(x): + return x.dtype.matches_code(DataTypeCode.INT, DataTypeCode.UINT) + + @tvm.te.tag_scope(tag=tag.ELEMWISE) def identity(x): """Take identity of input x. @@ -478,7 +482,7 @@ def log(x): y : tvm.te.Tensor The result. """ - if x.dtype.startswith("int"): + if x.dtype.matches_code(DataTypeCode.INT): x = te.compute(x.shape, lambda *i: x(*i).astype("float32")) return te.compute(x.shape, lambda *i: te.log(x(*i)), tag=tag.ELEMWISE) @@ -496,7 +500,7 @@ def log2(x): y : tvm.te.Tensor The result. """ - if x.dtype.startswith("int"): + if x.dtype.matches_code(DataTypeCode.INT): x = te.compute(x.shape, lambda *i: x(*i).astype("float32")) return te.compute(x.shape, lambda *i: te.log2(x(*i)), tag=tag.ELEMWISE) @@ -514,7 +518,7 @@ def log10(x): y : tvm.te.Tensor The result. """ - if x.dtype.startswith("int"): + if x.dtype.matches_code(DataTypeCode.INT): x = te.compute(x.shape, lambda *i: x(*i).astype("float32")) return te.compute(x.shape, lambda *i: te.log10(x(*i)), tag=tag.ELEMWISE) @@ -533,7 +537,7 @@ def sqrt(x): y : tvm.te.Tensor The result. """ - if x.dtype.startswith("int"): + if x.dtype.matches_code(DataTypeCode.INT): x = te.compute(x.shape, lambda *i: x(*i).astype("float32")) return te.compute(x.shape, lambda *i: te.sqrt(x(*i))) @@ -552,7 +556,7 @@ def rsqrt(x): y : tvm.te.Tensor The result. """ - if x.dtype.startswith("int"): + if x.dtype.matches_code(DataTypeCode.INT): x = te.compute(x.shape, lambda *i: x(*i).astype("float32")) return te.compute(x.shape, lambda *i: te.rsqrt(x(*i))) @@ -798,7 +802,7 @@ def fast_exp(x): y : tvm.te.Tensor The result. """ - if x.dtype.startswith("int") or x.dtype.startswith("uint"): + if _is_integer_tensor(x): x = cast(x, "float32") return cpp.fast_exp(x, x.dtype, tag.ELEMWISE) @@ -816,7 +820,7 @@ def fast_tanh(x): y : tvm.te.Tensor The result. """ - if x.dtype.startswith("int") or x.dtype.startswith("uint"): + if _is_integer_tensor(x): x = cast(x, "float32") return cpp.fast_tanh(x, x.dtype, tag.ELEMWISE) @@ -855,24 +859,26 @@ def ceil_log2(x): if not isinstance(x, tvm.tirx.PrimExpr): x = tvm.tirx.const(x) - if "float" in x.dtype: + if x.ty.matches_code(DataTypeCode.FLOAT, DataTypeCode.BFLOAT): return tvm.tirx.ceil(tvm.tirx.log2(x)) target = tvm.target.Target.current() - if "vulkan" in target.kind.name: - clz = tvm.tirx.clz(x) - bits = int(x.dtype[-2:]) - res = tvm.tirx.if_then_else(x & (x - 1) == 0, bits - clz - 1, bits - clz) - if res.dtype != x.dtype: - return cast(res, x.dtype) - return res - - if "adreno" in str(target.attrs.get("device", "")) or target.kind.name in [ - "metal", - "rocm", - "webgpu", - ]: - return cast(tvm.tirx.ceil(tvm.tirx.log2(cast(x, "float32"))), x.dtype) + if target is not None: + target_name = target.kind.name + if "vulkan" in target_name: + clz = tvm.tirx.clz(x) + bits = x.ty.dtype.bits + res = tvm.tirx.if_then_else(x & (x - 1) == 0, bits - clz - 1, bits - clz) + if res.dtype != x.dtype: + return cast(res, x.dtype) + return res + + if "adreno" in str(target.attrs.get("device", "")) or target_name in [ + "metal", + "rocm", + "webgpu", + ]: + return cast(tvm.tirx.ceil(tvm.tirx.log2(cast(x, "float32"))), x.dtype) return cast(tvm.tirx.ceil(tvm.tirx.log2(cast(x, "float64"))), x.dtype) diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index bf5b86599854..de35577c4d85 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -18,7 +18,7 @@ # ruff: noqa: E741 """ScatterND operator""" -from tvm import te, tirx # hide redefinition of min and max +from tvm import DataTypeCode, te, tirx # hide redefinition of min and max from tvm.arith.analyzer import Analyzer from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import tirx as T @@ -49,7 +49,7 @@ def _verify_scatter_nd_inputs(data, indices, updates): f"of out_shape[{i}] ({data.shape[i]})." ) - assert "int" in indices.dtype, ( + assert indices.dtype.matches_code(DataTypeCode.INT, DataTypeCode.UINT), ( f"Indices must be a tensor of integers, but its elements are {indices.dtype}." ) diff --git a/python/tvm/topi/sort.py b/python/tvm/topi/sort.py index 81821e462dcf..846573db5036 100644 --- a/python/tvm/topi/sort.py +++ b/python/tvm/topi/sort.py @@ -110,7 +110,7 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): f = tvm.compile(s, [data, out], "llvm") dev = tvm.cpu() tvm_data = tvm.runtime.tensor(np_data, dev) - tvm_out = tvm.runtime.tensor(np.zeros(dshape, dtype=data.dtype), dev) + tvm_out = tvm.runtime.tensor(np.zeros(dshape, dtype=data.dtype.dtype), dev) f(tvm_data, tvm_out) """ data_buf = tvm.tirx.decl_buffer( diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index fc59f891e1bf..94eb8788846b 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -73,7 +73,8 @@ void AnalyzerObj::Bind(const Var& var, const Range& range, bool allow_override) void AnalyzerObj::MarkGlobalNonNegValue(const PrimExpr& value) { // decompose value as symbol * scale + offset int64_t offset = 0; - PrimExpr symbol_scale = tirx::MakeConst(value.dtype(), 0); + PrimType value_ty = value.ty(); + PrimExpr symbol_scale = tirx::MakeConst(value_ty, 0); auto fcollect_sum = [&](PrimExpr val, int sign) { if (const auto* intimm = val.as()) { @@ -90,7 +91,7 @@ void AnalyzerObj::MarkGlobalNonNegValue(const PrimExpr& value) { // split out the symbol and non-symbolic part int64_t cscale = 1; - PrimExpr symbol = tirx::MakeConst(value.dtype(), 1); + PrimExpr symbol = tirx::MakeConst(value_ty, 1); auto fcollect_prod = [&](PrimExpr val) { if (const auto* intimm = val.as()) { cscale *= intimm->value; @@ -110,7 +111,7 @@ void AnalyzerObj::MarkGlobalNonNegValue(const PrimExpr& value) { Var var = ffi::GetRef(var_ptr); // skip non-index type, keep it to be compatible // with any_dim that do not represent any value - if (!IsIndexType(var.dtype())) return; + if (!IsIndexTypedExpr(var)) return; bool allow_override = true; // mark the constant bound is sufficient // we cannot mark interval set as that will cause relaxation of the var @@ -169,7 +170,7 @@ bool AnalyzerObj::CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs) { const auto* clhs = lhs.as(); const auto* crhs = rhs.as(); if (clhs && crhs) return clhs->value == crhs->value; - if (lhs->dtype.is_handle() || rhs->dtype.is_handle()) { + if (lhs->ty().IsHandle() || rhs->ty().IsHandle()) { return lhs.same_as(rhs); } return CanProve(lhs - rhs == 0); @@ -189,7 +190,7 @@ bool AnalyzerObj::CanProveLessEqualThanSymbolicShapeValue(const PrimExpr& lhs, } }; UnpackReduction(shape, fcollect); - PrimExpr const_shape_bound = IntImm(shape.dtype(), std::abs(cscale)); + PrimExpr const_shape_bound = IntImm(shape.ty(), std::abs(cscale)); if (this->CanProve(lhs <= const_shape_bound, ProofStrength::kSymbolicBound)) return true; return false; } diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc index 475a687cd462..bceeb4eafa2e 100644 --- a/src/arith/bound_deducer.cc +++ b/src/arith/bound_deducer.cc @@ -96,7 +96,8 @@ class BoundDeducer : public ExprFunctor { void VisitExprDefault_(const ffi::Object* op) final { success_ = false; } SignType GetSignType(const PrimExpr& e) { - if (e.dtype().is_uint()) { + PrimType e_ty = e.ty(); + if (e_ty.code() == DLDataTypeCode::kDLUInt) { return kPositive; } return expr_map_[e].GetSignType(); diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 12344cffd1d8..17a6ba022e2b 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -83,14 +83,14 @@ inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) { * \param analyzer The analyzer * \return whether value fits in dtype */ -bool CastIsSafe(DataType dtype, PrimExpr value, AnalyzerObj* analyzer) { - if (!IsIndexType(dtype)) { +bool CastIsSafe(PrimType dtype, PrimExpr value, AnalyzerObj* analyzer) { + if (!IsIndexType(dtype->dtype)) { return false; } ConstIntBound bound = analyzer->const_int_bound(value); int64_t ubound = max_value(dtype).as_or_throw()->value; int64_t lbound = min_value(dtype).as_or_throw()->value; - if (value.dtype().bits() <= dtype.bits() || // upcast is safe + if (value.ty().bits() <= dtype.bits() || // upcast is safe (bound->max_value <= ubound && bound->min_value >= lbound)) { return true; } @@ -128,7 +128,7 @@ class SplitExprNode : public CanonicalExprNode { PrimExpr NormalizeWithScale(int64_t sscale) const { PrimExpr res = this->index; - DataType dtype = this->dtype; + PrimType dtype = this->ty(); if (this->scale == 0) { return IntImm(dtype, 0); } @@ -140,7 +140,7 @@ class SplitExprNode : public CanonicalExprNode { } sscale *= this->scale; if (sscale != 1) { - TVM_FFI_ICHECK(!dtype.is_uint() || sscale > 0); + TVM_FFI_ICHECK(dtype.code() != DLDataTypeCode::kDLUInt || sscale > 0); res = res * MakeConst(dtype, sscale); } return res; @@ -156,12 +156,12 @@ class SplitExprNode : public CanonicalExprNode { * \param analyzer The analyzer * \return whether the cast can be safely pushed to children */ - bool CanPushCastToChildren(DataType dtype, AnalyzerObj* analyzer) const { + bool CanPushCastToChildren(PrimType dtype, AnalyzerObj* analyzer) const { // cast(dtype, index % upper_factor / lower_factor * scale) == // cast(dtype, index) % upper_factor / lower_factor * scale // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of // its intermediate results fit in the range of dtype - if (dtype.bits() >= this->dtype.bits()) { + if (dtype.bits() >= this->ty().bits()) { return true; // upcast is safe } PrimExpr res = this->index; @@ -172,20 +172,20 @@ class SplitExprNode : public CanonicalExprNode { return false; } if (this->upper_factor != SplitExprNode::kPosInf) { - res = ModImpl(res, MakeConst(this->dtype, this->upper_factor), div_mode); + res = ModImpl(res, MakeConst(this->ty(), this->upper_factor), div_mode); if (!CastIsSafe(dtype, res, analyzer)) { return false; } } if (this->lower_factor != 1) { - res = DivImpl(res, MakeConst(this->dtype, this->lower_factor), div_mode); + res = DivImpl(res, MakeConst(this->ty(), this->lower_factor), div_mode); if (!CastIsSafe(dtype, res, analyzer)) { return false; } } if (this->scale != 1) { - TVM_FFI_ICHECK(!this->dtype.is_uint() || this->scale > 0); - res = res * MakeConst(this->dtype, this->scale); + TVM_FFI_ICHECK(this->ty().code() != DLDataTypeCode::kDLUInt || this->scale > 0); + res = res * MakeConst(this->ty(), this->scale); if (!CastIsSafe(dtype, res, analyzer)) { return false; } @@ -197,9 +197,9 @@ class SplitExprNode : public CanonicalExprNode { * \brief self = cast(dtype, self) * \param dtype The target datatype */ - void PushCastToChildren(DataType dtype) { + void PushCastToChildren(PrimType dtype) { this->index = cast(dtype, this->index); - this->dtype = dtype; + this->BaseExprNode::ty = dtype; } inline bool IndexEqual(const SplitExpr& other) const; @@ -252,9 +252,9 @@ class SumExprNode : public CanonicalExprNode { PrimExpr Normalize() const final { // quick path 1. if (this->args.size() == 0) { - return MakeConst(this->dtype, this->base); + return MakeConst(this->ty(), this->base); } - return Normalize_(this->dtype, SimplifySplitExprs(args), base); + return Normalize_(this->ty(), SimplifySplitExprs(args), base); } /*! * \brief Whether self is divisible by scale. @@ -334,14 +334,14 @@ class SumExprNode : public CanonicalExprNode { * \param analyzer The analyzer * \return whether the cast can be safely pushed to children */ - bool CanPushCastToChildren(DataType dtype, AnalyzerObj* analyzer) const { + bool CanPushCastToChildren(PrimType dtype, AnalyzerObj* analyzer) const { bool is_min_value = dtype.bits() == 64 ? base == std::numeric_limits::lowest() : base == -(1LL << (dtype.bits() - 1)); // cast(dtype, arg_1 + arg_2 + ... arg_n) == // cast(dtype, arg_1) + ... + cast(dtype, arg_n) // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of // its intermediate results fit in the range of dtype - if (dtype.bits() >= this->dtype.bits()) { + if (dtype.bits() >= this->ty().bits()) { return true; // upcast is safe } PrimExpr res = IntImm(dtype, 0); @@ -386,11 +386,11 @@ class SumExprNode : public CanonicalExprNode { * \brief self = cast(dtype, self) * \param dtype The target datatype */ - void PushCastToChildren(DataType dtype) { + void PushCastToChildren(PrimType dtype) { for (auto& arg : args) { arg.CopyOnWrite()->PushCastToChildren(dtype); } - this->dtype = dtype; + this->BaseExprNode::ty = dtype; } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.SumExpr", SumExprNode, CanonicalExprNode); @@ -496,7 +496,7 @@ class SumExprNode : public CanonicalExprNode { std::stable_sort(args.begin(), args.end(), fcompare); return args; } - static PrimExpr Normalize_(DataType dtype, const std::vector& args, int64_t base) { + static PrimExpr Normalize_(PrimType dtype, const std::vector& args, int64_t base) { bool is_min_value = dtype.bits() == 64 ? base == std::numeric_limits::lowest() : base == -(1LL << (dtype.bits() - 1)); // Positive scales first @@ -648,7 +648,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { expr = op->Normalize(); } ffi::ObjectPtr n = ffi::make_object(); - n->dtype = expr.dtype(); + n->BaseExprNode::ty = expr.ty(); n->index = std::move(expr); n->div_mode = kTruncDiv; return SplitExpr(n); @@ -685,7 +685,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { return op.value(); } ffi::ObjectPtr n = ffi::make_object(); - n->dtype = expr.dtype(); + n->BaseExprNode::ty = expr.ty(); if (const auto* op = expr.as()) { n->base = op->value; return SumExpr(n); @@ -699,7 +699,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { }; PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const AddNode* op) { - if (!IsIndexType(op->dtype)) { + if (!IsIndexTypedExpr(op)) { return Rewriter::VisitExpr_(op); } // normalize @@ -723,7 +723,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const AddNode* op) { } PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const SubNode* op) { - if (!IsIndexType(op->dtype)) { + if (!IsIndexTypedExpr(op)) { return Rewriter::VisitExpr_(op); } // normalize @@ -747,7 +747,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const SubNode* op) { } PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) { - if (!IsIndexType(op->dtype)) { + if (!IsIndexTypedExpr(op)) { return Rewriter::VisitExpr_(op); } // normalize @@ -794,8 +794,8 @@ void CanonicalSimplifier::Impl::SeparateDivisibleParts(const SumExprNode* psum, SumExpr* out_non_divisible) { auto divisible = ffi::make_object(); auto non_divisible = ffi::make_object(); - divisible->dtype = psum->dtype; - non_divisible->dtype = psum->dtype; + divisible->BaseExprNode::ty = psum->ty(); + non_divisible->BaseExprNode::ty = psum->ty(); if (psum->base % coeff == 0) { divisible->base = psum->base; @@ -834,11 +834,11 @@ SplitExpr CanonicalSimplifier::Impl::SplitDivConst(SplitExpr lhs, int64_t cval, return lhs; } else if (lhs->upper_factor <= (lhs->lower_factor * scaled_cval)) { // (x % c1) / c2 => 0 when c2 >= c1 - return ToSplitExpr(IntImm(lhs.dtype(), 0)); + return ToSplitExpr(IntImm(lhs.ty(), 0)); } else { // move the upper_factor modular into index. lhs.CopyOnWrite()->index = - ModImpl(lhs->index, MakeConst(lhs.dtype(), lhs->upper_factor), div_mode); + ModImpl(lhs->index, MakeConst(lhs.ty(), lhs->upper_factor), div_mode); lhs.CopyOnWrite()->upper_factor = SplitExprNode::kPosInf; lhs.CopyOnWrite()->scale = 1; lhs.CopyOnWrite()->lower_factor *= scaled_cval; @@ -862,8 +862,9 @@ bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr* plhs, PrimExpr* prhs, if (prhs->as()) return false; // collect lhs products and try to eliminate by matching them to prod in rhs ffi::Array> lhs_prods; - PrimExpr new_rhs = MakeConst(prhs->dtype(), 1); - PrimExpr new_common_scale = MakeConst(prhs->dtype(), 1); + PrimType rhs_ty = prhs->ty(); + PrimExpr new_rhs = MakeConst(rhs_ty, 1); + PrimExpr new_common_scale = MakeConst(rhs_ty, 1); int64_t lhs_cscale = 1, rhs_cscale = 1; int num_elimination = 0; @@ -905,18 +906,19 @@ bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr* plhs, PrimExpr* prhs, if (num_elimination == 0 && cscale_gcd == 1) return false; // construct prod via canonical form - PrimExpr new_lhs = MakeConst(plhs->dtype(), 1); + PrimType lhs_ty = plhs->ty(); + PrimExpr new_lhs = MakeConst(lhs_ty, 1); for (ffi::Optional val : lhs_prods) { if (val.defined()) new_lhs = new_lhs * val.value(); } - *plhs = new_lhs * MakeConst(plhs->dtype(), lhs_cscale); - *prhs = new_rhs * MakeConst(prhs->dtype(), rhs_cscale); - *common_scale = new_common_scale * MakeConst(prhs->dtype(), cscale_gcd); + *plhs = new_lhs * MakeConst(lhs_ty, lhs_cscale); + *prhs = new_rhs * MakeConst(rhs_ty, rhs_cscale); + *common_scale = new_common_scale * MakeConst(rhs_ty, cscale_gcd); return true; } PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) { - if (!IsIndexType(op->dtype)) { + if (!IsIndexTypedExpr(op)) { return Rewriter::VisitExpr_(op); } @@ -958,7 +960,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) { // if a >= 0 && a < cval, then result == 0 auto cbound = analyzer_->const_int_bound(Normalize(a)); if (cbound->min_value >= 0 && cbound->max_value < cval) { - return IntImm(a.dtype(), 0); + return IntImm(a.ty(), 0); } } return SplitDivConst(ToSplitExpr(std::move(a)), cval, kTruncDiv); @@ -980,7 +982,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) { } PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { - if (!IsIndexType(op->dtype)) { + if (!IsIndexTypedExpr(op)) { return Rewriter::VisitExpr_(op); } PrimExpr a = this->CanonicalMutate(op->a); @@ -1019,7 +1021,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { // if a >= 0 && a < cval, then result == 0 auto cbound = analyzer_->const_int_bound(Normalize(a)); if (cbound->min_value >= 0 && cbound->max_value < cval) { - return IntImm(a.dtype(), 0); + return IntImm(a.ty(), 0); } } // Identity: floordiv(floormod(index, m*n), n) = floormod(floordiv(index, n), m) @@ -1049,7 +1051,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { } // Apply floormod(floordiv_result, m) to complete the identity PrimExpr div_result = Normalize(lhs); - return this->VisitExpr(floormod(div_result, MakeConst(a.dtype(), new_mod))); + return this->VisitExpr(floormod(div_result, MakeConst(a.ty(), new_mod))); } } } @@ -1095,8 +1097,8 @@ SplitExpr CanonicalSimplifier::Impl::SplitModConst(SplitExpr lhs, int64_t cval, // Perhaps there are more chances in simplifying the index // Do a recursive call to simplify the mod with the new factor. if (new_upper_factor < lhs->upper_factor && lhs->upper_factor != SplitExprNode::kPosInf) { - auto updated = ToSplitExpr(this->VisitExpr( - ModImpl(lhs->index, MakeConst(lhs.dtype(), new_upper_factor), div_mode))); + auto updated = ToSplitExpr( + this->VisitExpr(ModImpl(lhs->index, MakeConst(lhs.ty(), new_upper_factor), div_mode))); // re-apply the lower_factor if (lhs->lower_factor != 1) { auto ret = SplitDivConst(updated, lhs->lower_factor, div_mode); @@ -1126,7 +1128,7 @@ SplitExpr CanonicalSimplifier::Impl::SplitModConst(SplitExpr lhs, int64_t cval, } PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) { - if (!IsIndexType(op->dtype)) { + if (!IsIndexTypedExpr(op)) { return Rewriter::VisitExpr_(op); } // normalize @@ -1144,7 +1146,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) { SumExpr lhs, extra; SeparateDivisibleParts(psum, cval, &lhs, &extra); if (extra->IsZero()) { - return IntImm(a.dtype(), 0); + return IntImm(a.ty(), 0); } // both lhs and extra are non-negative if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) && @@ -1200,7 +1202,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) { } PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) { - if (!IsIndexType(op->dtype)) { + if (!IsIndexTypedExpr(op)) { return Rewriter::VisitExpr_(op); } // normalize @@ -1362,7 +1364,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) { } PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) { - if (!IsIndexType(op->dtype)) { + if (!IsIndexTypedExpr(op)) { return Rewriter::VisitExpr_(op); } // normalize @@ -1370,15 +1372,15 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) { // PushCastToChildren if (value.as()) { SumExpr se = value.as_or_throw(); - if (se->CanPushCastToChildren(op->dtype, analyzer_)) { - se.CopyOnWrite()->PushCastToChildren(op->dtype); + if (se->CanPushCastToChildren(op->ty(), analyzer_)) { + se.CopyOnWrite()->PushCastToChildren(op->ty()); return se; } } if (value.as()) { SplitExpr se = value.as_or_throw(); - if (se->CanPushCastToChildren(op->dtype, analyzer_)) { - se.CopyOnWrite()->PushCastToChildren(op->dtype); + if (se->CanPushCastToChildren(op->ty(), analyzer_)) { + se.CopyOnWrite()->PushCastToChildren(op->ty()); return se; } } @@ -1411,8 +1413,8 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const LTNode* op) { } SumExpr divisible, extra; SeparateDivisibleParts(lhs, gcd, &divisible, &extra); - DataType dtype = divisible->dtype; - TVM_FFI_ICHECK(extra->dtype == dtype); + PrimType dtype = divisible->ty(); + TVM_FFI_ICHECK(extra->ty()->dtype == dtype->dtype); PrimExpr normal_extra = extra->Normalize(); if (this->analyzer_->CanProve(normal_extra < MakeConst(dtype, gcd)) && this->analyzer_->CanProve(normal_extra >= IntImm(dtype, 0))) { diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index fb1055660e3b..ed1fc2d1a7a6 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -72,18 +72,29 @@ inline ffi::Optional TryConstFold(PrimExpr a); * \param type The type to represent index. * \return the checked result. */ -inline bool IsIndexType(const DataType& type) { - return type.is_int() && !type.is_scalable_or_fixed_length_vector() && - (type.bits() == 32 || type.bits() == 64); +inline bool IsIndexType(DLDataType type) { + return type.code == static_cast(DLDataTypeCode::kDLInt) && + (type.bits == 32 || type.bits == 64) && type.lanes == 1; +} + +inline bool IsIndexTypedExpr(const PrimExprNode* expr) { + TVM_FFI_DCHECK(expr != nullptr); + TVM_FFI_DCHECK(expr->BaseExprNode::ty.defined()); + const auto* prim_ty = expr->BaseExprNode::ty.as(); + TVM_FFI_DCHECK(prim_ty != nullptr); + return IsIndexType(prim_ty->dtype); +} + +inline bool IsIndexTypedExpr(const PrimExpr& expr) { + return IsIndexTypedExpr(static_cast(expr.get())); } /*! \brief Helper to get const folding result repr in int64. */ -inline int64_t GetFoldResultInt64Repr(int64_t x, const DataType& dtype) { +inline int64_t GetFoldResultInt64Repr(int64_t x, const PrimType& dtype) { if (dtype.bits() < 64) { x &= (1LL << dtype.bits()) - 1; } - if (dtype.is_int()) { - // get sign extended value of integer with specified bits + if (dtype.code() == DLDataTypeCode::kDLInt) { int64_t m = 1LL << (dtype.bits() - 1); x = (x ^ m) - m; } @@ -118,32 +129,30 @@ inline double GetFoldResultDoubleRepr(float x) { const FloatImmNode* fb = b.as(); \ BODY; -#define TVM_INDEX_CONST_PROPAGATION(BODY) \ - const IntImmNode* pa = a.as(); \ - const IntImmNode* pb = b.as(); \ - const DataType& ta = a.dtype(); \ - const DataType& tb = b.dtype(); \ - if (arith::IsIndexType(ta) && arith::IsIndexType(tb)) { \ - BODY; \ +#define TVM_INDEX_CONST_PROPAGATION(BODY) \ + const IntImmNode* pa = a.as(); \ + const IntImmNode* pb = b.as(); \ + if (arith::IsIndexTypedExpr(a) && arith::IsIndexTypedExpr(b)) { \ + BODY; \ } // specialization of constant folders. template <> inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); + PrimType result_ty = a.ty(); if (pa && pb) { int64_t res = pa->value + pb->value; - return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); + return IntImm(result_ty, GetFoldResultInt64Repr(res, result_ty)); } if (pa && pa->value == 0) return b; if (pb && pb->value == 0) return a; if (fa && fb) { - if (rtype.bits() == 32) { - return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast(fa->value) + - static_cast(fb->value))); - } else if (rtype.bits() == 64) { - return FloatImm(rtype, fa->value + fb->value); + if (result_ty.bits() == 32) { + return FloatImm(result_ty, GetFoldResultDoubleRepr(static_cast(fa->value) + + static_cast(fb->value))); + } else if (result_ty.bits() == 64) { + return FloatImm(result_ty, fa->value + fb->value); } } if (fa && fa->value == 0) return b; @@ -155,22 +164,22 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { template <> inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - TVM_FFI_ICHECK(!((pa && pa->dtype.is_uint() && pa->value == 0U) && - (pb && pb->dtype.is_uint() && pb->value > 0U))) + TVM_FFI_ICHECK(!((pa && pa->ty().code() == DLDataTypeCode::kDLUInt && pa->value == 0U) && + (pb && pb->ty().code() == DLDataTypeCode::kDLUInt && pb->value > 0U))) << "Checked failed. Minuend 's value is 0U and it's dtype is uint " << "while Subtrahend's dtype is uint; which will cause a negative uint"; - const DataType& rtype = a.dtype(); + PrimType result_ty = a.ty(); if (pa && pb) { int64_t res = pa->value - pb->value; - return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); + return IntImm(result_ty, GetFoldResultInt64Repr(res, result_ty)); } if (pb && pb->value == 0) return a; if (fa && fb) { - if (rtype.bits() == 32) { - return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast(fa->value) - - static_cast(fb->value))); - } else if (rtype.bits() == 64) { - return FloatImm(rtype, fa->value - fb->value); + if (result_ty.bits() == 32) { + return FloatImm(result_ty, GetFoldResultDoubleRepr(static_cast(fa->value) - + static_cast(fb->value))); + } else if (result_ty.bits() == 64) { + return FloatImm(result_ty, fa->value - fb->value); } } if (fb && fb->value == 0) return a; @@ -181,10 +190,10 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { template <> inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); + PrimType result_ty = a.ty(); if (pa && pb) { int64_t res = pa->value * pb->value; - return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); + return IntImm(result_ty, GetFoldResultInt64Repr(res, result_ty)); } if (pa) { if (pa->value == 1) return b; @@ -195,11 +204,11 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { if (pb->value == 0) return b; } if (fa && fb) { - if (rtype.bits() == 32) { - return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast(fa->value) * - static_cast(fb->value))); - } else if (rtype.bits() == 64) { - return FloatImm(rtype, fa->value * fb->value); + if (result_ty.bits() == 32) { + return FloatImm(result_ty, GetFoldResultDoubleRepr(static_cast(fa->value) * + static_cast(fb->value))); + } else if (result_ty.bits() == 64) { + return FloatImm(result_ty, fa->value * fb->value); } } if (fa) { @@ -217,13 +226,13 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { template <> inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); + PrimType result_ty = a.ty(); if (pa && pb) { // due to division and mod can have different modes // NOTE: this will assumes truc div. TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; int64_t res = pa->value / pb->value; - return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); + return IntImm(result_ty, GetFoldResultInt64Repr(res, result_ty)); } if (pa) { if (pa->value == 0) return a; @@ -234,11 +243,11 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { } if (fa && fb) { TVM_FFI_ICHECK_NE(fb->value, 0) << "Divide by zero"; - if (rtype.bits() == 32) { - return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast(fa->value) / - static_cast(fb->value))); - } else if (rtype.bits() == 64) { - return FloatImm(rtype, fa->value / fb->value); + if (result_ty.bits() == 32) { + return FloatImm(result_ty, GetFoldResultDoubleRepr(static_cast(fa->value) / + static_cast(fb->value))); + } else if (result_ty.bits() == 64) { + return FloatImm(result_ty, fa->value / fb->value); } } if (fa && fa->value == 0) return a; @@ -253,18 +262,18 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { template <> inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); + PrimType result_ty = a.ty(); if (pa && pb) { TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; int64_t res = pa->value % pb->value; - return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); + return IntImm(result_ty, GetFoldResultInt64Repr(res, result_ty)); } if (pa) { if (pa->value == 0) return a; } if (pb) { // MakeConst can handle both vector and scalar types. - if (pb->value == 1) return tirx::MakeConst(rtype, 0); + if (pb->value == 1) return tirx::MakeConst(result_ty, 0); TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; } }); @@ -274,11 +283,11 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { template <> inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); + PrimType result_ty = a.ty(); if (pa && pb) { TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; int64_t res = arith::floordiv(pa->value, pb->value); - return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); + return IntImm(result_ty, GetFoldResultInt64Repr(res, result_ty)); } if (pa) { if (pa->value == 0) return a; @@ -288,11 +297,12 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; } if (fa && fb && fb->value != 0) { - if (rtype.bits() == 32) { - return FloatImm(rtype, GetFoldResultDoubleRepr(std::floor(static_cast(fa->value) / - static_cast(fb->value)))); - } else if (rtype.bits() == 64) { - return FloatImm(rtype, std::floor(fa->value / fb->value)); + if (result_ty.bits() == 32) { + return FloatImm(result_ty, + GetFoldResultDoubleRepr(std::floor(static_cast(fa->value) / + static_cast(fb->value)))); + } else if (result_ty.bits() == 64) { + return FloatImm(result_ty, std::floor(fa->value / fb->value)); } else { return std::nullopt; } @@ -309,18 +319,18 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr template <> inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); + PrimType result_ty = a.ty(); if (pa && pb) { TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; int64_t res = arith::floormod(pa->value, pb->value); - return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); + return IntImm(result_ty, GetFoldResultInt64Repr(res, result_ty)); } if (pa) { if (pa->value == 0) return a; } if (pb) { // MakeConst can handle both vector and scalar types. - if (pb->value == 1) return tirx::MakeConst(rtype, 0); + if (pb->value == 1) return tirx::MakeConst(result_ty, 0); TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; } }); @@ -330,9 +340,9 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr template <> inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value)); - if (fa && fb) return FloatImm(rtype, std::min(fa->value, fb->value)); + PrimType result_ty = a.ty(); + if (pa && pb) return IntImm(result_ty, std::min(pa->value, pb->value)); + if (fa && fb) return FloatImm(result_ty, std::min(fa->value, fb->value)); }); if (a.same_as(b)) return a; return std::nullopt; @@ -341,9 +351,9 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { template <> inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value)); - if (fa && fb) return FloatImm(rtype, std::max(fa->value, fb->value)); + PrimType result_ty = a.ty(); + if (pa && pb) return IntImm(result_ty, std::max(pa->value, pb->value)); + if (fa && fb) return FloatImm(result_ty, std::max(fa->value, fb->value)); }); if (a.same_as(b)) return a; return std::nullopt; diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 4d700564ea05..3e8087af0eff 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -151,7 +151,7 @@ class ConstIntBoundAnalyzer::Impl // Override visitor behaviors Entry VisitExprDefault_(const ffi::Object* op) final { - return Everything(static_cast(op)->dtype); + return Everything(static_cast(op)->ty()); } Entry VisitExpr(const PrimExpr& expr) final { @@ -167,7 +167,7 @@ class ConstIntBoundAnalyzer::Impl if (bound_) { auto val = bound_->find(expr); if (val != bound_->end()) { - auto everything = Everything(expr->dtype); + auto everything = Everything(expr->ty()); TVM_FFI_ICHECK( (val->second->min_value == res.min_value && val->second->max_value == res.max_value) || (val->second->min_value == everything.min_value && @@ -203,7 +203,7 @@ class ConstIntBoundAnalyzer::Impl a = VisitExpr(op->value); } - Entry b = Everything(op->dtype); + Entry b = Everything(op->ty()); return Intersect(a, b); } @@ -263,7 +263,7 @@ class ConstIntBoundAnalyzer::Impl Entry VisitExpr_(const DivNode* op) final { Entry a = VisitExpr(op->a); Entry b = AssumeNoZeroDivisor(VisitExpr(op->b)); - return HandleDivision(a, b, op->dtype, InfAwareDiv); + return HandleDivision(a, b, op->ty(), InfAwareDiv); } Entry VisitExpr_(const ModNode* op) final { @@ -312,14 +312,14 @@ class ConstIntBoundAnalyzer::Impl TVM_FFI_ICHECK(!b.is_const(0)) << "mod by zero"; // mod by negative value is rare, // and we just use the simpliest rule. - return Everything(op->dtype); + return Everything(op->ty()); } } Entry VisitExpr_(const FloorDivNode* op) final { Entry a = VisitExpr(op->a); Entry b = AssumeNoZeroDivisor(VisitExpr(op->b)); - return HandleDivision(a, b, op->dtype, InfAwareFloorDiv); + return HandleDivision(a, b, op->ty(), InfAwareFloorDiv); } Entry VisitExpr_(const FloorModNode* op) final { @@ -385,7 +385,7 @@ class ConstIntBoundAnalyzer::Impl int64_t b_max_cap = InfAwareAdd(b.max_value, -1); return Intersect(MakeBound(std::min(static_cast(0), b_min_cap), std::max(static_cast(0), b_max_cap)), - Everything(op->dtype)); + Everything(op->ty())); } } @@ -424,7 +424,7 @@ class ConstIntBoundAnalyzer::Impl } else if (op->op.same_as(tirx::builtin::bitwise_and())) { return VisitBitwiseAnd(op); } else { - return Everything(op->dtype); + return Everything(op->ty()); } } @@ -434,7 +434,7 @@ class ConstIntBoundAnalyzer::Impl if (it != var_map_.end()) { return it->second; } else { - return Everything(op->dtype); + return Everything(op->ty()); } } @@ -456,7 +456,7 @@ class ConstIntBoundAnalyzer::Impl // If either operand can negative, we may run into undefined // behavior for some targets. In these cases, avoid making any // assumptions about the result. - return Everything(op->dtype); + return Everything(op->ty()); } return BinaryOpBoundary(a, b, InfAwareLeftShift); @@ -481,7 +481,7 @@ class ConstIntBoundAnalyzer::Impl if (a.min_value >= 0) { return MakeBound(0, a.max_value); } - return Everything(op->dtype); + return Everything(op->ty()); } } @@ -549,7 +549,7 @@ class ConstIntBoundAnalyzer::Impl * \return The result. */ template - static Entry HandleDivision(Entry a, Entry b, DataType dt, const F& op) { + static Entry HandleDivision(Entry a, Entry b, PrimType dt, const F& op) { // Here we have a / b. // The largest value of the division will be for the smallest (with // respect to the absolute value) value of b. If the range of b starts @@ -557,7 +557,7 @@ class ConstIntBoundAnalyzer::Impl // be closer to 0, because BinaryOpBoundary only checks end-points of // the domain ranges. // If the range of b contains 0, then some infinity will be involved - if (b.min_value <= 0 && 0 <= b.max_value && dt.is_int()) { + if (b.min_value <= 0 && 0 <= b.max_value && dt.code() == DLDataTypeCode::kDLInt) { Entry b_neg = b.min_value < 0 ? MakeBound(b.min_value, -1) : Everything(dt); Entry b_pos = b.max_value > 0 ? MakeBound(1, b.max_value) : Everything(dt); @@ -566,7 +566,7 @@ class ConstIntBoundAnalyzer::Impl return MakeBound(std::min(e_neg.min_value, e_pos.min_value), std::max(e_neg.max_value, e_pos.max_value)); - } else if (b.min_value == 0 && dt.is_uint()) { + } else if (b.min_value == 0 && dt.code() == DLDataTypeCode::kDLUInt) { // uints only have one sided bounds Entry assumed_b = MakeBound(1, b.max_value); return BinaryOpBoundary(a, assumed_b, op); @@ -727,16 +727,17 @@ class ConstIntBoundAnalyzer::Impl * \param dtype The data type. * \return Bound that represent everything dtype can represent. */ - static Entry Everything(DataType dtype) { - if (!dtype.is_int() && !dtype.is_uint() && !dtype.is_bool()) { + static Entry Everything(PrimType dtype) { + if (dtype.code() != DLDataTypeCode::kDLInt && dtype.code() != DLDataTypeCode::kDLUInt && + dtype.code() != DLDataTypeCode::kDLBool) { return MakeBound(kNegInf, kPosInf); } - if (dtype.is_bool()) { + if (dtype.code() == DLDataTypeCode::kDLBool) { return MakeBound(0, 1); } Entry ret; - int64_t vbits = dtype.bits() - static_cast(dtype.is_int()); - if (dtype.is_uint()) { + int64_t vbits = dtype.bits() - static_cast(dtype.code() == DLDataTypeCode::kDLInt); + if (dtype.code() == DLDataTypeCode::kDLUInt) { ret.min_value = 0; } else { if (vbits >= 63) { @@ -800,7 +801,7 @@ class ConstIntBoundAnalyzer::Impl static ffi::Optional FindCeilLog2Arg(const CastNode* op) { static const Op& ceil_op = Op::Get("tirx.ceil"); static const Op& log2_op = Op::Get("tirx.log2"); - if (op->dtype.is_int()) { + if (op->ty().code() == DLDataTypeCode::kDLInt) { if (auto as_call = op->value.as()) { if (as_call->op.same_as(ceil_op)) { PrimExpr ceil_arg = as_call->args[0]; diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index 5e77dca59405..f7e04ee0ebf5 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -54,10 +54,10 @@ class LinearEqDetector : public ExprFunctorbase.defined()) { - ret->base = IntImm(var_.dtype(), 0); + ret->base = IntImm(var_.ty(), 0); } if (!ret->coeff.defined()) { - ret->coeff = IntImm(var_.dtype(), 0); + ret->coeff = IntImm(var_.ty(), 0); } return true; } @@ -101,8 +101,8 @@ class LinearEqDetector : public ExprFunctordtype; - ret.coeff = MakeConst(DataType::Int(dtype.bits(), dtype.lanes()), 1); + PrimType dtype = op->ty(); + ret.coeff = MakeConst(PrimType::Int(dtype.bits(), dtype.lanes()), 1); } else { ret.base = e; } @@ -194,19 +194,21 @@ bool DetectClipBound(const PrimExpr& cond, bool is_eq = false; PrimExpr canonical; if (const LTNode* op = cond.as()) { - if (!op->a.dtype().is_int()) return false; - canonical = op->b - op->a - MakeConst(op->a.dtype(), 1); + PrimType a_ty = op->a.ty(); + if (a_ty.code() != DLDataTypeCode::kDLInt) return false; + canonical = op->b - op->a - MakeConst(a_ty, 1); } else if (const LENode* op = cond.as()) { - if (!op->a.dtype().is_int()) return false; + if (op->a.ty().code() != DLDataTypeCode::kDLInt) return false; canonical = op->b - op->a; } else if (const GTNode* op = cond.as()) { - if (!op->a.dtype().is_int()) return false; - canonical = op->a - op->b - MakeConst(op->a.dtype(), 1); + PrimType a_ty = op->a.ty(); + if (a_ty.code() != DLDataTypeCode::kDLInt) return false; + canonical = op->a - op->b - MakeConst(a_ty, 1); } else if (const GENode* op = cond.as()) { - if (!op->a.dtype().is_int()) return false; + if (op->a.ty().code() != DLDataTypeCode::kDLInt) return false; canonical = op->a - op->b; } else if (const EQNode* op = cond.as()) { - if (!op->a.dtype().is_int()) return false; + if (op->a.ty().code() != DLDataTypeCode::kDLInt) return false; canonical = op->a - op->b; is_eq = true; } else { diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index 55db4fc774b6..bcd957aac0f2 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -74,7 +74,9 @@ ffi::Array AsConditions(const ffi::Array& variables, IntGroupBounds::IntGroupBounds(PrimExpr coef, ffi::Array lower, ffi::Array equal, ffi::Array upper) { - TVM_FFI_ICHECK(coef.dtype().is_int() || coef.dtype().is_uint()) + PrimType coef_ty = coef.ty(); + TVM_FFI_ICHECK(coef_ty.code() == DLDataTypeCode::kDLInt || + coef_ty.code() == DLDataTypeCode::kDLUInt) << "Coefficient in IntGroupBounds must be integers"; ffi::ObjectPtr node = ffi::make_object(); node->coef = std::move(coef); @@ -86,7 +88,7 @@ IntGroupBounds::IntGroupBounds(PrimExpr coef, ffi::Array lower, IntGroupBounds IntGroupBounds::FromRange(const Range& r) { Analyzer analyzer; - PrimExpr coef = tirx::MakeConst(r->min.dtype(), 1); + PrimExpr coef = tirx::MakeConst(r->min.ty(), 1); ffi::Array equal; ffi::Array lower; ffi::Array upper; @@ -232,7 +234,9 @@ IntConstraints::IntConstraints(ffi::Array variables, ffi::Map r } TVM_FFI_ICHECK(relations.defined()); for (const auto& var : variables) { - TVM_FFI_ICHECK(var.dtype().is_int() || var.dtype().is_uint()) + PrimType var_ty = var.ty(); + TVM_FFI_ICHECK(var_ty.code() == DLDataTypeCode::kDLInt || + var_ty.code() == DLDataTypeCode::kDLUInt) << "Variables in IntConstraints must be integers"; } node->variables = std::move(variables); diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index d7bf32442497..ac966582e766 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -50,8 +50,8 @@ using tirx::MakeConst; TVM_FFI_STATIC_INIT_BLOCK() { IntervalSetNode::RegisterReflection(); } -PrimExpr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle()); -PrimExpr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle()); +PrimExpr SymbolicLimits::pos_inf_ = Var("pos_inf", PrimType::Handle()); +PrimExpr SymbolicLimits::neg_inf_ = Var("neg_inf", PrimType::Handle()); IntervalSet::IntervalSet(PrimExpr min_value, PrimExpr max_value) { auto node = ffi::make_object(); @@ -72,8 +72,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { IntervalSet Intersect(AnalyzerObj* analyzer, IntervalSet a, IntervalSet b) { PrimExpr max_value = min(a->max_value, b->max_value); PrimExpr min_value = max(a->min_value, b->min_value); - if ((max_value.dtype().is_int() || max_value.dtype().is_uint()) && - (min_value.dtype().is_int() || min_value.dtype().is_uint()) && + PrimType max_ty = max_value.ty(); + PrimType min_ty = min_value.ty(); + if ((max_ty.code() == DLDataTypeCode::kDLInt || max_ty.code() == DLDataTypeCode::kDLUInt) && + (min_ty.code() == DLDataTypeCode::kDLInt || min_ty.code() == DLDataTypeCode::kDLUInt) && analyzer->CanProve(max_value < min_value)) { return IntervalSet::Empty(); } else { @@ -121,7 +123,7 @@ TVM_DECLARE_LOGICAL_OP(Not); */ template inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, IntervalSet b, const OpNode* op) { - DataType dtype = op->dtype; + PrimType dtype = op->ty(); if (a->IsSinglePoint() && b->IsSinglePoint()) { PrimExpr expr; if (auto res = TryConstFold(a->min_value, b->min_value)) { @@ -195,7 +197,7 @@ inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, Inte return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { using tirx::Select; - PrimExpr sign = b->min_value >= IntImm(b->min_value.dtype().element_of(), 0); + PrimExpr sign = b->min_value >= IntImm(b->min_value.ty().WithLanes(1), 0); PrimExpr e1 = a->min_value * b->min_value; PrimExpr e2 = a->max_value * b->min_value; return IntervalSet(Select(sign, e1, e2), Select(sign, e2, e1)); @@ -229,7 +231,7 @@ inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, Inte return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { using tirx::Select; - PrimExpr sign = b->min_value >= IntImm(b->min_value.dtype().element_of(), 0); + PrimExpr sign = b->min_value >= IntImm(b->min_value.ty().WithLanes(1), 0); PrimExpr e1 = a->min_value / b->min_value; PrimExpr e2 = a->max_value / b->min_value; return IntervalSet(Select(sign, e1, e2), Select(sign, e2, e1)); @@ -258,7 +260,7 @@ inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, Inte // is the case of our application. // TODO(tqchen): add bound constraints for a. if (analyzer->CanProveGreaterEqual(divisor, 0)) { - return IntervalSet(IntImm(divisor.dtype(), 0), divisor - 1); + return IntervalSet(IntImm(divisor.ty(), 0), divisor - 1); } else { PrimExpr bound = abs(divisor) - 1; return IntervalSet(-bound, bound); @@ -292,7 +294,7 @@ inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { using tirx::Select; - PrimExpr sign = b->min_value >= IntImm(b->min_value.dtype().element_of(), 0); + PrimExpr sign = b->min_value >= IntImm(b->min_value.ty().WithLanes(1), 0); PrimExpr e1 = floordiv(a->min_value, b->min_value); PrimExpr e2 = floordiv(a->max_value, b->min_value); return IntervalSet(Select(sign, e1, e2), Select(sign, e2, e1)); @@ -323,7 +325,7 @@ inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, auto qmin = a->HasLowerBound() ? floordiv(a->min_value, divisor) : neg_inf(); // We can compare +/- inf against each other, but cannot use // operator== between the symbolic limits and an integer. - bool compatible_dtypes = !(qmin.dtype().is_handle() ^ qmax.dtype().is_handle()); + bool compatible_dtypes = !(qmin.ty().IsHandle() ^ qmax.ty().IsHandle()); if (compatible_dtypes && analyzer->CanProve(qmax == qmin)) { auto tmax = a->max_value - divisor * qmin; auto tmin = a->min_value - divisor * qmin; @@ -348,12 +350,13 @@ inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, int64_t max_mod_result = max_quotient * gcd + (dividend_mod->base % gcd); if (max_mod_result >= 0 && max_mod_result < div_val) { - return IntervalSet(IntImm(op->dtype, 0), IntImm(op->dtype, max_mod_result)); + PrimType result_ty = ffi::GetRef(op).ty(); + return IntervalSet(IntImm(result_ty, 0), IntImm(result_ty, max_mod_result)); } } } } - return IntervalSet(IntImm(divisor.dtype(), 0), divisor - 1); + return IntervalSet(IntImm(divisor.ty(), 0), divisor - 1); } else { PrimExpr bound = abs(divisor) - 1; return IntervalSet(-bound, bound); @@ -522,7 +525,7 @@ class IntervalSetEvaluator : public ExprFunctor { IntervalSet base = Eval(op->base); PVar stride; if (stride.Match(op->stride)) { - DataType t = op->base.dtype(); + PrimType t = op->base.ty(); int64_t vstride = stride.Eval()->value; if (op->lanes->IsInstance()) { int lanes = static_cast(op->lanes.as_or_throw()->value); @@ -569,18 +572,19 @@ class IntervalSetEvaluator : public ExprFunctor { // short cut for the int set. if (value_set->min_value.same_as(value_set->max_value)) { if (value_set->IsEmpty()) return value_set; - return IntervalSet::SinglePoint(cast(op->dtype, value_set->min_value)); + return IntervalSet::SinglePoint(cast(op->ty(), value_set->min_value)); } PrimExpr min_value = - value_set->HasLowerBound() ? cast(op->dtype, value_set->min_value) : neg_inf(); + value_set->HasLowerBound() ? cast(op->ty(), value_set->min_value) : neg_inf(); PrimExpr max_value = - value_set->HasUpperBound() ? cast(op->dtype, value_set->max_value) : pos_inf(); + value_set->HasUpperBound() ? cast(op->ty(), value_set->max_value) : pos_inf(); return IntervalSet(min_value, max_value); } IntervalSet VisitExpr_(const BufferLoadNode* op) final { - if (!(op->dtype.is_int() || op->dtype.is_uint())) { - DLOG(WARNING) << "cannot evaluate set BufferLoad which loads from a " << op->dtype + PrimType op_ty = op->ty(); + if (!(op_ty.code() == DLDataTypeCode::kDLInt || op_ty.code() == DLDataTypeCode::kDLUInt)) { + DLOG(WARNING) << "cannot evaluate set BufferLoad which loads from a " << op_ty->dtype << " buffer"; return IntervalSet::Everything(); } @@ -1048,7 +1052,7 @@ IntSet EvalSet(PrimExpr e, const ffi::Map& dom_map) { IntSet IntSet::Vector(PrimExpr x) { // short cut: simply get single point - if (!x.dtype().is_scalable_or_fixed_length_vector()) { + if (!x.ty().IsScalableVector() && !x.ty().IsFixedLengthVector()) { return IntSet::SinglePoint(x); } else { // vector case. @@ -1068,7 +1072,9 @@ IntSet EvalSet(PrimExpr e, const std::unordered_map& dom IntSet EvalSet(Range r, const ffi::Map& dom_map) { Analyzer ana; - if ((r->min->dtype.is_int() || r->min->dtype.is_uint()) && ana->CanProveEqual(r->extent, 1)) { + PrimType min_ty = r->min.ty(); + if ((min_ty.code() == DLDataTypeCode::kDLInt || min_ty.code() == DLDataTypeCode::kDLUInt) && + ana->CanProveEqual(r->extent, 1)) { return EvalSet(r->min, dom_map); } IntervalSetEvaluator m(ana.get(), dom_map); diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index 8dcef7a75a80..d6a264288b16 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -54,7 +54,7 @@ void AppendFloorDivConstraints(const FloorDivNode* div, int64_t value, CompareKi int64_t divisor_value = 0; if (!TryGetIntImm(div->b, &divisor_value) || divisor_value <= 0) return; - DataType dtype = div->a.dtype(); + PrimType dtype = div->a.ty(); PrimExpr divisor = MakeConst(dtype, divisor_value); PrimExpr k = MakeConst(dtype, value); PrimExpr lo = k * divisor; @@ -117,7 +117,8 @@ void CollectDerivedConstraintFacts(const PrimExpr& condition, std::vector()) { if (call->op.same_as(tirx::builtin::bitwise_and()) && call->args.size() == 2 && - call->args[0].dtype().is_bool() && call->args[1].dtype().is_bool()) { + call->args[0].ty().MatchesElementType(DLDataTypeCode::kDLBool, 8) && + call->args[1].ty().MatchesElementType(DLDataTypeCode::kDLBool, 8)) { CollectDerivedConstraintFacts(call->args[0], out); CollectDerivedConstraintFacts(call->args[1], out); return; @@ -260,7 +261,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == tirx::attr::thread_extent || op->attr_key == s_tir::attr::virtual_thread) { IterVar iv = op->node.as_or_throw(); TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U); - Range dom = Range::FromMinExtent(IntImm(op->value.dtype(), 0), op->value); + Range dom = Range::FromMinExtent(IntImm(op->value.ty(), 0), op->value); analyzer_->Bind(iv->var, dom); iter_vars_.Set(iv->var, dom); } @@ -313,7 +314,8 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) { false_value.same_as(op->args[2])) { return ffi::GetRef(op); } else { - return Call(op->dtype, op->op, {cond, true_value, false_value}, op->attrs, op->span); + return Call(ffi::GetRef(op).ty(), op->op, {cond, true_value, false_value}, + op->attrs, op->span); } } return StmtExprMutator::VisitExpr_(op); diff --git a/src/arith/ir_visitor_with_analyzer.cc b/src/arith/ir_visitor_with_analyzer.cc index ffe9c73bd6f2..0313dbfe4271 100644 --- a/src/arith/ir_visitor_with_analyzer.cc +++ b/src/arith/ir_visitor_with_analyzer.cc @@ -79,7 +79,7 @@ void IRVisitorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == tirx::attr::thread_extent || op->attr_key == s_tir::attr::virtual_thread) { IterVar iv = op->node.as_or_throw(); TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U); - analyzer_->Bind(iv->var, Range::FromMinExtent(IntImm(op->value->dtype, 0), op->value)); + analyzer_->Bind(iv->var, Range::FromMinExtent(IntImm(op->value.ty(), 0), op->value)); } StmtExprVisitor::VisitStmt_(op); }); diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index c7f8819f944f..430a4ec5c839 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -66,8 +66,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { IterSplitExpr::IterSplitExpr(IterMark source) { auto n = ffi::make_object(); - auto one = MakeConst(source->source->dtype, 1); - n->dtype = source->source->dtype; + auto one = MakeConst(source->source.ty(), 1); + n->BaseExprNode::ty = source->source.ty(); n->source = std::move(source); n->extent = n->source->extent; n->lower_factor = one; @@ -77,8 +77,8 @@ IterSplitExpr::IterSplitExpr(IterMark source) { IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr scale) { auto n = ffi::make_object(); - auto one = MakeConst(source->source->dtype, 1); - n->dtype = source->source->dtype; + auto one = MakeConst(source->source.ty(), 1); + n->BaseExprNode::ty = source->source.ty(); n->source = std::move(source); n->extent = n->source->extent; n->lower_factor = one; @@ -89,7 +89,7 @@ IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr scale) { IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr extent, PrimExpr scale) { auto n = ffi::make_object(); - n->dtype = source->source->dtype; + n->BaseExprNode::ty = source->source.ty(); n->source = std::move(source); n->lower_factor = std::move(lower_factor); n->extent = std::move(extent); @@ -109,7 +109,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { IterSumExpr::IterSumExpr(ffi::Array args, PrimExpr base) { auto n = ffi::make_object(); - n->dtype = base->dtype; + n->BaseExprNode::ty = base.ty(); n->args = std::move(args); n->base = std::move(base); data_ = std::move(n); @@ -563,7 +563,7 @@ class IterMapRewriter : public ExprMutator { IterMapLevel check_level) { std::vector used(splits.size(), false); std::vector iters; - PrimExpr expected_lower_factor = MakeConst(mark->source->dtype, 1); + PrimExpr expected_lower_factor = MakeConst(mark->source.ty(), 1); for (size_t i = 0; i < splits.size(); ++i) { size_t j = 0; @@ -694,7 +694,7 @@ class IterMapRewriter : public ExprMutator { PrimExpr iter_min = mark_offset; PrimExpr iter_max = iter_min + mark->extent; // the delta of iter_min when it is updated when the lower bound predicate is present - PrimExpr iter_min_delta = IntImm(iter_min.dtype(), 0); + PrimExpr iter_min_delta = IntImm(iter_min.ty(), 0); if (predicate_induced_min.defined()) { iter_min_delta = max(predicate_induced_min.value(), iter_min) - iter_min; iter_min = max(predicate_induced_min.value(), iter_min); @@ -788,7 +788,7 @@ class IterMapRewriter : public ExprMutator { for (IterSplitExpr split : expr->args) { int64_t symbol_prod_count = 0; int64_t cscale = 1; - PrimExpr res = tirx::MakeConst(split.dtype(), 1); + PrimExpr res = tirx::MakeConst(split.ty(), 1); auto fcollect = [&](PrimExpr val) { if (const auto* intimm = val.as()) { cscale *= intimm->value; @@ -799,7 +799,7 @@ class IterMapRewriter : public ExprMutator { }; UnpackReduction(split->scale, fcollect); if (cscale != 1) { - res = res * tirx::MakeConst(res.dtype(), cscale); + res = res * tirx::MakeConst(res.ty(), cscale); } split.CopyOnWrite()->scale = res; items.emplace_back(Item{cscale, symbol_prod_count, split}); @@ -830,7 +830,7 @@ class IterMapRewriter : public ExprMutator { if (auto op = expr.as()) { return op.value(); } else if (auto op = expr.as()) { - return IterSumExpr({op.value()}, IntImm(expr->dtype, 0)); + return IterSumExpr({op.value()}, IntImm(expr.ty(), 0)); } else { TVM_FFI_ICHECK(!expr->IsInstance()); return IterSumExpr({}, expr); @@ -1103,8 +1103,8 @@ class IterMapRewriter : public ExprMutator { std::vector flattened_iters, grouped_iters; // check if it can be remapped into a fused pattern. - PrimExpr expected_extra_base = IntImm(expr.dtype(), 0); - PrimExpr tail_extent = IntImm(expr.dtype(), 0); + PrimExpr expected_extra_base = IntImm(expr.ty(), 0); + PrimExpr tail_extent = IntImm(expr.ty(), 0); PrimExpr expected_scale = base_scale; int first_possible_unit_extent_pos = FindFirstPossibleUnitExtentIndex(expr); @@ -1200,10 +1200,10 @@ class IterMapRewriter : public ExprMutator { IterSumExpr structured_form = expr, flattened_form = expr; flattened_form.CopyOnWrite()->args = ffi::Array(flattened_iters.rbegin(), flattened_iters.rend()); - flattened_form.CopyOnWrite()->base = IntImm(expr.dtype(), 0); + flattened_form.CopyOnWrite()->base = IntImm(expr.ty(), 0); structured_form.CopyOnWrite()->args = ffi::Array(grouped_iters.rbegin(), grouped_iters.rend()); - structured_form.CopyOnWrite()->base = IntImm(expr.dtype(), 0); + structured_form.CopyOnWrite()->base = IntImm(expr.ty(), 0); auto it = sum_fuse_map_.find(flattened_form); if (it != sum_fuse_map_.end()) { // old iter @@ -1245,7 +1245,7 @@ class IterMapRewriter : public ExprMutator { if (sign > 0) { lhs->args.push_back(rhs); } else { - rhs.CopyOnWrite()->scale = IntImm(rhs->scale.dtype(), 0) - rhs->scale; + rhs.CopyOnWrite()->scale = IntImm(rhs->scale.ty(), 0) - rhs->scale; lhs->args.push_back(rhs); } } @@ -1332,8 +1332,10 @@ bool MatchBoundConstraints(PrimExpr pred, ffi::Map* input_iters, PrimExpr lhs_expr = lhs.Eval(); PrimExpr rhs_expr = rhs.Eval(); // we only accept predicate of integers - if (!((lhs_expr->dtype.is_int() || lhs_expr->dtype.is_uint()) && - (rhs_expr->dtype.is_int() || rhs_expr->dtype.is_uint()))) { + PrimType lhs_ty = lhs_expr.ty(); + PrimType rhs_ty = rhs_expr.ty(); + if (!((lhs_ty.code() == DLDataTypeCode::kDLInt || lhs_ty.code() == DLDataTypeCode::kDLUInt) && + (rhs_ty.code() == DLDataTypeCode::kDLInt || rhs_ty.code() == DLDataTypeCode::kDLUInt))) { return false; } // determine iter and bound, if we can not distinguish them simply, @@ -1563,7 +1565,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const VarNode* op) { } PrimExpr IterMapRewriter::VisitExpr_(const AddNode* op) { - if (!IsIndexType(op->dtype)) { + if (!IsIndexTypedExpr(op)) { return Parent::VisitExpr_(op); } PrimExpr a = this->DirectMutate(op->a); @@ -1596,7 +1598,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const AddNode* op) { } PrimExpr IterMapRewriter::VisitExpr_(const SubNode* op) { - if (!IsIndexType(op->dtype)) { + if (!IsIndexTypedExpr(op)) { return Parent::VisitExpr_(op); } @@ -1631,7 +1633,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const SubNode* op) { } PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { - if (!IsIndexType(op->dtype)) { + if (!IsIndexTypedExpr(op)) { return Parent::VisitExpr_(op); } // normalize @@ -1677,7 +1679,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr original_dividend) { if (dividend->IsInstance()) { auto split = dividend.as_or_throw(); - return IterSumExpr({split}, IntImm(split.dtype(), 0)); + return IterSumExpr({split}, IntImm(split.ty(), 0)); } else if (dividend->IsInstance()) { auto sum = dividend.as_or_throw(); if (sum->args.empty()) { @@ -1880,12 +1882,12 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P } else if (CanProveDivisible(rhs, lhs->scale) && is_zero(base)) { // floordiv(x*c1, c1*c2) = floordiv(x, c2), c2=rhs/scale rhs = floordiv(rhs, lhs->scale); - lhs.CopyOnWrite()->scale = MakeConst(rhs->dtype, 1); + lhs.CopyOnWrite()->scale = MakeConst(rhs.ty(), 1); } else if (CanProveDivisible(rhs, lhs->scale) && CanProveDivisible(base, lhs->scale)) { // floordiv(x*c1 + y*c1, c1*c2) = floordiv(x+y, c2), c2=rhs/scale base = floordiv(base, lhs->scale); rhs = floordiv(rhs, lhs->scale); - lhs.CopyOnWrite()->scale = MakeConst(rhs->dtype, 1); + lhs.CopyOnWrite()->scale = MakeConst(rhs.ty(), 1); } else { // mark as unresolved. ErrorLogger(this) << "Cannot represent as IterMap: the numerator's scaling factor, " @@ -1931,7 +1933,7 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P new_split = IterSplitExpr(IterMark(padded, padded->extent), /* lower_factor = */ rhs, /* extent = */ analyzer_->Simplify(ceildiv(padded->extent, rhs)), - /* scale = */ MakeConst(rhs->dtype, 1)); + /* scale = */ MakeConst(rhs.ty(), 1)); } auto new_base = analyzer_->Simplify(floordiv(base - left_pad, rhs), 6); @@ -1944,7 +1946,7 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P } PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { - if (!IsIndexType(op->dtype)) { + if (!IsIndexTypedExpr(op)) { return Parent::VisitExpr_(op); } @@ -1987,13 +1989,13 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, P if (is_one(rhs)) { // floormod(x, 1) = 0 - return IntImm(lhs->dtype, 0); + return IntImm(lhs.ty(), 0); } if (!is_one(lhs->scale)) { if (CanProveDivisible(lhs->scale, rhs) && CanProveDivisible(base, rhs)) { // floormod(x*c1*c2, c1) = 0 - return IntImm(lhs->dtype, 0); + return IntImm(lhs.ty(), 0); } else if (CanProveDivisible(rhs, lhs->scale) && is_zero(base)) { // floormod(x*c1, c1*c2) = (floormod(x, c2)) * c1, where c2 = rhs/scale rhs = floordiv(rhs, lhs->scale); @@ -2028,7 +2030,7 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, P } PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { - if (!IsIndexType(op->dtype)) { + if (!IsIndexTypedExpr(op)) { return Parent::VisitExpr_(op); } @@ -2113,7 +2115,7 @@ class IterMapToExprNormalizer : public ExprMutator { // simplify trivial iters like `vi \in [0, 1)`, which can be useful for subsequent analysis // like tensorization. if (is_one(expr->extent) && !is_one(expr->source->extent)) { - return IntImm(expr->extent->dtype, 0); + return IntImm(expr->extent.ty(), 0); } return floordiv(source, expr->lower_factor) * expr->scale; } else { @@ -2255,13 +2257,13 @@ class SubspaceDivider { IterSplitExpr GetInnerAsSplit() const { return GetAsSplit(inner, inner_extent); } static DivisionResult Inner(const IterMapExpr& iter, const PrimExpr& extent) { - auto dtype = iter.dtype(); + PrimType dtype = iter.ty(); return DivisionResult(IterSumExpr({}, IntImm(dtype, 0)), IntImm(dtype, 1), iter, extent, Kind::kInner); } static DivisionResult Outer(const IterMapExpr& iter, const PrimExpr& extent) { - auto dtype = iter.dtype(); + PrimType dtype = iter.ty(); return DivisionResult(iter, extent, IterSumExpr({}, IntImm(dtype, 0)), IntImm(dtype, 1), Kind::kOuter); } @@ -2285,7 +2287,7 @@ class SubspaceDivider { // Divide an IterSumExpr DivisionResult DivideIterSumExpr(const IterSumExpr& expr, const PrimExpr& mark_extent) { - auto dtype = expr.dtype(); + PrimType dtype = expr.ty(); if (expr->args.empty()) { // base return DivisionResult(IterSumExpr({}, IntImm(dtype, 0)), IntImm(dtype, 1), @@ -2377,7 +2379,7 @@ class SubspaceDivider { // args are sorted from inner to outer static IterMark MarkFromArgsAndBase(const std::vector& args, PrimExpr base) { std::vector res; - PrimExpr extent = MakeConst(base.dtype(), 1); + PrimExpr extent = MakeConst(base.ty(), 1); for (const IterSplitExpr& it : args) { IterSplitExpr arg = it; arg.CopyOnWrite()->scale = extent; @@ -2431,7 +2433,7 @@ class SubspaceDivider { bool encountered_boundary = mark_division.IsOuter(); std::vector used(splits.size(), false); std::vector inner_iters, outer_iters; - PrimExpr expected_lower_factor = MakeConst(expr->source->source->dtype, 1); + PrimExpr expected_lower_factor = MakeConst(expr->source->source.ty(), 1); // find the boundary of outer and inner, like case 1 above for (size_t i = 0; i < splits.size(); ++i) { size_t j = 0; diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index bb1ebd54cca7..f6a052089842 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -71,6 +71,7 @@ #include #include +#include #include #include "const_fold.h" @@ -199,7 +200,10 @@ class PVar : public Pattern> { // Store PVars by reference in the expression. using Nested = const PVar&; - void InitMatch_() const { filled_ = false; } + void InitMatch_() const { + value_.reset(); + filled_ = false; + } bool Match_(const T& value) const { if (!filled_) { @@ -207,7 +211,7 @@ class PVar : public Pattern> { filled_ = true; return true; } else { - return PEqualChecker()(value_, value); + return PEqualChecker()(value_.value(), value); } } @@ -223,14 +227,14 @@ class PVar : public Pattern> { T Eval() const { TVM_FFI_ICHECK(filled_); - return value_; + return value_.value(); } - T EvalOr(const T& default_value) const { return filled_ ? value_ : default_value; } + T EvalOr(const T& default_value) const { return filled_ ? value_.value() : default_value; } protected: /*! \brief The matched value */ - mutable T value_; + mutable std::optional value_; /*! \brief whether the variable has been filled */ mutable bool filled_{false}; }; @@ -282,7 +286,7 @@ class PVarWithDataType : public PVarWithCheck, T> { public: explicit PVarWithDataType(const DType& dtype) : dtype_(dtype) {} - bool Match_(const T& value) const { return dtype_.Match_(value->dtype); } + bool Match_(const T& value) const { return dtype_.Match_(value.ty()); } protected: typename DType::Nested dtype_; @@ -291,15 +295,15 @@ class PVarWithDataType : public PVarWithCheck, T> { /*! * \brief Pattern variable container for data type with lanes. */ -class PVecDataType : public PVarWithCheck { +class PVecDataType : public PVarWithCheck { public: /*! \brief construct vector dtype placeholder with element type check */ - explicit PVecDataType(const DataType& elem_dtype) : elem_dtype_(elem_dtype) {} + explicit PVecDataType(PrimType elem_dtype) : elem_dtype_(elem_dtype) {} - bool Match_(const DataType& dtype) const { return dtype.code() == elem_dtype_.code(); } + bool Match_(PrimType dtype) const { return dtype.code() == elem_dtype_.code(); } protected: - DataType elem_dtype_; + PrimType elem_dtype_; }; /*! @@ -377,7 +381,7 @@ class PConstWithTypeLike : public Pattern> { } } - PrimExpr Eval() const { return tirx::MakeConst(ref_.Eval().dtype(), value_); } + PrimExpr Eval() const { return tirx::MakeConst(ref_.Eval().ty(), value_); } private: typename TA::Nested ref_; @@ -540,7 +544,7 @@ class PCastExpr : public Pattern> { bool Match_(const ffi::ObjectRef& node) const { if (const tirx::CastNode* ptr = node.as()) { - if (!dtype_.Match_(ptr->dtype)) return false; + if (!dtype_.Match_(ptr->ty())) return false; if (!value_.Match_(ptr->value)) return false; return true; } else { @@ -558,7 +562,7 @@ class PCastExpr : public Pattern> { /*! * \brief Construct a cast pattern. * - * \param dtype The target data type, can be PVar or PConst. + * \param dtype The target data type, can be PVar or PConst. * \param value The input type. * * \return The result pattern. @@ -780,7 +784,7 @@ class PCallExpr : public Pattern> { #define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinOpName) \ struct OpName { \ static PrimExpr Eval(ffi::Array args) { \ - return tirx::Call(args[0].dtype(), GetOp(), args); \ + return tirx::Call(args[0].ty(), GetOp(), args); \ } \ static const Op& GetOp() { return tirx::builtin::IntrinOpName(); } \ }; \ @@ -799,7 +803,7 @@ TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, bitwise_xor); #define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinOpName) \ struct OpName { \ static PrimExpr Eval(ffi::Array args) { \ - return tirx::Call(args[0].dtype(), GetOp(), args); \ + return tirx::Call(args[0].ty(), GetOp(), args); \ } \ static const Op& GetOp() { return tirx::builtin::IntrinOpName(); } \ }; \ @@ -813,7 +817,7 @@ TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, bitwise_not); // if_then_else struct PIfThenElseOp { static PrimExpr Eval(ffi::Array args) { - return tirx::Call(args[1].dtype(), GetOp(), args); + return tirx::Call(args[1].ty(), GetOp(), args); } static const Op& GetOp() { return tirx::builtin::if_then_else(); } }; @@ -841,7 +845,7 @@ inline PCallExpr if_then_else(const Pattern // vscale struct PVscaleOp { - static PrimExpr Eval() { return tirx::Call(DataType::Int(32), GetOp(), {}); } + static PrimExpr Eval() { return tirx::Call(PrimType::Int(32), GetOp(), {}); } static const Op& GetOp() { return tirx::builtin::vscale(); } }; diff --git a/src/arith/product_normal_form.h b/src/arith/product_normal_form.h index 40d02c1952b7..79e040287fa7 100644 --- a/src/arith/product_normal_form.h +++ b/src/arith/product_normal_form.h @@ -79,7 +79,8 @@ inline void UnpackSum(const PrimExpr& value, FLeaf fleaf, int sign = 1) { */ inline PrimExpr MulAndNormalize(const PrimExpr& lhs, const PrimExpr& rhs) { int64_t cscale = 1; - PrimExpr res = tirx::MakeConst(lhs.dtype(), 1); + PrimType lhs_ty = lhs.ty(); + PrimExpr res = tirx::MakeConst(lhs_ty, 1); auto fcollect = [&](PrimExpr val) { if (const auto* intimm = val.as()) { cscale *= intimm->value; @@ -90,7 +91,7 @@ inline PrimExpr MulAndNormalize(const PrimExpr& lhs, const PrimExpr& rhs) { UnpackReduction(lhs, fcollect); UnpackReduction(rhs, fcollect); if (cscale != 1) { - res = res * tirx::MakeConst(res.dtype(), cscale); + res = res * tirx::MakeConst(res.ty(), cscale); } return res; } diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index fa3ba0b519d6..07ea2c7a7778 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -425,7 +425,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { // Pattern var for lanes in broadcast and ramp PVar lanes; // Vector rules - if (op->dtype.is_scalable_or_fixed_length_vector()) { + if (op->ty().IsScalableVector() || op->ty().IsFixedLengthVector()) { TVM_TRY_REWRITE(ramp(b1, s1, lanes) + ramp(b2, s2, lanes), ramp(b1 + b2, s1 + s2, lanes)); TVM_TRY_REWRITE(ramp(b1, s1, lanes) + broadcast(x, lanes), ramp(b1 + x, s1, lanes)); TVM_TRY_REWRITE(broadcast(x, lanes) + ramp(b1, s1, lanes), ramp(x + b1, s1, lanes)); @@ -433,7 +433,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { TVM_TRY_REWRITE_IF(x + broadcast(c4, lanes), x, c4.Eval()->value == 0.0f); } - if (IsIndexType(op->dtype)) { + if (IsIndexTypedExpr(op)) { // Index rules // cancelation rules TVM_TRY_REWRITE((x - y) + y, x); @@ -535,7 +535,7 @@ std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c if (SideEffect(subconstraint) <= CallEffectKind::kPure) { literal_constraints_.push_back(subconstraint); PrimExpr negation; - if (subconstraint.dtype().is_bool()) { + if (subconstraint.ty().MatchesElementType(DLDataTypeCode::kDLBool, 8)) { // We could apply NormalizeBooleanOperators during // TryMatchLiteralConstraint, but that would require // performing a rewrite of each expression being checked. @@ -543,7 +543,7 @@ std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c // applied. negation = NormalizeBooleanOperators(Not(subconstraint)); } else { - negation = subconstraint == IntImm(subconstraint.dtype(), 0); + negation = subconstraint == IntImm(subconstraint.ty(), 0); } literal_constraints_.push_back(Not(negation)); } @@ -575,14 +575,14 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { PVar lanes; // Vector rules - if (op->dtype.is_scalable_or_fixed_length_vector()) { + if (op->ty().IsScalableVector() || op->ty().IsFixedLengthVector()) { TVM_TRY_REWRITE(ramp(b1, s1, lanes) - ramp(b2, s2, lanes), ramp(b1 - b2, s1 - s2, lanes)); TVM_TRY_REWRITE(ramp(b1, s1, lanes) - broadcast(x, lanes), ramp(b1 - x, s1, lanes)); TVM_TRY_REWRITE(broadcast(x, lanes) - ramp(b1, s1, lanes), ramp(x - b1, 0 - s1, lanes)); TVM_TRY_REWRITE(broadcast(x, lanes) - broadcast(y, lanes), broadcast(x - y, lanes)); } - if (IsIndexType(op->dtype)) { + if (IsIndexTypedExpr(op)) { // Index rules // cancelation rules TVM_TRY_REWRITE(matches_one_of((x + y) - y, (y + x) - y), x); @@ -765,7 +765,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { // Pattern var for lanes in broadcast and ramp PVar lanes; // Vector rules - if (op->dtype.is_scalable_or_fixed_length_vector()) { + if (op->ty().IsScalableVector() || op->ty().IsFixedLengthVector()) { TVM_TRY_REWRITE(broadcast(x, lanes) * broadcast(y, lanes), broadcast(x * y, lanes)); TVM_TRY_REWRITE(matches_one_of(ramp(b1, s1, lanes) * broadcast(x, lanes), broadcast(x, lanes) * ramp(b1, s1, lanes)), @@ -773,7 +773,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { TVM_TRY_REWRITE_IF(broadcast(c3, lanes) * x, broadcast(c3, lanes), c3.Eval()->value == 0.0f); } - if (IsIndexType(op->dtype)) { + if (IsIndexTypedExpr(op)) { // constant simplification rule TVM_TRY_REWRITE((x + c1) * c2, x * c2 + c1 * c2); TVM_TRY_REWRITE((x * c1) * c2, x * (c1 * c2)); @@ -803,7 +803,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { PVar lanes; // Vector rules - if (op->dtype.is_scalable_or_fixed_length_vector()) { + if (op->ty().IsScalableVector() || op->ty().IsFixedLengthVector()) { // NOTE: use div as the pattern also works for float. TVM_TRY_REWRITE(div(broadcast(x, lanes), broadcast(y, lanes)), broadcast(div(x, y), lanes)); // ramp / bcast @@ -827,7 +827,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { } } - if (IsIndexType(op->dtype)) { + if (IsIndexTypedExpr(op)) { // Be-aware of the division rules: // We adopt the default C division uses truncation instead of floordiv. // This means most rules need to check non-negativeness of the operands. @@ -839,7 +839,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { if (truncdiv(c1, c2).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; - return MakeConst(op->dtype, truncdiv(c1val, c2val)); + return MakeConst(op->ty(), truncdiv(c1val, c2val)); } // while it is always true for trunc div @@ -957,7 +957,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) { PVar lanes; // Vector rules - if (op->dtype.is_scalable_or_fixed_length_vector()) { + if (op->ty().IsScalableVector() || op->ty().IsFixedLengthVector()) { TVM_TRY_REWRITE(truncmod(broadcast(x, lanes), broadcast(y, lanes)), broadcast(truncmod(x, y), lanes)); @@ -994,7 +994,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) { } } - if (IsIndexType(op->dtype)) { + if (IsIndexTypedExpr(op)) { // Be-aware of the division rules: // We adopt the default C division uses truncation instead of floordiv. // This means most rules need to check non-negativeness of the operands. @@ -1019,7 +1019,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) { // canonicalization: x % c == x % (-c) for truncated division // NOTE: trunc div required TVM_TRY_RECURSIVE_REWRITE_IF( - truncmod(x, c1), truncmod(x, PConst(MakeConst(op->dtype, -c1.Eval()->value))), + truncmod(x, c1), truncmod(x, PConst(MakeConst(op->ty(), -c1.Eval()->value))), c1.Eval()->value < 0); // try modular analysis @@ -1046,7 +1046,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { PVar lanes; // Vector rules - if (op->dtype.is_scalable_or_fixed_length_vector()) { + if (op->ty().IsScalableVector() || op->ty().IsFixedLengthVector()) { TVM_TRY_REWRITE(floordiv(broadcast(x, lanes), broadcast(y, lanes)), broadcast(floordiv(x, y), lanes)); // ramp // bcast @@ -1077,7 +1077,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { } } - if (IsIndexType(op->dtype)) { + if (IsIndexTypedExpr(op)) { // Be-aware of the division rules: this is floor division. TVM_TRY_REWRITE_IF(floordiv(floordiv(x, c1), c2), floordiv(x, c1 * c2), c1.Eval()->value > 0 && c2.Eval()->value > 0); @@ -1198,7 +1198,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { PVar lanes; // Vector rules - if (op->dtype.is_scalable_or_fixed_length_vector()) { + if (op->ty().IsScalableVector() || op->ty().IsFixedLengthVector()) { TVM_TRY_REWRITE(floormod(broadcast(x, lanes), broadcast(y, lanes)), broadcast(floormod(x, y), lanes)); @@ -1238,7 +1238,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { } } - if (IsIndexType(op->dtype)) { + if (IsIndexTypedExpr(op)) { // Be-aware of the division rules: we use floordiv/floormod here TVM_TRY_REWRITE_IF(floormod(x * c1, c2), floormod(x * floormod(c1, c2), c2), c2.Eval()->value != 0); @@ -1314,12 +1314,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) { PVar lanes; // vector rule - if (op->dtype.is_scalable_or_fixed_length_vector()) { + if (op->ty().IsScalableVector() || op->ty().IsFixedLengthVector()) { TVM_TRY_REWRITE(min(broadcast(x, lanes), broadcast(y, lanes)), broadcast(min(x, y), lanes)); TVM_TRY_REWRITE(min(min(x, broadcast(y, lanes)), broadcast(z, lanes)), min(x, broadcast(min(y, z), lanes))); } - if (IsIndexType(op->dtype)) { + if (IsIndexTypedExpr(op)) { TVM_TRY_REWRITE(min(x, x), x); // constant int bound @@ -1498,12 +1498,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) { PVar lanes; // vector rule - if (op->dtype.is_scalable_or_fixed_length_vector()) { + if (op->ty().IsScalableVector() || op->ty().IsFixedLengthVector()) { TVM_TRY_REWRITE(max(broadcast(x, lanes), broadcast(y, lanes)), broadcast(max(x, y), lanes)); TVM_TRY_REWRITE(max(max(x, broadcast(y, lanes)), broadcast(z, lanes)), max(x, broadcast(max(y, z), lanes))); } - if (IsIndexType(op->dtype)) { + if (IsIndexTypedExpr(op)) { TVM_TRY_REWRITE(max(x, x), x); // constant int bound @@ -1686,10 +1686,10 @@ ffi::Optional RewriteSimplifier::Impl::TryMatchLiteralConstraint( ExprDeepEqual expr_equal; for (const auto& constraint : literal_constraints_) { if (expr_equal(constraint, expr)) { - return MakeConst(expr->dtype, true); + return MakeConst(expr->ty(), true); } if (expr_equal(constraint, negation)) { - return MakeConst(expr->dtype, false); + return MakeConst(expr->ty(), false); } } return std::nullopt; @@ -1715,20 +1715,20 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) { // Pattern var match IntImm PVar c1, c2; PVar lanes; - PConst ctrue(MakeConst(ret->dtype, true)); + PConst ctrue(MakeConst(ret->ty(), true)); // vector rule - if (ret->dtype.is_scalable_or_fixed_length_vector()) { + if (ret->ty().IsScalableVector() || ret->ty().IsFixedLengthVector()) { TVM_TRY_REWRITE(broadcast(x, lanes) == broadcast(y, lanes), broadcast(x == y, lanes)); } - if (IsIndexType(ret->a.dtype())) { + if (IsIndexTypedExpr(ret->a)) { CompareResult result = TryCompare(ret->a, ret->b); if (result == CompareResult::kEQ) { - return MakeConst(ret->dtype, true); + return MakeConst(ret->ty(), true); } else if (result == CompareResult::kNE || result == CompareResult::kGT || result == CompareResult::kLT) { - return MakeConst(ret->dtype, false); + return MakeConst(ret->ty(), false); } TVM_TRY_REWRITE(c1 == x, x == c1); @@ -1758,13 +1758,13 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NENode* op) { if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); - if (IsIndexType(op->a.dtype())) { + if (IsIndexTypedExpr(op->a)) { CompareResult result = TryCompare(op->a, op->b); if (result == CompareResult::kNE || result == CompareResult::kGT || result == CompareResult::kLT) { - return MakeConst(op->dtype, true); + return MakeConst(op->ty(), true); } else if (result == CompareResult::kEQ) { - return MakeConst(op->dtype, false); + return MakeConst(op->ty(), false); } else if (result == CompareResult::kGE) { // Known: a >= b // @@ -1802,13 +1802,13 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LENode* op) { // (floordiv(A,B)b, op->a)))); - if (auto op = ret.as(); op && IsIndexType(op->a.dtype())) { + if (auto op = ret.as(); op && IsIndexTypedExpr(op->a)) { CompareResult result = TryCompare(op->a, op->b); if (result == CompareResult::kLE || result == CompareResult::kLT || result == CompareResult::kEQ) { - return MakeConst(op->dtype, true); + return MakeConst(op->ty(), true); } else if (result == CompareResult::kGT) { - return MakeConst(op->dtype, false); + return MakeConst(op->ty(), false); } else if (result == CompareResult::kNE) { // Known: a != b // @@ -1857,19 +1857,19 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) { PVar lanes; // vector rule - if (ret->dtype.is_scalable_or_fixed_length_vector()) { + if (ret->ty().IsScalableVector() || ret->ty().IsFixedLengthVector()) { TVM_TRY_REWRITE(broadcast(x, lanes) < broadcast(y, lanes), broadcast(x < y, lanes)); TVM_TRY_REWRITE(ramp(x, s1, lanes) < ramp(y, s1, lanes), broadcast(x < y, lanes)); } - if (IsIndexType(ret->a.dtype())) { + if (IsIndexTypedExpr(ret->a)) { CompareResult result = TryCompare(ret->a, ret->b); if (result == CompareResult::kLT) { - return MakeConst(ret->dtype, true); + return MakeConst(ret->ty(), true); } if (result == CompareResult::kEQ || result == CompareResult::kGT || result == CompareResult::kGE) { - return MakeConst(ret->dtype, false); + return MakeConst(ret->ty(), false); } // clang-format off @@ -1987,9 +1987,9 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) { } else if (diff == 1) { return lhs <= rhs; } else if (diff < 0 && rhs_offset != 0) { - return lhs + MakeConst(lhs.dtype(), -diff) < rhs; + return lhs + MakeConst(lhs.ty(), -diff) < rhs; } else if (diff > 0 && lhs_offset != 0) { - return lhs < rhs + MakeConst(rhs.dtype(), diff); + return lhs < rhs + MakeConst(rhs.ty(), diff); } return std::nullopt; @@ -2024,7 +2024,7 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(Not ret) { // Pattern var to match any expression PVar x, y; PVar lanes; - if (ret->dtype.is_scalable_or_fixed_length_vector()) { + if (ret->ty().IsScalableVector() || ret->ty().IsFixedLengthVector()) { TVM_TRY_REWRITE(!broadcast(x, lanes), broadcast(!x, lanes)); } @@ -2100,11 +2100,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { PVar c1, c2, c3; PVar lanes; - if (op->dtype.is_scalable_or_fixed_length_vector()) { + if (op->ty().IsScalableVector() || op->ty().IsFixedLengthVector()) { TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes), broadcast(x && y, lanes)); } - auto cfalse = PConst(MakeConst(op->dtype, false)); + auto cfalse = PConst(MakeConst(op->ty(), false)); TVM_TRY_REWRITE(x == y && x != y, cfalse); TVM_TRY_REWRITE(x != y && x == y, cfalse); TVM_TRY_REWRITE(x && !x, cfalse); @@ -2248,11 +2248,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) { PVar c1, c2; PVar lanes; - if (op->dtype.is_scalable_or_fixed_length_vector()) { + if (op->ty().IsScalableVector() || op->ty().IsFixedLengthVector()) { TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes), broadcast(x || y, lanes)); } - auto ctrue = PConst(MakeConst(op->dtype, true)); + auto ctrue = PConst(MakeConst(op->ty(), true)); TVM_TRY_REWRITE(x == y || x != y, ctrue); TVM_TRY_REWRITE(x != y || x == y, ctrue); @@ -2319,12 +2319,14 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { static const Op& ceil_op = Op::Get("tirx.ceil"); static const Op& log2_op = Op::Get("tirx.log2"); static const Op& clz_op = Op::Get("tirx.clz"); + PrimType ret_ty = ffi::GetRef(op).ty(); if (op->op.same_as(ceil_op)) { PrimExpr ceil_arg = op->args[0]; if (auto arg_int = op->args[0].as()) { - return cast(op->dtype, IntImm(arg_int->dtype, arg_int->value)); + return cast(ret_ty, IntImm(ffi::GetRef(arg_int).ty(), arg_int->value)); } else if (auto arg_float = ceil_arg.as()) { - return cast(op->dtype, FloatImm(arg_float->dtype, std::ceil(arg_float->value))); + return cast(ret_ty, + FloatImm(ffi::GetRef(arg_float).ty(), std::ceil(arg_float->value))); } else if (auto arg_call = ceil_arg.as()) { // ceil(log2(cast(n,"float64"))) is used as the implementation of // topi.math.ceil_log2, and appears in iteration bounds. @@ -2334,17 +2336,17 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { // ceil(log2(n)) can be simplified, and should produce the // same integer result regardless of the target's rounding // conventions. - return FloatImm(op->dtype, std::ceil(std::log2(as_float->value))); + return FloatImm(ret_ty, std::ceil(std::log2(as_float->value))); } } } } else if (op->op.same_as(clz_op)) { if (const auto* arg_int = op->args[0].as()) { - int bits = arg_int->dtype.bits(); - if (arg_int->value == 0) return MakeConst(op->dtype, bits); + int bits = arg_int->ty().bits(); + if (arg_int->value == 0) return MakeConst(ret_ty, bits); for (int i = bits - 1; i >= 0; --i) { if ((int64_t(1) << i) & arg_int->value) { - return IntImm(op->dtype, bits - i - 1); + return IntImm(ret_ty, bits - i - 1); } } TVM_FFI_THROW(InternalError) << "Should not reach here"; @@ -2373,7 +2375,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { // Only check constant cases to avoid recursion if (is_const_number(inner_else_expr) && is_const_number(else_expr) && analyzer_->CanProve(inner_else_expr == else_expr)) { - return Call(op->dtype, op->op, {cond && inner_cond, inner_then_expr, else_expr}, op->attrs, + return Call(ret_ty, op->op, {cond && inner_cond, inner_then_expr, else_expr}, op->attrs, op->span); } } @@ -2384,7 +2386,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const VarNode* op) { Var var = ffi::GetRef(op); - if (op->dtype == DataType::Bool()) { + PrimType op_ty = op->ty(); + if (op_ty.MatchesElementType(DLDataTypeCode::kDLBool, 8) && !op_ty.IsScalableVector() && + !op_ty.IsFixedLengthVector()) { if (auto match = TryMatchLiteralConstraint(var)) { return match.value(); } @@ -2400,7 +2404,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const VarNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CastNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - return cast(op->dtype, op->value); + return cast(ret.ty(), op->value); } bool RewriteSimplifier::Impl::CanInlineLet(const LetNode* op) { diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc index 27144c674b9f..fd507ccdd658 100644 --- a/src/arith/solve_linear_equation.cc +++ b/src/arith/solve_linear_equation.cc @@ -24,9 +24,9 @@ #include #include #include +#include #include #include -#include #include #include #include @@ -133,10 +133,10 @@ void SmithNormalFormDiag(std::vector>* S, std::vector>* S, std::vector()) { name_hint += "_" + v_old->name_hint; } - Var v = Var(name_hint, V_inv_x[j].dtype()); + Var v = Var(name_hint, V_inv_x[j].ty()); solution_for_V_inv_x.push_back(v); new_vars.push_back(v); new_to_old_map.Set(v, to_old); @@ -403,12 +403,12 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol // The j-th variable is just a single value, don't create a tvm variable // S^{-1}_{nxm} Uy_{mxn} if (S[j][j] >= 0) { - PrimExpr a = tirx::MakeConst(Uy[j].dtype(), S[j][j]); + PrimExpr a = tirx::MakeConst(Uy[j].ty(), S[j][j]); solution_for_V_inv_x.push_back(analyzer_problem->Simplify(floordiv(Uy[j], a))); } else { // This is required because some simplifiers // have problems with dividing by negative numbers - PrimExpr a = tirx::MakeConst(Uy[j].dtype(), -S[j][j]); + PrimExpr a = tirx::MakeConst(Uy[j].ty(), -S[j][j]); solution_for_V_inv_x.push_back(analyzer_problem->Simplify(floordiv(-Uy[j], a))); } } @@ -416,9 +416,9 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol // V V^{-1} x = x for (size_t i = 0; i < num_vars; ++i) { - PrimExpr e = IntImm(system_to_solve->variables[i].dtype(), 0); + PrimExpr e = IntImm(system_to_solve->variables[i].ty(), 0); for (size_t j = 0; j < num_vars; ++j) { - e = e + tirx::MakeConst(e.dtype(), V[i][j]) * solution_for_V_inv_x[j]; + e = e + tirx::MakeConst(e.ty(), V[i][j]) * solution_for_V_inv_x[j]; } e = analyzer_problem->Simplify(e); old_to_new_map.Set(system_to_solve->variables[i], e); diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index 80d064f71157..14b1affb9927 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -24,9 +24,9 @@ #include #include #include +#include #include #include -#include #include #include #include @@ -91,10 +91,12 @@ class NormalizeComparisons : public ExprMutator { template PrimExpr Make(const PrimExpr& a, const PrimExpr& b) { // rewrite LT to LE for ints - if (std::is_same::value && (a.dtype().is_int() || a.dtype().is_uint())) { - return LE(analyzer_->Simplify(a - b + 1), IntImm(a.dtype(), 0)); + PrimType a_ty = a.ty(); + if (std::is_same::value && + (a_ty.code() == DLDataTypeCode::kDLInt || a_ty.code() == DLDataTypeCode::kDLUInt)) { + return LE(analyzer_->Simplify(a - b + 1), IntImm(a.ty(), 0)); } - return T(analyzer_->Simplify(a - b), IntImm(a.dtype(), 0)); + return T(analyzer_->Simplify(a - b), IntImm(a.ty(), 0)); } arith::Analyzer analyzer_; }; @@ -248,11 +250,12 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t for (const auto& pos : coef_pos) { for (const auto& neg : coef_neg) { auto first_gcd = ExtendedEuclidean(pos.first, -neg.first, &gcd_x, &gcd_y); - PrimExpr c_pos = MakeConst(v.dtype(), neg.first / first_gcd); - PrimExpr c_neg = MakeConst(v.dtype(), pos.first / first_gcd); + PrimType v_ty = v.ty(); + PrimExpr c_pos = MakeConst(v_ty, neg.first / first_gcd); + PrimExpr c_neg = MakeConst(v_ty, pos.first / first_gcd); // eliminate the current variable PrimExpr new_lhs = c_neg * neg.second - c_pos * pos.second; - PrimExpr new_ineq = LE(new_lhs, IntImm(pos.second.dtype(), 0)); + PrimExpr new_ineq = LE(new_lhs, IntImm(pos.second.ty(), 0)); // we need rewrite_simplify -> canonical_simplify -> rewrite_simplify // to help simplify things like (((y + 10) - (-1*(y - 20))) <= 0) => y - 5 <= 0 // with steps = 2 it's (y*2) - 10 <= 0 @@ -281,7 +284,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t lower_bounds.reserve(coef_neg.size()); for (const auto& pos : coef_pos) { - PrimExpr bound = MakeConst(v.dtype(), -coef_lcm / pos.first) * pos.second; + PrimExpr bound = MakeConst(v.ty(), -coef_lcm / pos.first) * pos.second; bound = analyzer->Simplify(bound, kSimplifyRewriteCanonicalRewrite); // Don't add if any of the existing bounds is better if (std::any_of(upper_bounds.begin(), upper_bounds.end(), @@ -302,7 +305,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t upper_bounds.push_back(bound); } for (const auto& neg : coef_neg) { - PrimExpr bound = MakeConst(v.dtype(), -coef_lcm / neg.first) * neg.second; + PrimExpr bound = MakeConst(v.ty(), -coef_lcm / neg.first) * neg.second; bound = analyzer->Simplify(bound, kSimplifyRewriteCanonicalRewrite); // Don't add if any of the existing bounds is better if (std::any_of(lower_bounds.begin(), lower_bounds.end(), @@ -330,7 +333,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t std::sort(equal_list.begin(), equal_list.end(), ExprLess()); // Write it to the result. - IntGroupBounds bnds(MakeConst(v.dtype(), coef_lcm), + IntGroupBounds bnds(MakeConst(v.ty(), coef_lcm), ffi::Array(lower_bounds.begin(), lower_bounds.end()), ffi::Array(equal_list.begin(), equal_list.end()), ffi::Array(upper_bounds.begin(), upper_bounds.end())); @@ -509,7 +512,7 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ analyzer->Simplify(var - Substitute(best_range->min, res_dst_to_src))); // Add the new var to the resulting axis - auto range = Range(IntImm(new_var.dtype(), 0), best_range->extent); + auto range = Range(IntImm(new_var.ty(), 0), best_range->extent); res_variables.push_back(new_var); res_ranges.Set(new_var, range); diff --git a/src/arith/transitive_comparison_analyzer.cc b/src/arith/transitive_comparison_analyzer.cc index 20fd05169f43..e6465ad3cf93 100644 --- a/src/arith/transitive_comparison_analyzer.cc +++ b/src/arith/transitive_comparison_analyzer.cc @@ -615,7 +615,8 @@ CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs const PrimExpr& rhs_expr, bool propagate_inequalities) const { // Currently only supports integer checks - if (!lhs_expr.dtype().is_int() || !rhs_expr.dtype().is_int()) { + if (lhs_expr.ty().code() != DLDataTypeCode::kDLInt || + rhs_expr.ty().code() != DLDataTypeCode::kDLInt) { return CompareResult::kUnknown; } diff --git a/src/arith/unwrap_vector_expr.cc b/src/arith/unwrap_vector_expr.cc index e9245c48a102..dfe7a3cf404b 100644 --- a/src/arith/unwrap_vector_expr.cc +++ b/src/arith/unwrap_vector_expr.cc @@ -58,14 +58,16 @@ class Scalarizer : public ExprMutator { } } PrimExpr VisitExpr_(const LetNode* op) final { - if (op->value.dtype().lanes() == 1) { + PrimType value_ty = op->value.ty(); + if (value_ty.lanes() == 1) { return ExprMutator::VisitExpr_(op); } auto it = let_var_remap_.find(op->var.get()); TVM_FFI_ICHECK(it == let_var_remap_.end()) << "Duplicate binding of variable " << op->var; - Var new_var(op->var->name_hint + "_scalar", op->var.dtype().element_of()); + PrimType var_ty = op->var.ty(); + Var new_var(op->var->name_hint + "_scalar", var_ty.WithLanes(1)); let_var_remap_[op->var.get()] = new_var; PrimExpr value = this->VisitExpr(op->value); diff --git a/src/arith/z3_prover.cc b/src/arith/z3_prover.cc index 604815c97955..9ceb156dead8 100644 --- a/src/arith/z3_prover.cc +++ b/src/arith/z3_prover.cc @@ -50,10 +50,10 @@ #include #include "tvm/ffi/cast.h" +#include "tvm/ffi/dtype.h" #include "tvm/ffi/object.h" #include "tvm/ffi/string.h" #include "tvm/ir/expr.h" -#include "tvm/runtime/data_type.h" #include "z3++.h" namespace tvm::arith { @@ -147,14 +147,14 @@ class Z3Prover::Impl : ExprFunctor { /// @brief Create a Free z3 expression from PrimExprNode z3::expr Create(const PrimExprNode* op) { auto ref = ffi::GetRef(op); - auto dtype = op->dtype; + PrimType dtype = op->ty(); std::string name = ns.GetNewName(ref); /// TVM max_val can't handle uint64 max correctly, so we special case it here - if (dtype.is_bool()) { + if (dtype.MatchesCode(DLDataTypeCode::kDLBool)) { return ctx->bool_const(name.c_str()); } else { z3::expr e = ctx->int_const(name.c_str()); - if (dtype.is_uint() && dtype.bits() == 64) { + if (dtype.MatchesCode(DLDataTypeCode::kDLUInt) && dtype.bits() == 64) { solver.add(ctx->int_val(0) <= e && e <= ctx->int_val((uint64_t)UINT64_MAX)); } else { auto min_val = min_value(dtype).as_or_throw()->value; @@ -249,7 +249,7 @@ class Z3Prover::Impl : ExprFunctor { // solver) must degrade to "cannot prove" instead of escaping to the caller. try { if (CheckTrivilBadCases(expr)) return false; - if (!IsValidDType(expr->dtype)) return false; + if (!IsValidType(expr.ty())) return false; z3::expr_vector constr(*ctx); constr.push_back(!ConvertBool(expr)); auto result = solver.check(constr); @@ -263,7 +263,7 @@ class Z3Prover::Impl : ExprFunctor { /// @brief Binded /// @brief Bind a variable to a value or a range void Bind(const Var& var, const PrimExpr& value, bool allow_override = false) { - if (!IsValidDType(var->dtype)) return; + if (!IsValidType(var.ty())) return; scope_stack_.back().push_back(Scope{Scope::BindValue, var, value}); // we add the binding whenever the value is pure, // because non-pure parts are handling by creating free variables in VisitExpr @@ -272,7 +272,7 @@ class Z3Prover::Impl : ExprFunctor { /// @brief Bind a variable to a range void Bind(const Var& var, const Range& range, bool allow_override = false) { - if (!IsValidDType(var->dtype)) return; + if (!IsValidType(var.ty())) return; scope_stack_.back().push_back( Scope{Scope::BindRange, var, PrimExpr(), range->min, range->extent}); // 1. Create a placeholder for the var, and save it in the memo @@ -427,7 +427,7 @@ class Z3Prover::Impl : ExprFunctor { * \return Number of satisfying values, -1 on error, -2 if min_consecutive constraint not met */ int64_t CountSatisfyingValues(const Var& var, int64_t max_count, int64_t min_consecutive = 1) { - if (!IsValidDType(var->dtype)) { + if (!IsValidType(var.ty())) { return -1; } @@ -550,12 +550,14 @@ class Z3Prover::Impl : ExprFunctor { } return e->IsInstance() || e->IsInstance() || e->IsInstance() || e->IsInstance() || - (e->IsInstance() && !IsValidDType(e.as_or_throw()->value->dtype)); + (e->IsInstance() && !IsValidType(e.as_or_throw()->value.ty())); } /// @brief Check if the dtype is valid for z3 integer operations - static bool IsValidDType(const DataType& dtype) { - return (dtype.is_int() || dtype.is_uint() || dtype.is_bool()) && dtype.lanes() == 1; + static bool IsValidType(const PrimType& dtype) { + return dtype.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt, + DLDataTypeCode::kDLBool) && + dtype.lanes() == 1; } /// @brief Visit the expression and convert it into z3 integer expression @@ -581,7 +583,7 @@ class Z3Prover::Impl : ExprFunctor { /// @brief Helper function to visit binary arithmetic operations z3::expr VisitArith(Z3BinOp signed_op, const PrimExprNode* op, const PrimExpr& a, const PrimExpr& b) { - if (IsValidDType(a->dtype) && IsValidDType(b->dtype)) { + if (IsValidType(a.ty()) && IsValidType(b.ty())) { return signed_op(VisitInt(a), VisitInt(b)); } else { return Create(op); @@ -589,14 +591,14 @@ class Z3Prover::Impl : ExprFunctor { } z3::expr VisitExpr_(const LetNode* op) override { - if (IsValidDType(op->var->dtype)) { + if (IsValidType(op->var.ty())) { memo_.emplace(op->var, VisitInt(op->value)); } return VisitExpr(op->body); } z3::expr VisitExpr_(const CastNode* op) override { // if the inner dtype is valid, we just visit it - if (IsValidDType(op->value->dtype) && IsValidDType(op->dtype)) { + if (IsValidType(op->value.ty()) && IsValidType(op->ty())) { return VisitInt(op->value); } else { // otherwise, we create a new free z3 variable @@ -696,7 +698,7 @@ class Z3Prover::Impl : ExprFunctor { } else if (op->op.same_as(tirx::builtin::shift_right())) { return VisitShiftOp(z3::ashr, op); } else if (op->op.same_as(tirx::builtin::if_then_else()) && op->args.size() == 3 && - IsValidDType(op->args[1]->dtype) && IsValidDType(op->args[2]->dtype)) { + IsValidType(op->args[1].ty()) && IsValidType(op->args[2].ty())) { // tir.if_then_else(cond, a, b) is a select-like ternary. return z3::ite(VisitBool(op->args[0]), VisitInt(op->args[1]), VisitInt(op->args[2])); } else { @@ -715,9 +717,9 @@ class Z3Prover::Impl : ExprFunctor { const PrimExpr& a = op->args[0]; const PrimExpr& b = op->args[1]; - unsigned bit_width = std::max(op->args[0].dtype().bits(), op->args[1].dtype().bits()); + unsigned bit_width = std::max(op->args[0].ty().bits(), op->args[1].ty().bits()); - if (IsValidDType(a->dtype) && IsValidDType(b->dtype)) { + if (IsValidType(a.ty()) && IsValidType(b.ty())) { return z3::bv2int( op_func(z3::int2bv(bit_width, VisitInt(a)), z3::int2bv(bit_width, VisitInt(b))), true); } else { @@ -734,9 +736,9 @@ class Z3Prover::Impl : ExprFunctor { const PrimExpr& a = op->args[0]; - if (IsValidDType(a->dtype)) { + if (IsValidType(a.ty())) { // Cast integer to bit-vector, apply bitwise not, then cast back. - unsigned bit_width = a.dtype().bits(); + unsigned bit_width = a.ty().bits(); z3::expr a_int = VisitInt(a); z3::expr a_bv = z3::int2bv(bit_width, a_int); return z3::bv2int(~a_bv, true); @@ -756,7 +758,7 @@ class Z3Prover::Impl : ExprFunctor { const PrimExpr& b = op->args[1]; // Shift operations require integer types for both operands - if (IsValidDType(a->dtype) && IsValidDType(b->dtype)) { + if (IsValidType(a.ty()) && IsValidType(b.ty())) { z3::expr a_expr = VisitInt(a); z3::expr b_expr = VisitInt(b); @@ -765,7 +767,7 @@ class Z3Prover::Impl : ExprFunctor { // matching push/pop in this path, so the assertion would permanently // poison the shared solver and make all subsequent unrelated proofs about // `b` unsound. - unsigned bit_width = std::max(a.dtype().bits(), b.dtype().bits()); + unsigned bit_width = std::max(a.ty().bits(), b.ty().bits()); z3::expr a_bv = z3::int2bv(bit_width, a_expr); z3::expr b_bv = z3::int2bv(bit_width, b_expr); diff --git a/src/backend/cuda/codegen/codegen_cuda.cc b/src/backend/cuda/codegen/codegen_cuda.cc index 0f2838014b28..0d70d9aef3fd 100644 --- a/src/backend/cuda/codegen/codegen_cuda.cc +++ b/src/backend/cuda/codegen/codegen_cuda.cc @@ -56,13 +56,32 @@ bool IsOp(const tirx::CallNode* call, const Op& compat_op, const char* canonical return op_node != nullptr && op_node->name == canonical_name; } +bool IsCUDAFloat8(DLDataTypeCode code) { + return code == DLDataTypeCode::kDLFloat8_e3m4 || code == DLDataTypeCode::kDLFloat8_e4m3 || + code == DLDataTypeCode::kDLFloat8_e4m3b11fnuz || + code == DLDataTypeCode::kDLFloat8_e4m3fn || code == DLDataTypeCode::kDLFloat8_e4m3fnuz || + code == DLDataTypeCode::kDLFloat8_e5m2 || code == DLDataTypeCode::kDLFloat8_e5m2fnuz || + code == DLDataTypeCode::kDLFloat8_e8m0fnu; +} + +bool IsCUDAFloat6(DLDataTypeCode code) { + return code == DLDataTypeCode::kDLFloat6_e2m3fn || code == DLDataTypeCode::kDLFloat6_e3m2fn; +} + +bool IsCUDAFloat4(DLDataTypeCode code) { return code == DLDataTypeCode::kDLFloat4_e2m1fn; } + +bool IsCUDAPackedFloat(DLDataTypeCode code) { + return IsCUDAFloat8(code) || IsCUDAFloat6(code) || IsCUDAFloat4(code); +} + } // namespace -std::string GetFP8Type(DataType type) { +std::string GetFP8Type(DLDataType type) { + PrimType type_ty(type); std::stringstream stream; - int32_t lanes = type.lanes(); + int32_t lanes = type_ty.lanes(); std::string vec; - if (type.is_scalar()) { + if (type_ty.IsScalar()) { vec = ""; } else if (lanes == 2) { vec = "x2"; @@ -78,11 +97,12 @@ std::string GetFP8Type(DataType type) { } stream << "__nv_fp8"; std::string suffix; - if (type.code() == DataType::kFloat8_e4m3fn) { + DLDataTypeCode code = type_ty.code(); + if (code == DLDataTypeCode::kDLFloat8_e4m3fn) { suffix = "_e4m3"; - } else if (type.code() == DataType::kFloat8_e5m2) { + } else if (code == DLDataTypeCode::kDLFloat8_e5m2) { suffix = "_e5m2"; - } else if (type.code() == DataType::kFloat8_e8m0fnu) { + } else if (code == DLDataTypeCode::kDLFloat8_e8m0fnu) { suffix = "_e8m0"; } else { TVM_FFI_THROW(InternalError) << "Unsupported FP8 type in CUDA codegen"; @@ -91,11 +111,12 @@ std::string GetFP8Type(DataType type) { return stream.str(); } -std::string GetFP6Type(DataType type) { +std::string GetFP6Type(DLDataType type) { + PrimType type_ty(type); std::stringstream stream; - int32_t lanes = type.lanes(); + int32_t lanes = type_ty.lanes(); std::string vec; - if (type.is_scalar()) { + if (type_ty.IsScalar()) { vec = ""; } else if (lanes == 2) { vec = "x2"; @@ -110,9 +131,10 @@ std::string GetFP6Type(DataType type) { } stream << "__nv_fp6"; std::string suffix; - if (type.code() == DataType::kFloat6_e2m3fn) { + DLDataTypeCode code = type_ty.code(); + if (code == DLDataTypeCode::kDLFloat6_e2m3fn) { suffix = "_e2m3"; - } else if (type.code() == DataType::kFloat6_e3m2fn) { + } else if (code == DLDataTypeCode::kDLFloat6_e3m2fn) { suffix = "_e3m2"; } else { TVM_FFI_THROW(InternalError) << "Unsupported FP6 type in CUDA codegen"; @@ -121,11 +143,12 @@ std::string GetFP6Type(DataType type) { return stream.str(); } -std::string GetFP4Type(DataType type) { +std::string GetFP4Type(DLDataType type) { + PrimType type_ty(type); std::stringstream stream; - int32_t lanes = type.lanes(); + int32_t lanes = type_ty.lanes(); std::string vec; - if (type.is_scalar()) { + if (type_ty.IsScalar()) { vec = ""; } else if (lanes == 2) { vec = "x2"; @@ -140,7 +163,8 @@ std::string GetFP4Type(DataType type) { } stream << "__nv_fp4"; std::string suffix; - if (type.code() == DataType::kFloat4_e2m1fn) { + DLDataTypeCode code = type_ty.code(); + if (code == DLDataTypeCode::kDLFloat4_e2m1fn) { suffix = "_e2m1"; } else { TVM_FFI_THROW(InternalError) << "Unsupported FP4 type in CUDA codegen"; @@ -299,31 +323,34 @@ void CodeGenCUDA::BindThreadIndex(const IterVar& iv) { ";\" : \"=r\"(ctaid) :);\n" " return ctaid;\n" "}\n"); - var_idmap_[iv->var.get()] = CastFromTo(func_name + "()", DataType::UInt(32), iv->var.dtype()); + var_idmap_[iv->var.get()] = + CastFromTo(func_name + "()", DLDataType{kDLUInt, 32, 1}, iv->var.ty()->dtype); } else { - var_idmap_[iv->var.get()] = CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype()); + var_idmap_[iv->var.get()] = + CastFromTo(iv->thread_tag, DLDataType{kDLUInt, 32, 1}, iv->var.ty()->dtype); } } -void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) +void CodeGenCUDA::PrintType(DLDataType raw_t, std::ostream& os) { // NOLINT(*) + PrimType t(raw_t); int lanes = t.lanes(); - if (t.is_handle()) { - TVM_FFI_ICHECK(t.is_scalar()) << "do not yet support vector types"; + if (t.IsHandle()) { + TVM_FFI_ICHECK(t.IsScalar()) << "do not yet support vector types"; os << "void*"; return; } - if (t.is_void()) { + if (t.IsVoid()) { os << "void"; return; } bool fail = false; - if (t.is_float()) { + if (t.code() == DLDataTypeCode::kDLFloat) { switch (t.bits()) { case 16: codegen_tags_.insert("fp16"); - if (t.is_scalar()) { + if (t.IsScalar()) { os << "half"; } else if (lanes <= 8) { TVM_FFI_ICHECK_EQ(lanes % 2, 0) << "Only support an even number of lanes for half type"; @@ -360,15 +387,15 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) fail = true; break; } - if (!fail && (t.is_scalar() || t.bits() == 16)) return; + if (!fail && (t.IsScalar() || t.bits() == 16)) return; if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32)) return; if (!fail && (lanes >= 2 && lanes <= 4)) { os << lanes; return; } - } else if (t.is_bfloat16()) { + } else if (t.MatchesElementType(DLDataTypeCode::kDLBfloat, 16)) { codegen_tags_.insert("bf16"); - if (t.is_scalar()) { + if (t.IsScalar()) { os << "nv_bfloat16"; } else if (lanes <= 8) { TVM_FFI_ICHECK_EQ(lanes % 2, 0) << "only support even lane for bfloat16 type"; @@ -381,57 +408,65 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) fail = true; } if (!fail) return; - } else if (t.is_float8()) { + } else if (t.code() == DLDataTypeCode::kDLFloat8_e3m4 || + t.code() == DLDataTypeCode::kDLFloat8_e4m3 || + t.code() == DLDataTypeCode::kDLFloat8_e4m3b11fnuz || + t.code() == DLDataTypeCode::kDLFloat8_e4m3fn || + t.code() == DLDataTypeCode::kDLFloat8_e4m3fnuz || + t.code() == DLDataTypeCode::kDLFloat8_e5m2 || + t.code() == DLDataTypeCode::kDLFloat8_e5m2fnuz || + t.code() == DLDataTypeCode::kDLFloat8_e8m0fnu) { codegen_tags_.insert("fp8"); - if (t.lanes() <= 4) { - os << GetFP8Type(t); + if (lanes <= 4) { + os << GetFP8Type(raw_t); } else { - os << "uint" << t.lanes() / 4; + os << "uint" << lanes / 4; } return; - } else if (t.is_float6()) { + } else if (t.code() == DLDataTypeCode::kDLFloat6_e2m3fn || + t.code() == DLDataTypeCode::kDLFloat6_e3m2fn) { codegen_tags_.insert("fp6"); - if (t.lanes() <= 4) { - os << GetFP6Type(t); + if (lanes <= 4) { + os << GetFP6Type(raw_t); } else { fail = true; } return; - } else if (t.is_float4()) { + } else if (t.code() == DLDataTypeCode::kDLFloat4_e2m1fn) { codegen_tags_.insert("fp4"); - if (t.lanes() <= 4) { - os << GetFP4Type(t); + if (lanes <= 4) { + os << GetFP4Type(raw_t); } else { fail = true; } return; - } else if (t == DataType::Bool()) { + } else if (raw_t == DLDataType{kDLBool, 8, 1}) { os << "bool"; return; - } else if (t.is_vector_bool()) { + } else if (t.code() == DLDataTypeCode::kDLBool && lanes > 1) { // CUDA does not support bool vectors. // Use ushort vectors to represent instead. - int n = t.lanes(); + int n = lanes; if (n <= 4) { os << "ushort" << n; return; } - } else if (t.is_uint() || t.is_int()) { - if (t.is_uint()) { + } else if (t.MatchesCode(DLDataTypeCode::kDLUInt, DLDataTypeCode::kDLInt)) { + if (t.MatchesCode(DLDataTypeCode::kDLUInt)) { os << "u"; } switch (t.bits()) { case 1: { - if (t.is_scalar()) { + if (t.IsScalar()) { os << "int"; return; - } else if (t.lanes() == 8) { + } else if (lanes == 8) { os << "int8_t"; return; - } else if (t.lanes() == 16) { + } else if (lanes == 16) { os << "int16_t"; return; - } else if (t.lanes() == 32) { + } else if (lanes == 32) { os << "int"; return; } else { @@ -439,23 +474,23 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } } case 4: { - if (t.is_scalar()) { + if (t.IsScalar()) { os << "int"; return; - } else if (t.lanes() == 4) { + } else if (lanes == 4) { os << "int16_t"; return; - } else if (t.lanes() == 8) { + } else if (lanes == 8) { // directly 8 4-bit int in integer. os << "int"; return; - } else if (t.lanes() == 16) { + } else if (lanes == 16) { os << "int2"; return; - } else if (t.lanes() == 32) { + } else if (lanes == 32) { os << "int4"; return; - } else if (t.lanes() == 64) { + } else if (lanes == 64) { os << "int8"; return; } else { @@ -463,7 +498,7 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } } case 8: { - if (t.lanes() == 4) { + if (lanes == 4) { // directly 4 8 bit int in integer. codegen_tags_.insert("int8"); @@ -472,15 +507,15 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) // into 32-bit data. os << "int"; return; - } else if (t.lanes() == 8) { + } else if (lanes == 8) { codegen_tags_.insert("int8"); os << "int2"; return; - } else if (t.lanes() == 16) { + } else if (lanes == 16) { codegen_tags_.insert("int8"); os << "int4"; return; - } else if (!t.is_uint() && t.is_scalar()) { + } else if (!t.MatchesCode(DLDataTypeCode::kDLUInt) && t.IsScalar()) { os << "signed char"; break; } else { @@ -489,11 +524,11 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } } case 16: { - if (t.is_scalar()) { + if (t.IsScalar()) { os << "short"; - } else if (t.lanes() <= 4) { + } else if (lanes <= 4) { os << "short" << lanes; - } else if (t.lanes() <= 8) { + } else if (lanes <= 8) { // Emit CUDA code to access int16 vector elements. // // short4 is stored as int2 @@ -503,9 +538,8 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) // s4.z is emitted as *(short2*)(&(i2.y)).x // s4.w is emitted as *(short2*)(&(i2.y)).y // - TVM_FFI_ICHECK_EQ(t.lanes() % 2, 0) - << "only support even lane for shorT type with lanes > 4"; - os << "int" << t.lanes() / 2; + TVM_FFI_ICHECK_EQ(lanes % 2, 0) << "only support even lane for shorT type with lanes > 4"; + os << "int" << lanes / 2; } else { fail = true; } @@ -515,11 +549,11 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) break; } case 32: { - if (t.is_scalar()) { + if (t.IsScalar()) { os << "int"; - } else if (t.lanes() <= 4) { - os << "int" << t.lanes(); - } else if (t.lanes() <= 8) { + } else if (lanes <= 4) { + os << "int" << lanes; + } else if (lanes <= 8) { // Emit CUDA code to access int32 vector elements for 4 < lanes <= 8. // // int8 is stored as longlong4 @@ -538,13 +572,13 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) break; } case 64: { - if (t.is_scalar()) { + if (t.IsScalar()) { os << "int64_t"; - } else if (t.lanes() == 2) { + } else if (lanes == 2) { os << "longlong2"; - } else if (t.lanes() == 3) { + } else if (lanes == 3) { os << "longlong3"; - } else if (t.lanes() == 4) { + } else if (lanes == 4) { os << "longlong4"; } return; @@ -561,15 +595,16 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) return; } } - TVM_FFI_THROW(InternalError) << "Cannot convert type " << t << " to CUDA type"; + TVM_FFI_THROW(InternalError) << "Cannot convert type " << ffi::DLDataTypeToString(raw_t) + << " to CUDA type"; } -void CodeGenCUDA::PrintVecConstructor(DataType t, std::ostream& os) { +void CodeGenCUDA::PrintVecConstructor(DLDataType t, std::ostream& os) { os << "make_"; PrintType(t, os); } -void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, +void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DLDataType t, PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*) // Declare the result. std::string sret = name_supply_->FreshName("_"); @@ -579,22 +614,22 @@ void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr l int ssa_scope = BeginScope(); { // Unpack into individual ops. - std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype()); - std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype()); + std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.ty()->dtype); + std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.ty()->dtype); - for (int i = 0, lanes = t.lanes(); i < lanes; ++i) { + for (int i = 0, lanes = PrimType(t).lanes(); i < lanes; ++i) { std::ostringstream value_temp; if (isalpha(op[0])) { value_temp << op << "("; - PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); + PrintVecElemLoad(vlhs, lhs.ty()->dtype, i, value_temp); value_temp << ", "; - PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); + PrintVecElemLoad(vrhs, rhs.ty()->dtype, i, value_temp); value_temp << ")"; } else { value_temp << "("; - PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); + PrintVecElemLoad(vlhs, lhs.ty()->dtype, i, value_temp); value_temp << op; - PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); + PrintVecElemLoad(vrhs, rhs.ty()->dtype, i, value_temp); value_temp << ")"; } PrintVecElemStore(sret, t, i, value_temp.str()); @@ -604,55 +639,58 @@ void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr l os << sret; } -void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, +void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DLDataType t, int i, std::ostream& os) { // NOLINT(*) - if (t.is_scalar()) { + PrimType t_ty(t); + int lanes = t_ty.lanes(); + if (t_ty.IsScalar()) { os << vec; return; } static const char access[] = {'x', 'y', 'z', 'w'}; - TVM_FFI_ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4)); - if (t.bits() == 8 && (t.is_int() || t.is_uint())) { - std::string type_name = t.is_int() ? "signed char" : "unsigned char"; - if (t.lanes() == 2 || t.lanes() == 3) { - os << vec << "." << access[i % t.lanes()]; + TVM_FFI_ICHECK(i >= 0 && i < (t.bits == 8 ? 16 : (t.bits == 16 || t.bits == 32) ? 8 : 4)); + if (t.bits == 8 && (t_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt))) { + std::string type_name = + t_ty.MatchesCode(DLDataTypeCode::kDLInt) ? "signed char" : "unsigned char"; + if (lanes == 2 || lanes == 3) { + os << vec << "." << access[i % lanes]; } else { - std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]); + std::string ac = lanes == 4 ? vec : (vec + "." + access[i / 4]); os << "(reinterpret_cast(&(" << ac << "))[" << (i % 4) << "])"; } - } else if (t.is_float16()) { - if (t.lanes() <= 4) { + } else if (t_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16)) { + if (lanes <= 4) { os << vec << "." << access[i]; } else { os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; } - } else if (t.is_bfloat16()) { - if (t.lanes() <= 4) { + } else if (t_ty.MatchesElementType(DLDataTypeCode::kDLBfloat, 16)) { + if (lanes <= 4) { os << vec << "." << access[i]; } else { os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; } - } else if (t.lanes() > 4 && t.lanes() <= 8) { + } else if (lanes > 4 && lanes <= 8) { std::string type_name; - if (t.bits() == 16) { - if (t.is_int()) { + if (t.bits == 16) { + if (t_ty.MatchesCode(DLDataTypeCode::kDLInt)) { type_name = "short"; - } else if (t.is_uint()) { + } else if (t_ty.MatchesCode(DLDataTypeCode::kDLUInt)) { type_name = "ushort"; } - } else if (t.bits() == 32) { - if (t.is_int()) { + } else if (t.bits == 32) { + if (t_ty.MatchesCode(DLDataTypeCode::kDLInt)) { type_name = "int"; - } else if (t.is_uint()) { + } else if (t_ty.MatchesCode(DLDataTypeCode::kDLUInt)) { type_name = "uint"; - } else if (t.is_float()) { + } else if (t_ty.code() == DLDataTypeCode::kDLFloat) { type_name = "float"; } } TVM_FFI_ICHECK(!type_name.empty()); os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; - } else if (t.is_float4_e2m1fn()) { + } else if (t_ty.code() == DLDataTypeCode::kDLFloat4_e2m1fn) { os << "([](__nv_fp4_storage_t v) { __nv_fp4_e2m1 t; t.__x = v; return t; })((" << vec << ".__x >> " << i * 4 << ") & 0xF)"; } else { @@ -660,50 +698,53 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, } } -void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i, +void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DLDataType t, int i, const std::string& value) { + PrimType t_ty(t); + int lanes = t_ty.lanes(); this->PrintIndent(); static const char access[] = {'x', 'y', 'z', 'w'}; - TVM_FFI_ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4)); - if (t.bits() == 8 && (t.is_int() || t.is_uint())) { - if (t.lanes() == 2 || t.lanes() == 3) { - stream << vec << '.' << access[i % t.lanes()] << "=" + TVM_FFI_ICHECK(i >= 0 && i < (t.bits == 8 ? 16 : (t.bits == 16 || t.bits == 32) ? 8 : 4)); + if (t.bits == 8 && (t_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt))) { + if (lanes == 2 || lanes == 3) { + stream << vec << '.' << access[i % lanes] << "=" << "(" << value << ");\n"; } else { - std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]); - std::string type_name = t.is_int() ? "signed char" : "unsigned char"; + std::string ac = lanes == 4 ? vec : (vec + "." + access[i / 4]); + std::string type_name = + t_ty.MatchesCode(DLDataTypeCode::kDLInt) ? "signed char" : "unsigned char"; stream << "reinterpret_cast<" << type_name << "*>(&(" << ac << "))[" << (i % 4) << "] = (" << type_name << ")(" << value << ");\n"; } - } else if (t.is_float16()) { - if (t.lanes() <= 4) { + } else if (t_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16)) { + if (lanes <= 4) { stream << vec << "." << access[i] << " = " << value << ";\n"; } else { stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = " << value << ";\n"; } - } else if (t.is_bfloat16()) { - if (t.lanes() <= 4) { + } else if (t_ty.MatchesElementType(DLDataTypeCode::kDLBfloat, 16)) { + if (lanes <= 4) { stream << vec << "." << access[i] << " = " << value << ";\n"; } else { stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = " << value << ";\n"; } - } else if (t.lanes() > 4 && t.lanes() <= 8) { + } else if (lanes > 4 && lanes <= 8) { std::string type_name; - if (t.bits() == 16) { - if (t.is_int()) { + if (t.bits == 16) { + if (t_ty.MatchesCode(DLDataTypeCode::kDLInt)) { type_name = "short"; - } else if (t.is_uint()) { + } else if (t_ty.MatchesCode(DLDataTypeCode::kDLUInt)) { type_name = "ushort"; } - } else if (t.bits() == 32) { - if (t.is_int()) { + } else if (t.bits == 32) { + if (t_ty.MatchesCode(DLDataTypeCode::kDLInt)) { type_name = "int"; - } else if (t.is_uint()) { + } else if (t_ty.MatchesCode(DLDataTypeCode::kDLUInt)) { type_name = "uint"; - } else if (t.is_float()) { + } else if (t_ty.code() == DLDataTypeCode::kDLFloat) { type_name = "float"; } } @@ -766,15 +807,19 @@ void CodeGenCUDA::PrintStorageScope(const std::string& scope, std::ostream& os) } } -std::string CodeGenCUDA::CastFromTo(std::string value, DataType from, DataType target) { +std::string CodeGenCUDA::CastFromTo(std::string value, DLDataType from, DLDataType target) { if (from == target) return value; + PrimType from_ty(from); + PrimType target_ty(target); std::ostringstream os; os << "(("; this->PrintType(target, os); os << ")"; - if (from.is_float16() && (target.is_int() || target.is_uint()) && target.bits() == 8) { + if (from_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16) && + (target_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) && + target.bits == 8) { os << "("; - if (target.is_uint()) { + if (target_ty.MatchesCode(DLDataTypeCode::kDLUInt)) { os << "u"; } os << "int)"; @@ -794,33 +839,22 @@ void CodeGenCUDA::AddUtilFunction(const std::string& func_name, const std::strin } void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { - DataType from_ty = op->value.dtype(); - DataType target_ty = op->dtype; + DLDataType from_dtype = op->value.ty()->dtype; + DLDataType target_dtype = op->ty()->dtype; + PrimType from_ty(from_dtype); + PrimType target_ty(target_dtype); TVM_FFI_ICHECK_EQ(target_ty.lanes(), from_ty.lanes()); // Emit simple C-style type conversion. - if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os); - - if (target_ty.code() == DataType::kFloat8_e3m4 || target_ty.code() == DataType::kFloat8_e4m3 || - target_ty.code() == DataType::kFloat8_e4m3b11fnuz || - target_ty.code() == DataType::kFloat8_e4m3fn || - target_ty.code() == DataType::kFloat8_e4m3fnuz || - target_ty.code() == DataType::kFloat8_e5m2 || - target_ty.code() == DataType::kFloat8_e5m2fnuz || - target_ty.code() == DataType::kFloat8_e8m0fnu || - target_ty.code() == DataType::kFloat4_e2m1fn || - - from_ty.code() == DataType::kFloat8_e3m4 || from_ty.code() == DataType::kFloat8_e4m3 || - from_ty.code() == DataType::kFloat8_e4m3b11fnuz || - from_ty.code() == DataType::kFloat8_e4m3fn || from_ty.code() == DataType::kFloat8_e4m3fnuz || - from_ty.code() == DataType::kFloat8_e5m2 || from_ty.code() == DataType::kFloat8_e5m2fnuz || - from_ty.code() == DataType::kFloat8_e8m0fnu || from_ty.code() == DataType::kFloat4_e2m1fn) { + if (from_ty.IsScalar()) return CodeGenC::VisitExpr_(op, os); + + if (IsCUDAPackedFloat(target_ty.code()) || IsCUDAPackedFloat(from_ty.code())) { std::ostringstream val; - if (target_ty.code() == DataType::kBFloat && target_ty.lanes() == 2) { + if (target_ty.code() == DLDataTypeCode::kDLBfloat && target_ty.lanes() == 2) { val << "cast_to_nv_bfloat162(" << PrintExpr(op->value) << ")"; } else { val << "("; - PrintType(target_ty, val); + PrintType(target_dtype, val); val << ")(" << PrintExpr(op->value) << ")"; } os << val.str(); @@ -831,18 +865,18 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { // too compact to read. Emit this as vectorized unary ops. std::string sret = name_supply_->FreshName("_"); this->PrintIndent(); - this->PrintType(target_ty, stream); + this->PrintType(target_dtype, stream); stream << ' ' << sret << ";\n"; { - std::string src = SSAGetID(PrintExpr(op->value), from_ty); + std::string src = SSAGetID(PrintExpr(op->value), from_dtype); for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) { std::ostringstream val; val << "("; - PrintType(target_ty.element_of(), val); + PrintType(DLDataType{target_dtype.code, target_dtype.bits, 1}, val); val << ")("; - PrintVecElemLoad(src, from_ty, i, val); + PrintVecElemLoad(src, from_dtype, i, val); val << ")"; - PrintVecElemStore(sret, target_ty, i, val.str()); + PrintVecElemStore(sret, target_dtype, i, val.str()); } } os << sret; @@ -851,8 +885,9 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { void CodeGenCUDA::PrintCallExtern(Type ret_type, ffi::String global_symbol, const ffi::Array& args, bool skip_first_arg, std::ostream& os) { // NOLINT(*) - DataType ret_dtype = GetRuntimeDataType(ret_type); - if (ret_dtype.is_fixed_length_vector()) { + DLDataType ret_dtype = GetRuntimeDataType(ret_type); + PrimType ret_ty(ret_dtype); + if (ret_ty.IsFixedLengthVector()) { // // Emit an unsupported vector call // @@ -881,17 +916,17 @@ void CodeGenCUDA::PrintCallExtern(Type ret_type, ffi::String global_symbol, std::vector sargs; size_t arg_begin = static_cast(skip_first_arg); for (size_t i = arg_begin; i < args.size(); ++i) { - std::string val = SSAGetID(PrintExpr(args[i]), args[i].dtype()); + std::string val = SSAGetID(PrintExpr(args[i]), args[i].ty()->dtype); sargs.push_back(std::move(val)); } // Emit a scalar call for each lane. - for (int i = 0; i < ret_dtype.lanes(); ++i) { + for (int i = 0; i < ret_ty.lanes(); ++i) { std::ostringstream scall; scall << global_symbol << "("; for (size_t j = 0; j < sargs.size(); ++j) { if (j > 0) scall << ", "; - PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall); + PrintVecElemLoad(sargs[j], args[arg_begin + j].ty()->dtype, i, scall); } scall << ")"; PrintVecElemStore(sret, ret_dtype, i, scall.str()); @@ -1196,7 +1231,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string local_ptr = this->PrintExpr(op->args[3]); std::string local_offset = this->PrintExpr(op->args[4]); std::string smem_ptr = this->PrintExpr(op->args[5]); - if (trans && op->dtype.bits() == 8) { + if (trans && op->ty()->dtype.bits == 8) { // ldmatrix can't transpose 8-bit elements (it assumes 16-bit), so // synthesize the equivalent manual gather loop. args[6] is the // shared-memory stride for this fallback. @@ -1317,39 +1352,46 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { << guard << ")\n"; stream << ");\n"; } else if (op->op.same_as(builtin::reinterpret())) { - DataType tgt_dtype = op->dtype; - DataType src_dtype = op->args[0]->dtype; + DLDataType tgt_dtype = op->ty()->dtype; + DLDataType src_dtype = op->args[0].ty()->dtype; + PrimType tgt_ty(tgt_dtype); + PrimType src_ty(src_dtype); PrimExpr value = op->args[0]; - if (src_dtype.is_handle() && tgt_dtype.is_scalar() && - (tgt_dtype.is_uint() || tgt_dtype.is_int()) && tgt_dtype.bits() == 64) { + if (src_ty.IsHandle() && tgt_ty.IsScalar() && + tgt_ty.MatchesCode(DLDataTypeCode::kDLUInt, DLDataTypeCode::kDLInt) && + tgt_dtype.bits == 64) { os << "reinterpret_cast<"; this->PrintType(tgt_dtype, os); os << ">(" << PrintExpr(value) << ")"; return; } - if (tgt_dtype.is_handle() && src_dtype.is_scalar() && - (src_dtype.is_uint() || src_dtype.is_int()) && src_dtype.bits() == 64) { + if (tgt_ty.IsHandle() && src_ty.IsScalar() && + src_ty.MatchesCode(DLDataTypeCode::kDLUInt, DLDataTypeCode::kDLInt) && + src_dtype.bits == 64) { os << "reinterpret_cast(" << PrintExpr(value) << ")"; return; } // Handle float4_e2m1fn reinterpret - if (!src_dtype.is_float4_e2m1fn() && !tgt_dtype.is_float4_e2m1fn()) { + if (!IsCUDAFloat4(src_ty.code()) && !IsCUDAFloat4(tgt_ty.code())) { return CodeGenC::VisitExpr_(op, os); } if (src_dtype == tgt_dtype || - tgt_dtype.lanes() * tgt_dtype.bits() == src_dtype.lanes() * src_dtype.bits()) { + tgt_ty.lanes() * tgt_dtype.bits == src_ty.lanes() * src_dtype.bits) { return CodeGenC::VisitExpr_(op, os); } - TVM_FFI_ICHECK_EQ(tgt_dtype.lanes(), src_dtype.lanes()) + TVM_FFI_ICHECK_EQ(tgt_ty.lanes(), src_ty.lanes()) << "E2M1 float4 reinterpret expects source and target to have the same number of lanes. " - << "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype; - TVM_FFI_ICHECK_EQ(tgt_dtype.bytes(), src_dtype.bytes()) + << "Source dtype: " << ffi::DLDataTypeToString(src_dtype) + << ", Target dtype: " << ffi::DLDataTypeToString(tgt_dtype); + TVM_FFI_ICHECK_EQ((tgt_ty.lanes() * tgt_dtype.bits + 7) / 8, + (src_ty.lanes() * src_dtype.bits + 7) / 8) << "E2M1 float4 reinterpret expects source and target to have the same number of bytes. " - << "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype; + << "Source dtype: " << ffi::DLDataTypeToString(src_dtype) + << ", Target dtype: " << ffi::DLDataTypeToString(tgt_dtype); - int lanes = tgt_dtype.lanes(); + int lanes = tgt_ty.lanes(); int ssa_scope = BeginScope(); if (lanes == 1) { @@ -1360,47 +1402,47 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { this->PrintType(tgt_dtype, os); os << " *)(&(" << rhs << ")))"; } else if (lanes == 2) { - if (tgt_dtype.is_float4_e2m1fn()) { + if (IsCUDAFloat4(tgt_ty.code())) { // We view the source as an uint16, and then extract bits of two fp4 numbers, // and finally reinterpret the result as fp4x2. - value = tirx::Call(DataType::UInt(16), tirx::builtin::reinterpret(), {value}); - tirx::Var temp_var("temp_var", DataType::UInt(16)); + value = tirx::Call(PrimType::UInt(16), tirx::builtin::reinterpret(), {value}); + tirx::Var temp_var("temp_var", PrimType::UInt(16)); value = tirx::Let(temp_var, value, - tirx::Cast(DataType::UInt(8), - (temp_var & IntImm(DataType::UInt(16), 0xF)) | - ((temp_var >> 4) & IntImm(DataType::UInt(16), 0xF0)))); + tirx::Cast(PrimType::UInt(8), + (temp_var & IntImm(PrimType::UInt(16), 0xF)) | + ((temp_var >> 4) & IntImm(PrimType::UInt(16), 0xF0)))); } else { - value = tirx::Cast(DataType::UInt(16), - tirx::Call(DataType::UInt(8), tirx::builtin::reinterpret(), {value})); - tirx::Var temp_var("temp_var", DataType::UInt(16)); + value = tirx::Cast(PrimType::UInt(16), + tirx::Call(PrimType::UInt(8), tirx::builtin::reinterpret(), {value})); + tirx::Var temp_var("temp_var", PrimType::UInt(16)); value = tirx::Let(temp_var, value, - (temp_var & IntImm(DataType::UInt(16), 0xF)) | - ((temp_var & IntImm(DataType::UInt(16), 0xF0)) << 4)); + (temp_var & IntImm(PrimType::UInt(16), 0xF)) | + ((temp_var & IntImm(PrimType::UInt(16), 0xF0)) << 4)); } - os << PrintExpr(tirx::Call(tgt_dtype, tirx::builtin::reinterpret(), {value})); + os << PrintExpr(tirx::Call(PrimType(tgt_dtype), tirx::builtin::reinterpret(), {value})); } else if (lanes == 4) { - if (tgt_dtype.is_float4_e2m1fn()) { + if (IsCUDAFloat4(tgt_ty.code())) { // We view the source as an uint32, and then extract bits of four fp4 numbers, // and finally reinterpret the result as fp4x4. - value = tirx::Call(DataType::UInt(32), tirx::builtin::reinterpret(), {value}); - tirx::Var temp_var("temp_var", DataType::UInt(32)); + value = tirx::Call(PrimType::UInt(32), tirx::builtin::reinterpret(), {value}); + tirx::Var temp_var("temp_var", PrimType::UInt(32)); value = tirx::Let(temp_var, value, - tirx::Cast(DataType::UInt(16), - (temp_var & IntImm(DataType::UInt(32), 0xF)) | - ((temp_var >> 4) & IntImm(DataType::UInt(32), 0xF0)) | - ((temp_var >> 8) & IntImm(DataType::UInt(32), 0xF00)) | - ((temp_var >> 12) & IntImm(DataType::UInt(32), 0xF000)))); + tirx::Cast(PrimType::UInt(16), + (temp_var & IntImm(PrimType::UInt(32), 0xF)) | + ((temp_var >> 4) & IntImm(PrimType::UInt(32), 0xF0)) | + ((temp_var >> 8) & IntImm(PrimType::UInt(32), 0xF00)) | + ((temp_var >> 12) & IntImm(PrimType::UInt(32), 0xF000)))); } else { - value = tirx::Cast(DataType::UInt(32), - tirx::Call(DataType::UInt(16), tirx::builtin::reinterpret(), {value})); - tirx::Var temp_var("temp_var", DataType::UInt(32)); + value = tirx::Cast(PrimType::UInt(32), + tirx::Call(PrimType::UInt(16), tirx::builtin::reinterpret(), {value})); + tirx::Var temp_var("temp_var", PrimType::UInt(32)); value = tirx::Let(temp_var, value, - (temp_var & IntImm(DataType::UInt(32), 0xF)) | - ((temp_var & IntImm(DataType::UInt(32), 0xF0)) << 4) | - ((temp_var & IntImm(DataType::UInt(32), 0xF00)) << 8) | - ((temp_var & IntImm(DataType::UInt(32), 0xF000)) << 12)); + (temp_var & IntImm(PrimType::UInt(32), 0xF)) | + ((temp_var & IntImm(PrimType::UInt(32), 0xF0)) << 4) | + ((temp_var & IntImm(PrimType::UInt(32), 0xF00)) << 8) | + ((temp_var & IntImm(PrimType::UInt(32), 0xF000)) << 12)); } - os << PrintExpr(tirx::Call(tgt_dtype, tirx::builtin::reinterpret(), {value})); + os << PrintExpr(tirx::Call(PrimType(tgt_dtype), tirx::builtin::reinterpret(), {value})); } else { TVM_FFI_THROW(InternalError) << "Invalid number of lanes for float4_e2m1fn reinterpret: " << lanes; @@ -1411,7 +1453,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { const PrimExpr& arg = op->args[0]; const auto* var_node = arg.as(); - DataType dtype = op->dtype; + DLDataType dtype = op->ty()->dtype; + PrimType dtype_ty(dtype); bool is_string = op->args[2].as()->value; bool is_scalar = op->args[3].as()->value; int num_dims = op->args[4].as()->value; @@ -1432,22 +1475,23 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { if (is_scalar) { // Scalar printing logic std::string format_specifier; - bool is_float16 = dtype.is_float() && dtype.bits() == 16; - if (dtype.is_float()) + bool is_float16 = dtype_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16); + if (dtype_ty.code() == DLDataTypeCode::kDLFloat) format_specifier = "%f"; - else if (dtype.is_int()) + else if (dtype_ty.MatchesCode(DLDataTypeCode::kDLInt)) format_specifier = "%d"; - else if (dtype.is_uint()) + else if (dtype_ty.MatchesCode(DLDataTypeCode::kDLUInt)) format_specifier = "%u"; else - TVM_FFI_THROW(InternalError) << "Unsupported data type for scalar print: " << dtype; + TVM_FFI_THROW(InternalError) + << "Unsupported data type for scalar print: " << ffi::DLDataTypeToString(dtype); std::string print_arg = var_node ? ("*" + GetVarID(var_node)) : PrintExpr(arg); os << "// print_buffer starts (scalar)\n" << "if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) {\n" - << " printf(\"Scalar (dtype: " << dtype << "): " << format_specifier << "\\n\\n\", " - << (is_float16 ? "static_cast(" : "") << print_arg << (is_float16 ? ")" : "") - << ");\n" + << " printf(\"Scalar (dtype: " << ffi::DLDataTypeToString(dtype) + << "): " << format_specifier << "\\n\\n\", " << (is_float16 ? "static_cast(" : "") + << print_arg << (is_float16 ? ")" : "") << ");\n" << "}\n" << "// print_buffer ends\n"; return; @@ -1460,19 +1504,20 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string format_specifier; bool is_float16 = false; - if (dtype.is_float()) { - if (dtype.bits() == 16) { + if (dtype_ty.code() == DLDataTypeCode::kDLFloat) { + if (dtype.bits == 16) { format_specifier = "%f"; is_float16 = true; } else { format_specifier = "%f"; } - } else if (dtype.is_int()) { + } else if (dtype_ty.MatchesCode(DLDataTypeCode::kDLInt)) { format_specifier = "%d"; - } else if (dtype.is_uint()) { + } else if (dtype_ty.MatchesCode(DLDataTypeCode::kDLUInt)) { format_specifier = "%u"; } else { - TVM_FFI_THROW(InternalError) << "Unsupported data type for print: " << dtype; + TVM_FFI_THROW(InternalError) + << "Unsupported data type for print: " << ffi::DLDataTypeToString(dtype); } TVM_FFI_ICHECK(var_node) << "Formatted print is only supported for buffer variables."; @@ -1485,7 +1530,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { for (int i = 0; i < num_dims; ++i) { os << PrintExpr(shape[i]) << (i < num_dims - 1 ? "," : ""); } - os << "), dtype=" << dtype << "):\\n\");\n"; + os << "), dtype=" << ffi::DLDataTypeToString(dtype) << "):\\n\");\n"; std::vector loop_vars; for (int i = 0; i < num_dims; ++i) { @@ -1572,7 +1617,7 @@ void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) { << "For CUDA, the index of an async queue must be 0."; this->VisitStmt(op->body); static const Op& ptx_cp_async_commit_group_op = Op::Get("tirx.ptx.cp_async_commit_group"); - auto commit_group = Call(DataType::Void(), ptx_cp_async_commit_group_op, {}); + auto commit_group = Call(PrimType::Void(), ptx_cp_async_commit_group_op, {}); this->PrintIndent(); this->VisitExpr(commit_group, this->stream); this->stream << ";\n"; @@ -1584,7 +1629,7 @@ void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) { << "For CUDA, the index of an async queue must be 0."; auto wait_cnt = wait_attrs.second; static const Op& ptx_cp_async_wait_group_op = Op::Get("tirx.ptx.cp_async_wait_group"); - auto wait_group = Call(DataType::Void(), ptx_cp_async_wait_group_op, {wait_cnt}); + auto wait_group = Call(PrimType::Void(), ptx_cp_async_wait_group_op, {wait_cnt}); this->PrintIndent(); this->VisitExpr(wait_group, this->stream); this->stream << ";\n"; @@ -1614,19 +1659,23 @@ void CodeGenCUDA::VisitStmt_(const AllocBufferNode* op) { this->PrintIndent(); std::string scope = GetPtrStorageScope(op->buffer->data); const VarNode* buffer = op->buffer->data.as(); - DataType dtype = op->buffer->dtype; + DLDataType dtype = op->buffer->dtype->dtype; if (scope.find("wmma.") == 0) { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { - TVM_FFI_ICHECK(dtype == DataType::Float(16) || dtype == DataType::Int(8) || - dtype == DataType::UInt(8) || dtype == DataType::Int(4) || - dtype == DataType::UInt(4) || dtype == DataType::Int(1) || - dtype == DataType::BFloat(16)) + bool supported_wmma_input_dtype = + dtype == DLDataType{kDLFloat, 16, 1} || dtype == DLDataType{kDLInt, 8, 1} || + dtype == DLDataType{kDLUInt, 8, 1} || dtype == DLDataType{kDLInt, 4, 1} || + dtype == DLDataType{kDLUInt, 4, 1} || dtype == DLDataType{kDLInt, 1, 1} || + dtype == DLDataType{kDLBfloat, 16, 1}; + TVM_FFI_ICHECK(supported_wmma_input_dtype) << "Matrix_a and matrix_b only support half or char or unsigned char " << "or uint4 or int4 or int1 type for now"; } else { - TVM_FFI_ICHECK(dtype == DataType::Float(16) || dtype == DataType::Float(32) || - dtype == DataType::Int(32)) + bool supported_wmma_accumulator_dtype = dtype == DLDataType{kDLFloat, 16, 1} || + dtype == DLDataType{kDLFloat, 32, 1} || + dtype == DLDataType{kDLInt, 32, 1}; + TVM_FFI_ICHECK(supported_wmma_accumulator_dtype) << "Accumulator only support half, float and int type for now"; } PrintWmmaScope(scope, dtype, buffer, stream); @@ -1662,9 +1711,11 @@ void CodeGenCUDA::VisitStmt_(const AllocBufferNode* op) { if (scope.find("wmma.") == 0) { constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); } - if ((dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1)) && - scope == "shared") { - constant_size = constant_size / (32 / dtype.bits()); + bool is_packed_integer_dtype = dtype == DLDataType{kDLInt, 4, 1} || + dtype == DLDataType{kDLUInt, 4, 1} || + dtype == DLDataType{kDLInt, 1, 1}; + if (is_packed_integer_dtype && scope == "shared") { + constant_size = constant_size / (32 / dtype.bits); } stream << ' ' << vid << '[' << constant_size << "];\n"; } @@ -1693,9 +1744,10 @@ void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) { } void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) { - int lanes = op->dtype.lanes(); + PrimType op_ty = op->ty(); + int lanes = op_ty.lanes(); if (lanes <= 4) { - PrintVecConstructor(op->dtype, os); + PrintVecConstructor(op->ty()->dtype, os); os << "("; for (int i = 0; i < lanes; i++) { os << "(" << PrintExpr(op->base) << ")" @@ -1710,16 +1762,16 @@ void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) { // constructor argument layout does not match TIR vector lane layout. std::string sret = name_supply_->FreshName("_"); this->PrintIndent(); - this->PrintType(op->dtype, stream); + this->PrintType(op->ty()->dtype, stream); stream << ' ' << sret << ";\n"; int ssa_scope = BeginScope(); { - std::string vbase = SSAGetID(PrintExpr(op->base), op->base.dtype()); - std::string vstride = SSAGetID(PrintExpr(op->stride), op->stride.dtype()); + std::string vbase = SSAGetID(PrintExpr(op->base), op->base.ty()->dtype); + std::string vstride = SSAGetID(PrintExpr(op->stride), op->stride.ty()->dtype); for (int i = 0; i < lanes; ++i) { std::ostringstream value_temp; value_temp << "(" << vbase << ")+(" << vstride << "*" << i << ")"; - PrintVecElemStore(sret, op->dtype, i, value_temp.str()); + PrintVecElemStore(sret, op->ty()->dtype, i, value_temp.str()); } } EndScope(ssa_scope); @@ -1727,14 +1779,16 @@ void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) { } void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) - int lanes = op->dtype.lanes(); - if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && lanes == 4) { + PrimType op_ty = op->ty(); + int lanes = op_ty.lanes(); + if ((op_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) && op_ty.bits() == 8 && + lanes == 4) { // make_int8x4 const int64_t* p = as_const_int(op->value); TVM_FFI_ICHECK(p); int64_t v = *p & 0xFF; v = (v << 24) | (v << 16) | (v << 8) | v; - if (op->dtype.is_uint()) { + if (op_ty.MatchesCode(DLDataTypeCode::kDLUInt)) { os << "(uint)" << v; } else { os << "(int)" << v; @@ -1742,9 +1796,9 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO return; } - if (op->dtype.is_float16()) { + if (op_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16)) { std::string v = PrintExpr(op->value); - PrintVecConstructor(op->dtype, os); + PrintVecConstructor(op->ty()->dtype, os); os << '('; if (lanes <= 4) { for (int i = 0; i < lanes / 2; ++i) { @@ -1761,9 +1815,9 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO return; } - if (op->dtype.is_bfloat16()) { + if (op_ty.MatchesElementType(DLDataTypeCode::kDLBfloat, 16)) { std::string v = PrintExpr(op->value); - PrintVecConstructor(op->dtype, os); + PrintVecConstructor(op->ty()->dtype, os); os << '('; if (lanes > 4) { for (int i = 0; i < lanes / 2; ++i) { @@ -1780,12 +1834,11 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO return; } - if (op->dtype.is_float8() || op->dtype.is_float4()) { - int lanes = op->dtype.lanes(); + if (IsCUDAFloat8(op_ty.code()) || IsCUDAFloat4(op_ty.code())) { TVM_FFI_ICHECK(lanes == 1 || lanes == 2 || lanes == 4); std::string v = PrintExpr(op->value); // Implicit conversion from float back to fp8 - PrintType(op->dtype, os); + PrintType(op->ty()->dtype, os); os << "(make_float" << lanes << "("; for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; @@ -1795,7 +1848,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO return; } - if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) { + if ((op_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) && op_ty.bits() == 4) { bool fail = false; const int64_t* p = as_const_int(op->value); TVM_FFI_ICHECK(p); @@ -1803,7 +1856,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO if (lanes == 4) { v = (v << 12) | (v << 8) | (v << 4) | v; - if (op->dtype.is_uint()) { + if (op_ty.MatchesCode(DLDataTypeCode::kDLUInt)) { os << "(uint16_t)" << v; } else { os << "(int16_t)" << v; @@ -1811,17 +1864,17 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO } else { v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | (v << 4) | v; if (lanes == 8) { - if (op->dtype.is_uint()) { + if (op_ty.MatchesCode(DLDataTypeCode::kDLUInt)) { os << "(uint)" << v; } else { os << "(int)" << v; } } else if (lanes == 16 || lanes == 32) { - PrintVecConstructor(op->dtype, os); + PrintVecConstructor(op->ty()->dtype, os); os << '('; for (int i = 0; i < lanes / 8; ++i) { if (i != 0) os << ", "; - if (op->dtype.is_uint()) { + if (op_ty.MatchesCode(DLDataTypeCode::kDLUInt)) { os << "(uint)" << v; } else { os << "(int)" << v; @@ -1839,7 +1892,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO } std::string v = PrintExpr(op->value); - PrintVecConstructor(op->dtype, os); + PrintVecConstructor(op->ty()->dtype, os); os << '('; for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; @@ -1849,47 +1902,49 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO } void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream& os) { + PrimType op_ty = op->ty(); // Non-vector cases. - if (!op->dtype.is_fixed_length_vector()) { + if (!op_ty.IsFixedLengthVector()) { CodeGenC::VisitExpr_(op, os); return; } // Codegen vector condition case by serializing the select op. - TVM_FFI_ICHECK(op->false_value->dtype == op->dtype && op->true_value->dtype == op->dtype && - op->dtype.lanes() == op->condition.dtype().lanes()); + TVM_FFI_ICHECK(op->false_value.ty() == op_ty && op->true_value.ty() == op_ty && + op_ty.lanes() == op->condition.ty().lanes()); std::string r_var = name_supply_->FreshName("_"); this->PrintIndent(); - this->PrintType(op->dtype, stream); + this->PrintType(op->ty()->dtype, stream); stream << ' ' << r_var << ";\n"; { - std::string c_var = SSAGetID(PrintExpr(op->condition), op->dtype); - std::string t_var = SSAGetID(PrintExpr(op->true_value), op->dtype); - std::string f_var = SSAGetID(PrintExpr(op->false_value), op->dtype); + std::string c_var = SSAGetID(PrintExpr(op->condition), op->ty()->dtype); + std::string t_var = SSAGetID(PrintExpr(op->true_value), op->ty()->dtype); + std::string f_var = SSAGetID(PrintExpr(op->false_value), op->ty()->dtype); // The condition is stored as an ushort vector. - int lanes = op->dtype.lanes(); - DataType memory_ty(DataType::TypeCode::kUInt, 16, lanes); + int lanes = op_ty.lanes(); + DLDataType memory_dtype{kDLUInt, 16, static_cast(lanes)}; for (int i = 0; i < lanes; ++i) { std::ostringstream item; item << "(bool("; - PrintVecElemLoad(c_var, memory_ty, i, item); + PrintVecElemLoad(c_var, memory_dtype, i, item); item << ")?"; - PrintVecElemLoad(t_var, op->dtype, i, item); + PrintVecElemLoad(t_var, op->ty()->dtype, i, item); item << ':'; - PrintVecElemLoad(f_var, op->dtype, i, item); + PrintVecElemLoad(f_var, op->ty()->dtype, i, item); item << ')'; - PrintVecElemStore(r_var, op->dtype, i, item.str()); + PrintVecElemStore(r_var, op->ty()->dtype, i, item.str()); } } os << r_var; } inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*) + PrimType op_ty = op->ty(); // Type code is kBFloat - if (op->dtype.is_bfloat16()) { + if (op_ty.MatchesElementType(DLDataTypeCode::kDLBfloat, 16)) { os << "__float2bfloat16_rn"; os << '(' << std::hexfloat << op->value << 'f'; os << "/*" << std::scientific << op->value << "*/"; @@ -1897,15 +1952,15 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) return; } // Type code is kFloat8_e5m2 or kE4M4Float - if (op->dtype.is_float8() || op->dtype.is_float4()) { - p->PrintType(op->dtype, os); + if (IsCUDAFloat8(op_ty.code()) || IsCUDAFloat4(op_ty.code())) { + p->PrintType(op->ty()->dtype, os); os << '(' << std::hexfloat << op->value << 'f'; os << "/*" << std::scientific << op->value << "*/"; os << ')'; return; } // Type code is kFloat - switch (op->dtype.bits()) { + switch (op_ty.bits()) { case 64: { std::ostringstream temp; if (std::isinf(op->value)) { @@ -1945,13 +2000,14 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) } case 16: { os << "__float2half_rn" << '('; - FloatImm const_f32 = FloatImm(DataType::Float(32), op->value); + FloatImm const_f32 = FloatImm(PrimType::Float(32), op->value); PrintConst(const_f32.get(), os, p); os << ')'; break; } default: - TVM_FFI_THROW(InternalError) << "Bad bit-width for float: " << op->dtype << "\n"; + TVM_FFI_THROW(InternalError) + << "Bad bit-width for float: " << ffi::DLDataTypeToString(op->ty()->dtype) << "\n"; } } @@ -1959,25 +2015,27 @@ void CodeGenCUDA::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOL PrintConst(op, os, this); } -void CodeGenCUDA::PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable, +void CodeGenCUDA::PrintWmmaScope(const std::string& scope, DLDataType t, const VarNode* variable, std::ostream& os) { + PrimType t_ty(t); std::stringstream type; PrintType(t, type); TVM_FFI_ICHECK(fragment_shapes.count(variable)) << "Cannot find shape of the wmma fragment " << variable->name_hint; std::string shape_str = fragment_shapes.at(variable); - if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) { + if ((t_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) && t.bits < 8 && + t_ty.lanes() == 1) { type.str(std::string()); - if (t.is_int()) { - if (t.bits() == 4) { + if (t_ty.MatchesCode(DLDataTypeCode::kDLInt)) { + if (t.bits == 4) { type << "nvcuda::wmma::experimental::precision::s4"; - } else if (t.bits() == 1) { + } else if (t.bits == 1) { type << "nvcuda::wmma::experimental::precision::b1"; } else { TVM_FFI_THROW(InternalError) << "Unhandled interger type for wmma fragment!"; } - } else if (t.is_uint()) { - if (t.bits() == 4) { + } else if (t_ty.MatchesCode(DLDataTypeCode::kDLUInt)) { + if (t.bits == 4) { type << "nvcuda::wmma::experimental::precision::u4"; } else { TVM_FFI_THROW(InternalError) << "Unhandled interger type for wmma fragment!"; @@ -2029,20 +2087,25 @@ void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const BufferLoad // Cast away volatile qualifier for fp16 types. That is, only loads and // stores are volatile. The loaded objects are not marked as volatile. // - if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && IsVolatile(op->buffer->data.get())) { + PrimType op_ty = op->ty(); + if ((op_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16) || + op_ty.MatchesElementType(DLDataTypeCode::kDLBfloat, 16)) && + IsVolatile(op->buffer->data.get())) { os << "("; - PrintType(op->dtype, os); + PrintType(op->ty()->dtype, os); os << ")(" << value << ")"; } else { os << value; } } -void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, +void CodeGenCUDA::PrintVecElemLoadExpr(DLDataType t, int i, const std::string& value, std::ostream& os) { - TVM_FFI_ICHECK_GT(t.lanes(), 1); - if (t.bits() == 8 && (t.is_int() || t.is_uint())) { - if (!(t.lanes() == 2 || t.lanes() == 3)) { + PrimType t_ty(t); + int lanes = t_ty.lanes(); + TVM_FFI_ICHECK_GT(lanes, 1); + if (t.bits == 8 && (t_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt))) { + if (!(lanes == 2 || lanes == 3)) { if (i != 0) { os << "|"; } @@ -2051,12 +2114,12 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val } } - if (t.is_float16()) { + if (t_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16)) { if (i == 0) { PrintVecConstructor(t, os); os << '('; } - if (i == t.lanes() - 1) { + if (i == lanes - 1) { os << value << ")"; } else { os << value << ","; @@ -2064,12 +2127,12 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val return; } - if (t.is_bfloat16()) { + if (t_ty.MatchesElementType(DLDataTypeCode::kDLBfloat, 16)) { if (i == 0) { PrintVecConstructor(t, os); os << '('; } - if (i == t.lanes() - 1) { + if (i == lanes - 1) { os << value << ")"; } else { os << value << ","; @@ -2082,7 +2145,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val os << "("; } os << value; - if (i != t.lanes() - 1) { + if (i != lanes - 1) { os << ","; } else { os << ")"; diff --git a/src/backend/cuda/codegen/codegen_cuda.h b/src/backend/cuda/codegen/codegen_cuda.h index 92ca3cab34a4..94f86614e45e 100644 --- a/src/backend/cuda/codegen/codegen_cuda.h +++ b/src/backend/cuda/codegen/codegen_cuda.h @@ -56,16 +56,17 @@ class CodeGenCUDA final : public CodeGenC { void VisitStmt_(const WhileNode* op) final; void PrintStorageSync(const CallNode* op) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) - void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, - std::ostream& os) final; // NOLINT(*) - void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) - void PrintVecConstructor(DataType t, std::ostream& os) final; - void PrintVecElemLoad(const std::string& vec, DataType t, int i, + void PrintVecBinaryOp(const std::string& op, DLDataType t, PrimExpr lhs, PrimExpr rhs, + std::ostream& os) final; // NOLINT(*) + void PrintType(DLDataType t, std::ostream& os) final; // NOLINT(*) + void PrintVecConstructor(DLDataType t, std::ostream& os) final; + void PrintVecElemLoad(const std::string& vec, DLDataType t, int i, std::ostream& os) final; // NOLINT(*) - void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; + void PrintVecElemStore(const std::string& vec, DLDataType t, int i, + const std::string& value) final; void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) - void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) final; - std::string CastFromTo(std::string value, DataType from, DataType target) final; + void PrintVecElemLoadExpr(DLDataType t, int i, const std::string& value, std::ostream& os) final; + std::string CastFromTo(std::string value, DLDataType from, DLDataType target) final; void AddUtilFunction(const std::string& name, const std::string& code); // overload visitor void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) @@ -129,7 +130,7 @@ class CodeGenCUDA final : public CodeGenC { std::unordered_map fragment_shapes; std::unordered_map fragment_layouts; friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p); - void PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable, + void PrintWmmaScope(const std::string& scope, DLDataType t, const VarNode* variable, std::ostream& os); int32_t GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, int32_t size); }; diff --git a/src/backend/cuda/codegen/intrin_rule_cuda.cc b/src/backend/cuda/codegen/intrin_rule_cuda.cc index dc8d4a020e1e..ea2d0abfa80e 100644 --- a/src/backend/cuda/codegen/intrin_rule_cuda.cc +++ b/src/backend/cuda/codegen/intrin_rule_cuda.cc @@ -34,8 +34,8 @@ namespace intrin { using tirx::FLowerIntrinsic; struct CUDAMath { - std::string operator()(DataType t, std::string name) const { - if (t.is_float()) { + std::string operator()(PrimType t, std::string name) const { + if (t.code() == DLDataTypeCode::kDLFloat) { switch (t.bits()) { case 64: // Use nearbyint (ties-to-even) for round to match constant-folding semantics. @@ -56,7 +56,7 @@ struct CUDAMath { default: return ""; } - } else if (t.is_bfloat16()) { + } else if (t.code() == DLDataTypeCode::kDLBfloat && t.bits() == 16) { if (name == "fabs") { return "__habs"; } else if (name == "round") { @@ -64,7 +64,7 @@ struct CUDAMath { } else { return "h" + name; } - } else if (t.is_int() || t.is_uint()) { + } else if (t.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) { switch (t.bits()) { case 32: return "__" + name; @@ -79,8 +79,8 @@ struct CUDAMath { }; struct CUDAFastMath : public CUDAMath { - std::string operator()(DataType t, std::string name) const { - if (t.is_float() && t.bits() == 32) { + std::string operator()(PrimType t, std::string name) const { + if (t.code() == DLDataTypeCode::kDLFloat && t.bits() == 32) { return "__" + name + 'f'; } else { return CUDAMath::operator()(t, name); @@ -90,8 +90,8 @@ struct CUDAFastMath : public CUDAMath { }; struct CUDAFastMathTan : public CUDAMath { - std::string operator()(DataType t, std::string name) const { - if (t.is_float()) { + std::string operator()(PrimType t, std::string name) const { + if (t.code() == DLDataTypeCode::kDLFloat) { switch (t.bits()) { case 64: return name; @@ -110,8 +110,8 @@ struct CUDAFastMathTan : public CUDAMath { }; struct CUDAPopcount { - std::string operator()(DataType t, std::string name) const { - if (t.is_uint()) { + std::string operator()(PrimType t, std::string name) const { + if (t.MatchesCode(DLDataTypeCode::kDLUInt)) { switch (t.bits()) { case 32: return "__popc"; @@ -126,7 +126,7 @@ struct CUDAPopcount { }; struct CUDAWarpIntrinsic { - const Op operator()(DataType t, const Op& orig_op) const { + const Op operator()(PrimType t, const Op& orig_op) const { if (orig_op.same_as(builtin::tvm_warp_shuffle())) { static const Op& cuda_shfl_sync_op = Op::Get("tirx.cuda.__shfl_sync"); return cuda_shfl_sync_op; @@ -147,7 +147,7 @@ struct CUDAWarpIntrinsic { static PrimExpr DispatchCUDAWarpActiveMask(const PrimExpr& e) { const CallNode* call = e.as(); static const Op& cuda_active_mask_op = Op::Get("tirx.cuda.__activemask"); - return Call(call->dtype, cuda_active_mask_op, call->args); + return Call(e.ty(), cuda_active_mask_op, call->args); } template @@ -156,7 +156,7 @@ static PrimExpr DispatchCUDAShuffle(const PrimExpr& e) { TVM_FFI_ICHECK(call != nullptr); TVM_FFI_ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size ffi::Array cuda_args{{call->args[0], call->args[1], call->args[2], call->args[3]}}; - return Call(call->dtype, T()(call->dtype, call->op.as_or_throw()), cuda_args); + return Call(e.ty(), T()(e.ty(), call->op.as_or_throw()), cuda_args); } void RegisterCudaIntrinRules() { diff --git a/src/backend/cuda/codegen/llvm/codegen_nvptx.cc b/src/backend/cuda/codegen/llvm/codegen_nvptx.cc index e523e2b22aab..eb84f10fda10 100644 --- a/src/backend/cuda/codegen/llvm/codegen_nvptx.cc +++ b/src/backend/cuda/codegen/llvm/codegen_nvptx.cc @@ -87,7 +87,7 @@ class CodeGenNVPTX : public CodeGenLLVM { } auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data)); - DataType dtype = op->buffer->dtype; + PrimType dtype = op->buffer->dtype; if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { // Shared memory: address space == 3 @@ -230,7 +230,8 @@ class CodeGenNVPTX : public CodeGenLLVM { // corresponding nvvm intrinsic. Return true if the match is successful. static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID* id) { // Only 32 bit data type is supported. - if (op->dtype.is_fixed_length_vector() || op->dtype.bits() != 32) { + PrimType op_ty = op->ty(); + if (op_ty.IsFixedLengthVector() || op_ty.bits() != 32) { return false; } @@ -253,7 +254,7 @@ static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID* id) return false; } - *id = ids[offset + op->dtype.is_float()]; + *id = ids[offset + (op_ty.code() == DLDataTypeCode::kDLFloat)]; return true; } @@ -279,10 +280,11 @@ llvm::Value* CodeGenNVPTX::CreateIntrinsic(const CallNode* op) { auto val = llvm::InlineAsm::get(fty, "activemask.b32 %0", "=r", true); return builder_->CreateCall(val); } else if (op->op.same_as(builtin::atomic_add())) { - TVM_FFI_ICHECK(op->args[1]->dtype.bits() == 32) << "Only supports 32 bit atomic for now"; + PrimType value_ty = op->args[1].ty(); + TVM_FFI_ICHECK(value_ty.bits() == 32) << "Only supports 32 bit atomic for now"; llvm::Value* v0 = MakeValue(op->args[0]); llvm::Value* v1 = MakeValue(op->args[1]); - if (op->args[1]->dtype.is_float()) { + if (value_ty.code() == DLDataTypeCode::kDLFloat) { return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::FAdd, v0, v1, llvm::MaybeAlign(), llvm::AtomicOrdering::Monotonic); } diff --git a/src/backend/cuda/codegen/llvm/intrin_rule_nvptx.cc b/src/backend/cuda/codegen/llvm/intrin_rule_nvptx.cc index d8706a94b181..13d6f7d95a3b 100644 --- a/src/backend/cuda/codegen/llvm/intrin_rule_nvptx.cc +++ b/src/backend/cuda/codegen/llvm/intrin_rule_nvptx.cc @@ -38,7 +38,8 @@ inline PrimExpr DispatchPureExternLibDevice(const PrimExpr& e) { using namespace tirx; const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); - TVM_FFI_ICHECK(call->dtype.bits() == 32 || call->dtype.bits() == 64) + PrimType call_ty = call->ty(); + TVM_FFI_ICHECK(call_ty.bits() == 32 || call_ty.bits() == 64) << "Only support float32 or float64."; const OpNode* op = call->op.as(); @@ -48,13 +49,13 @@ inline PrimExpr DispatchPureExternLibDevice(const PrimExpr& e) { std::ostringstream intrinsic_name; intrinsic_name << "__nv_" << name.substr(5); - if (call->dtype.bits() == 32) intrinsic_name << "f"; + if (call_ty.bits() == 32) intrinsic_name << "f"; ffi::Array new_args = {StringImm(intrinsic_name.str())}; for (auto arg : call->args) { new_args.push_back(arg); } - return Call(call->dtype, builtin::call_pure_extern(), new_args); + return Call(call->ty(), builtin::call_pure_extern(), new_args); } namespace llvm { @@ -73,7 +74,7 @@ TVM_REGISTER_OP("tirx.round") const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); static const Op& nearbyint_op = Op::Get("tirx.nearbyint"); - auto new_call = Call(call->dtype, nearbyint_op, call->args); + auto new_call = Call(call->ty(), nearbyint_op, call->args); return DispatchPureExternLibDevice(new_call); }); diff --git a/src/backend/cuda/runtime/cuda_device_api.cc b/src/backend/cuda/runtime/cuda_device_api.cc index 68ae39de56bf..6e30df29aa91 100644 --- a/src/backend/cuda/runtime/cuda_device_api.cc +++ b/src/backend/cuda/runtime/cuda_device_api.cc @@ -426,7 +426,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_ICHECK_GE(args.size(), 4) << "init_cuTensorMap expects at least 4 arguments"; size_t arg_cnt = 0; CUtensorMap* tensor_map = static_cast(args[arg_cnt++].cast()); - runtime::DataType tensor_dtype = args[arg_cnt++].cast(); + DLDataType tensor_dtype = args[arg_cnt++].cast(); int32_t raw_tensor_rank = args[arg_cnt++].cast(); TVM_FFI_ICHECK_GT(raw_tensor_rank, 0) << "tensorRank must be non-zero"; TVM_FFI_ICHECK_LE(raw_tensor_rank, 5) @@ -482,7 +482,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { << "Expect tensor_dtype to have lanes=1, but get " << tensor_dtype; CUtensorMapDataType cu_dtype; switch (tensor_dtype.code()) { - case DataType::kInt: + case kDLInt: // int switch (tensor_dtype.bits()) { case 8: @@ -499,7 +499,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { << "Unsupported data type " << ffi::DLDataTypeToString(tensor_dtype); } break; - case DataType::kUInt: + case kDLUInt: // unsigned int switch (tensor_dtype.bits()) { case 8: @@ -519,7 +519,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { << "Unsupported data type " << ffi::DLDataTypeToString(tensor_dtype); } break; - case DataType::kFloat: + case kDLFloat: // float switch (tensor_dtype.bits()) { case 16: @@ -536,7 +536,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { << "Unsupported data type " << ffi::DLDataTypeToString(tensor_dtype); } break; - case DataType::kBFloat: + case kDLBfloat: // bfloat switch (tensor_dtype.bits()) { case 16: @@ -547,15 +547,15 @@ TVM_FFI_STATIC_INIT_BLOCK() { << "Unsupported data type " << ffi::DLDataTypeToString(tensor_dtype); } break; - case DataType::kFloat8_e4m3fn: + case kDLFloat8_e4m3fn: // NV float8 e4m3 cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8; break; - case DataType::kFloat8_e5m2: + case kDLFloat8_e5m2: // NV float8 e5m2 cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8; break; - case DataType::kFloat4_e2m1fn: + case kDLFloat4_e2m1fn: #if (CUDA_VERSION >= 12080) // Packed FP4 in GMEM, unpacked into SMEM/TMEM-facing tiles. cu_dtype = CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B; diff --git a/src/backend/hexagon/codegen/llvm/codegen_hexagon.cc b/src/backend/hexagon/codegen/llvm/codegen_hexagon.cc index 017796918444..60959f2aa9fe 100644 --- a/src/backend/hexagon/codegen/llvm/codegen_hexagon.cc +++ b/src/backend/hexagon/codegen/llvm/codegen_hexagon.cc @@ -66,6 +66,11 @@ namespace tvm { namespace codegen { +TVM_FFI_INLINE int GetVectorBytes(const PrimType& dtype) { + TVM_FFI_ICHECK(dtype.IsFixedLengthVector() || dtype.IsScalar()); + return dtype.bits() * dtype.lanes() / 8; +} + // Hexagon code generation class CodeGenHexagon final : public CodeGenCPU { public: @@ -97,12 +102,12 @@ class CodeGenHexagon final : public CodeGenCPU { void CreatePrintf(const std::string& format, llvm::ArrayRef format_args) final; private: - TypedPointer CreateBufferPtr(llvm::Value* buffer_ptr, DataType buffer_element_dtype, - llvm::ArrayRef indices, DataType value_dtype) final; + TypedPointer CreateBufferPtr(llvm::Value* buffer_ptr, PrimType buffer_element_dtype, + llvm::ArrayRef indices, PrimType value_dtype) final; bool IsQHLFunction(const std::string& func); - llvm::Value* VectorLookupLoad(Buffer buffer, DataType buffer_type, ffi::Array indices); + llvm::Value* VectorLookupLoad(Buffer buffer, PrimType buffer_type, ffi::Array indices); llvm::Value* Intrinsic(llvm::Intrinsic::ID, llvm::ArrayRef args); std::vector fqhl_list_ = { "tvm_vect_qhmath_hvx_cos_ahf", "tvm_vect_qhmath_hvx_tanh_ahf", @@ -149,8 +154,9 @@ void CodeGenHexagon::InitTarget() { llvm::Value* CodeGenHexagon::CreateCallExternQHL(Type ret_type, ffi::String global_symbol, const ffi::Array& args, bool skip_first_arg) { - int num_lanes = args[1].dtype().lanes(); - int vector_length = native_vector_bits_ / args[1].dtype().bits(); + PrimType arg_ty = args[1].ty(); + int num_lanes = arg_ty.lanes(); + int vector_length = native_vector_bits_ / arg_ty.bits(); num_lanes = ((num_lanes + vector_length - 1) / vector_length) * vector_length; std::vector vect_split; for (int i = 0; i < num_lanes / vector_length; ++i) { @@ -181,8 +187,9 @@ bool CodeGenHexagon::IsQHLFunction(const std::string& func) { llvm::Value* CodeGenHexagon::CreateCallExtern(Type ret_type, ffi::String global_symbol, const ffi::Array& args, bool skip_first_arg) { - int num_lanes = args[1].dtype().lanes(); - int vector_length = native_vector_bits_ / args[1].dtype().bits(); + PrimType arg_ty = args[1].ty(); + int num_lanes = arg_ty.lanes(); + int vector_length = native_vector_bits_ / arg_ty.bits(); if (IsQHLFunction(global_symbol) && (num_lanes > vector_length)) return CreateCallExternQHL(ret_type, global_symbol, args, skip_first_arg); return CodeGenCPU::CreateCallExtern(ret_type, global_symbol, args, skip_first_arg); @@ -192,7 +199,7 @@ llvm::Value* CodeGenHexagon::VisitExpr_(const BufferLoadNode* op) { if (!op->buffer.same_as(op->buffer->data)) { // Check if we can generate a vector lookup. if (!op->indices[0].as()) { - if (auto* vlut = VectorLookupLoad(op->buffer, op->dtype, op->indices)) { + if (auto* vlut = VectorLookupLoad(op->buffer, PrimType(op->ty()->dtype), op->indices)) { return vlut; } } @@ -261,9 +268,9 @@ void CodeGenHexagon::CreatePrintf(const std::string& format, } CodeGenLLVM::TypedPointer CodeGenHexagon::CreateBufferPtr(llvm::Value* buffer_ptr, - DataType buffer_element_dtype, + PrimType buffer_element_dtype, llvm::ArrayRef indices, - DataType value_dtype) { + PrimType value_dtype) { // Flat indices get delegated to the LLVM codegen. if (indices.size() == 1) { return CodeGenCPU::CreateBufferPtr(buffer_ptr, buffer_element_dtype, indices, value_dtype); @@ -274,7 +281,7 @@ CodeGenLLVM::TypedPointer CodeGenHexagon::CreateBufferPtr(llvm::Value* buffer_pt << "-d buffer indices"; // Use the first index to identify the pointer. - DataType dtype_void_ptr = DataType::Handle(); + PrimType dtype_void_ptr = PrimType::Handle(); CodeGenLLVM::TypedPointer buffer_chunk_ptr_ptr = CodeGenCPU::CreateBufferPtr(buffer_ptr, dtype_void_ptr, {indices[0]}, dtype_void_ptr); llvm::Value* buffer_chunk_ptr = @@ -317,10 +324,11 @@ llvm::Value* CodeGenHexagon::Intrinsic(llvm::Intrinsic::ID IntID, return builder_->CreateCall(intf_callee, conv_args); } -llvm::Value* CodeGenHexagon::VectorLookupLoad(Buffer buffer, DataType buffer_type, +llvm::Value* CodeGenHexagon::VectorLookupLoad(Buffer buffer, PrimType buffer_type, ffi::Array indices) { PrimExpr index = indices[0]; - if (!index.dtype().is_fixed_length_vector()) { + PrimType index_ty = index.ty(); + if (!index_ty.IsFixedLengthVector()) { return nullptr; } @@ -329,16 +337,16 @@ llvm::Value* CodeGenHexagon::VectorLookupLoad(Buffer buffer, DataType buffer_typ int table_elem_count = arith::Analyzer()->Simplify(buffer->shape[0]).as()->value; if (table_elem_count <= 0 || table_elem_count > 256) return nullptr; - auto int32 = DataType::Int(32); + auto int32 = PrimType::Int(32); auto native_vector_bytes = native_vector_bits_ / 8; // Indexes - llvm::Value* trunc = MakeValue(Cast(index.dtype().with_bits(8), index)); + llvm::Value* trunc = MakeValue(Cast(index_ty.WithBits(8), index)); llvm::Value* index_pad = CreateVecPad(trunc, native_vector_bytes); // Values std::vector vloads; - DataType table_type = buffer_type.with_lanes(table_elem_count); + PrimType table_type = buffer_type.WithLanes(table_elem_count); auto table_all = MakeValue(BufferLoad(buffer, { @@ -347,7 +355,7 @@ llvm::Value* CodeGenHexagon::VectorLookupLoad(Buffer buffer, DataType buffer_typ // The number of value vectors should be a power of 2. int table_vec_count = llvm::PowerOf2Ceil(GetVectorBytes(table_type) / native_vector_bytes); - int table_vec_length = native_vector_bytes / buffer_type.bytes(); + int table_vec_length = native_vector_bytes / GetVectorBytes(buffer_type); for (int i = 0; i != table_vec_count; ++i) { // CreateVecSlice will generate undefs for elements outside the source vector. vloads.push_back(CreateVecSlice(table_all, i * table_vec_length, table_vec_length)); diff --git a/src/backend/hexagon/codegen/llvm/intrin_rule_hexagon.cc b/src/backend/hexagon/codegen/llvm/intrin_rule_hexagon.cc index 3e46e322a881..928df03f38aa 100644 --- a/src/backend/hexagon/codegen/llvm/intrin_rule_hexagon.cc +++ b/src/backend/hexagon/codegen/llvm/intrin_rule_hexagon.cc @@ -50,7 +50,7 @@ inline PrimExpr TVMExternCall(const tirx::CallNode* call, const std::string& fna for (PrimExpr arg : call->args) { new_args.push_back(arg); } - return tirx::Call(call->dtype, tirx::builtin::call_pure_extern(), new_args); + return tirx::Call(call->ty(), tirx::builtin::call_pure_extern(), new_args); } template @@ -72,14 +72,16 @@ inline PrimExpr DispatchTVMQHLWrapperFp16(const PrimExpr& e) { // Enable QHL library for FP16 data type const PrimExpr& x = call->args[0]; - if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) { + PrimType x_ty = x.ty(); + if (x_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16) && + (x_ty.IsFixedLengthVector() || x_ty.IsScalableVector()) && useqhl) { return TVMExternCall(call, tvm_wrapper); } #endif - new_args.push_back(IntImm(DataType::UInt(32), id)); - new_args.push_back(IntImm(DataType::UInt(32), num_sign)); + new_args.push_back(IntImm(PrimType::UInt(32), id)); + new_args.push_back(IntImm(PrimType::UInt(32), num_sign)); new_args.insert(new_args.end(), call->args.begin(), call->args.end()); - return tirx::Call(call->dtype, tirx::builtin::call_llvm_pure_intrin(), new_args); + return tirx::Call(call->ty(), tirx::builtin::call_llvm_pure_intrin(), new_args); } void RegisterHexagonIntrinRules() { @@ -117,6 +119,7 @@ TVM_REGISTER_OP("tirx.tanh") const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; + PrimType x_ty = x.ty(); #if ENABLE_QHL // Check target for qfloat enablement @@ -130,14 +133,15 @@ TVM_REGISTER_OP("tirx.tanh") } // Enable QHL library for FP16 data type - if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) { + if (x_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16) && + (x_ty.IsFixedLengthVector() || x_ty.IsScalableVector()) && useqhl) { std::string tvm_wrapper("tvm_vect_qhmath_hvx_tanh_ahf"); return TVMExternCall(call, tvm_wrapper); } #endif - PrimExpr one = tirx::MakeConst(x.dtype(), 1); - PrimExpr two = tirx::MakeConst(x.dtype(), 2); - PrimExpr neg_two = tirx::MakeConst(x.dtype(), -2); + PrimExpr one = tirx::MakeConst(x_ty, 1); + PrimExpr two = tirx::MakeConst(x_ty, 2); + PrimExpr neg_two = tirx::MakeConst(x_ty, -2); PrimExpr exp_neg2x = exp(neg_two * x); PrimExpr exp_pos2x = exp(two * x); @@ -145,7 +149,7 @@ TVM_REGISTER_OP("tirx.tanh") PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); // MakeConst can handle both vector and scalar types. - PrimExpr tanh_x = tirx::Select(x >= tirx::MakeConst(x.dtype(), 0), tanh_pos, tanh_neg); + PrimExpr tanh_x = tirx::Select(x >= tirx::MakeConst(x_ty, 0), tanh_pos, tanh_neg); return tanh_x; }); @@ -154,6 +158,7 @@ TVM_REGISTER_OP("tirx.tan") const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; + PrimType x_ty = x.ty(); #if ENABLE_QHL // Check target for qfloat enablement const auto f = tvm::ffi::Function::GetGlobal("target.TargetCurrent"); @@ -166,7 +171,8 @@ TVM_REGISTER_OP("tirx.tan") } // Enable QHL library for FP16 data type - if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) { + if (x_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16) && + (x_ty.IsFixedLengthVector() || x_ty.IsScalableVector()) && useqhl) { std::string tvm_wrapper("tvm_vect_qhmath_hvx_tan_ahf"); return TVMExternCall(call, tvm_wrapper); } @@ -184,6 +190,7 @@ TVM_REGISTER_OP("tirx.sigmoid") const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; + PrimType x_ty = x.ty(); #if ENABLE_QHL // Check target for qfloat enablement const auto f = tvm::ffi::Function::GetGlobal("target.TargetCurrent"); @@ -195,21 +202,22 @@ TVM_REGISTER_OP("tirx.sigmoid") useqhl = tstring.find("+hvx-qfloat") != std::string::npos; } - PrimExpr MinBound = tirx::MakeConst(x.dtype(), -8); - PrimExpr MaxBound = tirx::MakeConst(x.dtype(), 8); + PrimExpr MinBound = tirx::MakeConst(x_ty, -8); + PrimExpr MaxBound = tirx::MakeConst(x_ty, 8); const PrimExpr v1 = tirx::Max(x, MinBound); const PrimExpr v2 = tirx::Min(v1, MaxBound); ffi::Array new_args = {v2}; - const tirx::Call new_call = tirx::Call(call->dtype, call->op, new_args); + const tirx::Call new_call = tirx::Call(call->ty(), call->op, new_args); // Enable QHL library for FP16 data type - if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) { + if (x_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16) && + (x_ty.IsFixedLengthVector() || x_ty.IsScalableVector()) && useqhl) { std::string tvm_wrapper("tvm_vect_qhmath_hvx_sigmoid_ahf"); return TVMExternCall(new_call.get(), tvm_wrapper); } #endif - PrimExpr one = tirx::MakeConst(x.dtype(), 1); + PrimExpr one = tirx::MakeConst(x_ty, 1); return one / (one + exp(-x)); }); diff --git a/src/backend/hexagon/runtime/ops/conv2d_fp16_hvx.cc b/src/backend/hexagon/runtime/ops/conv2d_fp16_hvx.cc index d555fb77cfae..c063ae62b1bd 100644 --- a/src/backend/hexagon/runtime/ops/conv2d_fp16_hvx.cc +++ b/src/backend/hexagon/runtime/ops/conv2d_fp16_hvx.cc @@ -21,8 +21,8 @@ #include #include #include +#include #include -#include #include #include @@ -469,7 +469,7 @@ int conv2d_packed_fp16(void*, TVMFFIAny* args, int num_args, TVMFFIAny* out_val) // Prepare zero_block int64_t block_nbytes = 2048; void* zero_block = device_api->AllocDataSpace(conv_utils::hexagon_device, 1, &block_nbytes, - tvm::runtime::DataType::UInt(8), vtcm_scope); + DLDataType{kDLUInt, 8, 1}, vtcm_scope); memset(zero_block, 0, 2048); // FIXME: Setting bias to zero_block: this works for up to 256 output channels. diff --git a/src/backend/metal/codegen/codegen_metal.cc b/src/backend/metal/codegen/codegen_metal.cc index 3f483f79aaed..e6ef1647e5bf 100644 --- a/src/backend/metal/codegen/codegen_metal.cc +++ b/src/backend/metal/codegen/codegen_metal.cc @@ -46,7 +46,7 @@ void CodeGenMetal::InitFuncState(const PrimFunc& f) { CodeGenC::InitFuncState(f); // analyze the data; for (Var arg : f->params) { - if (arg.dtype().is_handle()) { + if (arg.ty().IsHandle()) { alloc_storage_scope_[arg.get()] = "global"; } } @@ -97,7 +97,7 @@ void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { } for (size_t i = 0; i < func->params.size(); ++i, ++num_buffer) { Var v = func->params[i]; - if (!v.dtype().is_handle()) break; + if (!v.ty().IsHandle()) break; this->stream << " "; std::string vid = AllocVarID(v.get()); auto it = alloc_storage_scope_.find(v.get()); @@ -126,24 +126,24 @@ void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { decl_stream << "struct " << arg_buf_type << " {\n"; for (size_t i = num_buffer; i < func->params.size(); ++i) { Var v = func->params[i]; - TVM_FFI_ICHECK(!v.dtype().is_handle()); + TVM_FFI_ICHECK(!v.ty().IsHandle()); std::string vid = AllocVarID(v.get()); std::ostringstream vref; - if (v.dtype().bits() == 32) { + if (v.ty().bits() == 32) { decl_stream << " "; - PrintType(v.dtype(), decl_stream); + PrintType(v.ty()->dtype, decl_stream); decl_stream << " " << vid << "[2];\n"; vref << varg << "." << vid << "[0]"; - } else if (v.dtype().bits() == 64) { + } else if (v.ty().bits() == 64) { decl_stream << " "; - PrintType(v.dtype(), decl_stream); + PrintType(v.ty()->dtype, decl_stream); decl_stream << " " << vid << ";\n"; vref << varg << "." << vid; } else { // For non 32bit type, ref through arg union. decl_stream << " __TVMArgUnion " << vid << ";\n"; vref << varg << "." << vid << ".v_"; - PrintType(v.dtype(), vref); + PrintType(v.ty()->dtype, vref); } var_idmap_[v.get()] = vref.str(); } @@ -165,10 +165,14 @@ void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { if (work_dim != 0) { // use ushort by default for now stream << " "; - PrintType(DataType::UInt(thread_index_bits_, work_dim), stream); + PrintType(DLDataType{kDLUInt, static_cast(thread_index_bits_), + static_cast(work_dim)}, + stream); stream << " blockIdx [[threadgroup_position_in_grid]],\n"; stream << " "; - PrintType(DataType::UInt(thread_index_bits_, work_dim), stream); + PrintType(DLDataType{kDLUInt, static_cast(thread_index_bits_), + static_cast(work_dim)}, + stream); stream << " threadIdx [[thread_position_in_threadgroup]]\n"; } thread_work_dim_ = work_dim; @@ -190,28 +194,29 @@ void CodeGenMetal::BindThreadIndex(const IterVar& iv) { if (thread_work_dim_ <= 1) { vname = vname.substr(0, iv->thread_tag.length() - 2); } - var_idmap_[iv->var.get()] = - CastFromTo(vname, DataType::UInt(thread_index_bits_), iv->var.dtype()); + var_idmap_[iv->var.get()] = CastFromTo( + vname, DLDataType{kDLUInt, static_cast(thread_index_bits_), 1}, iv->var.ty()->dtype); } -void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*) +void CodeGenMetal::PrintType(DLDataType raw_t, std::ostream& os) { // NOLINT(*) + PrimType t(raw_t); int lanes = t.lanes(); - if (t.is_handle()) { + if (t.IsHandle()) { TVM_FFI_ICHECK_EQ(lanes, 1) << "do not yet support vector types"; os << "void*"; return; } - if (t.is_void()) { + if (t.IsVoid()) { os << "void"; return; } - if (t == DataType::Bool()) { + if (raw_t == DLDataType{kDLBool, 8, 1}) { os << "bool"; return; } bool fail = false; - if (t.is_float()) { + if (t.code() == DLDataTypeCode::kDLFloat) { // Need to care about sizes and alignment of half3/float3 because tirx representation might not // be aware of Metal half3/float3 details and can treat them as just three elements, // while sizes and alignmnents of half3/float3 are one element more (half3-8 bytes/ @@ -239,8 +244,8 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << lanes; return; } - } else if (t.is_uint() || t.is_int()) { - if (t.is_uint()) { + } else if (t.MatchesCode(DLDataTypeCode::kDLUInt, DLDataTypeCode::kDLInt)) { + if (t.MatchesCode(DLDataTypeCode::kDLUInt)) { os << 'u'; } switch (t.bits()) { @@ -268,11 +273,12 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << lanes; return; } - } else if (t.is_bfloat16()) { + } else if (t.MatchesElementType(DLDataTypeCode::kDLBfloat, 16)) { os << "bfloat"; return; } - TVM_FFI_THROW(InternalError) << "Cannot convert type " << t << " to Metal type"; + TVM_FFI_THROW(InternalError) << "Cannot convert type " << ffi::DLDataTypeToString(raw_t) + << " to Metal type"; } void CodeGenMetal::PrintStorageSync(const CallNode* op) { @@ -288,12 +294,12 @@ void CodeGenMetal::PrintStorageSync(const CallNode* op) { } } -void CodeGenMetal::PrintVecElemLoad(const std::string& vec, DataType t, int i, +void CodeGenMetal::PrintVecElemLoad(const std::string& vec, DLDataType t, int i, std::ostream& os) { // NOLINT(*) os << vec << "[" << i << "]"; } -void CodeGenMetal::PrintVecElemStore(const std::string& vec, DataType t, int i, +void CodeGenMetal::PrintVecElemStore(const std::string& vec, DLDataType t, int i, const std::string& value) { this->PrintIndent(); stream << vec << "[" << i << "]" @@ -328,11 +334,14 @@ void CodeGenMetal::VisitStmt_(const AllocBufferNode* op) { auto scope = GetPtrStorageScope(op->buffer->data); alloc_storage_scope_[op->buffer->data.get()] = scope; - DataType dtype = op->buffer->dtype; + DLDataType dtype = op->buffer->dtype->dtype; if (scope == "metal.simdgroup") { - TVM_FFI_ICHECK(dtype == DataType::Float(16) || dtype == DataType::Float(32) || - dtype == DataType::BFloat(16)) - << "Only float16, float32, and bfloat16 are supported, but got " << dtype; + bool supported_simdgroup_dtype = dtype == DLDataType{kDLFloat, 16, 1} || + dtype == DLDataType{kDLFloat, 32, 1} || + dtype == DLDataType{kDLBfloat, 16, 1}; + TVM_FFI_ICHECK(supported_simdgroup_dtype) + << "Only float16, float32, and bfloat16 are supported, but got " + << ffi::DLDataTypeToString(dtype); TVM_FFI_ICHECK(constant_size % 64 == 0) << "Only 8x8 matrix is supported, but got " << constant_size << " bytes\n"; @@ -360,8 +369,8 @@ void CodeGenMetal::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLI void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); - int lanes = op->dtype.lanes(); - PrintType(op->dtype, os); + int lanes = op->ty().lanes(); + PrintType(op->ty()->dtype, os); os << "("; for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; @@ -422,7 +431,7 @@ void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT } else if (op->op.same_as(builtin::reinterpret())) { // generate as_type(ARG) os << "(as_type<"; - this->PrintType(op->dtype, os); + this->PrintType(op->ty()->dtype, os); os << ">("; this->PrintExpr(op->args[0], os); os << "))"; @@ -442,9 +451,9 @@ void CodeGenMetal::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NO temp << "NAN"; } else { temp << std::scientific << op->value; - if (op->dtype.bits() == 32) + if (op->ty().bits() == 32) temp << 'f'; - else if (op->dtype.bits() == 16) + else if (op->ty().bits() == 16) temp << 'h'; } MarkConst(temp.str()); diff --git a/src/backend/metal/codegen/codegen_metal.h b/src/backend/metal/codegen/codegen_metal.h index b92608aecfa1..ffa9a321aa43 100644 --- a/src/backend/metal/codegen/codegen_metal.h +++ b/src/backend/metal/codegen/codegen_metal.h @@ -43,13 +43,14 @@ class CodeGenMetal final : public CodeGenC { void InitFuncState(const PrimFunc& f) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageSync(const CallNode* op) final; // NOLINT(*) - void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + void PrintType(DLDataType t, std::ostream& os) final; // NOLINT(*) void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) // print load of single element - void PrintVecElemLoad(const std::string& vec, DataType t, int i, + void PrintVecElemLoad(const std::string& vec, DLDataType t, int i, std::ostream& os) final; // NOLINT(*) // print store of single element. - void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; + void PrintVecElemStore(const std::string& vec, DLDataType t, int i, + const std::string& value) final; // overload visitor void VisitStmt_(const AllocBufferNode* op) final; // NOLINT(*) void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*) diff --git a/src/backend/metal/codegen/intrin_rule_metal.cc b/src/backend/metal/codegen/intrin_rule_metal.cc index c807ac4c2e8a..999fe526f04e 100644 --- a/src/backend/metal/codegen/intrin_rule_metal.cc +++ b/src/backend/metal/codegen/intrin_rule_metal.cc @@ -31,7 +31,7 @@ namespace intrin { using tirx::FLowerIntrinsic; struct MetalWarpIntrinsic { - const Op operator()(DataType t, const Op& orig_op) const { + const Op operator()(PrimType t, const Op& orig_op) const { if (orig_op.same_as(builtin::tvm_warp_shuffle())) { static const Op& metal_simd_shuffle_op = Op::Get("tirx.metal.simd_shuffle"); return metal_simd_shuffle_op; @@ -52,7 +52,7 @@ static PrimExpr DispatchMetalShuffle(const PrimExpr& e) { TVM_FFI_ICHECK(call != nullptr); TVM_FFI_ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size ffi::Array metal_args{{call->args[1], call->args[2]}}; - return Call(call->dtype, T()(call->dtype, call->op.as_or_throw()), metal_args); + return Call(e.ty(), T()(e.ty(), call->op.as_or_throw()), metal_args); } void RegisterMetalIntrinRules() { @@ -81,7 +81,7 @@ TVM_REGISTER_OP("tirx.round") for (auto arg : call->args) { new_args.push_back(arg); } - return tirx::Call(call->dtype, tirx::builtin::call_pure_extern(), new_args); + return tirx::Call(e.ty(), tirx::builtin::call_pure_extern(), new_args); }); TVM_REGISTER_OP("tirx.nearbyint") diff --git a/src/backend/opencl/codegen/codegen_opencl.cc b/src/backend/opencl/codegen/codegen_opencl.cc index 51719785195b..001d4a33b081 100644 --- a/src/backend/opencl/codegen/codegen_opencl.cc +++ b/src/backend/opencl/codegen/codegen_opencl.cc @@ -84,7 +84,7 @@ void CodeGenOpenCL::InitFuncState(const PrimFunc& f) { // Storage scope qualifiers for textures are inferred // and set prior to function codegen. continue; - } else if (arg.dtype().is_handle()) { + } else if (arg.ty().IsHandle()) { alloc_storage_scope_[arg.get()] = "global"; } } @@ -189,26 +189,27 @@ void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) { } else { os << "get_group_id(" << ts.dim_index << ")"; } - var_idmap_[iv->var.get()] = CastFromTo(os.str(), DataType::UInt(64), iv->var.dtype()); + var_idmap_[iv->var.get()] = CastFromTo(os.str(), DLDataType{kDLUInt, 64, 1}, iv->var.ty()->dtype); } -void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) +void CodeGenOpenCL::PrintType(DLDataType raw_t, std::ostream& os) { // NOLINT(*) + PrimType t(raw_t); int lanes = t.lanes(); - if (t.is_handle()) { + if (t.IsHandle()) { TVM_FFI_ICHECK_EQ(lanes, 1) << "do not yet support vector types"; os << "void*"; return; } - if (t.is_void()) { + if (t.IsVoid()) { os << "void"; return; } - if (t == DataType::Bool()) { + if (raw_t == DLDataType{kDLBool, 8, 1}) { os << "bool"; return; } bool fail = false; - if (t.is_float()) { + if (t.code() == DLDataTypeCode::kDLFloat) { switch (t.bits()) { case 16: os << "half"; @@ -230,14 +231,14 @@ void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << lanes; return; } - } else if (t.is_bool()) { + } else if (t.MatchesCode(DLDataTypeCode::kDLBool)) { os << "uint"; if (!fail && ((lanes >= 2 && lanes <= 4) || lanes == 8 || lanes == 16)) { os << lanes; return; } - } else if (t.is_uint() || t.is_int()) { - if (t.is_uint()) { + } else if (t.MatchesCode(DLDataTypeCode::kDLUInt, DLDataTypeCode::kDLInt)) { + if (t.MatchesCode(DLDataTypeCode::kDLUInt)) { os << 'u'; } switch (t.bits()) { @@ -266,7 +267,8 @@ void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) return; } } - TVM_FFI_THROW(InternalError) << "Cannot convert type " << t << " to OpenCL type"; + TVM_FFI_THROW(InternalError) << "Cannot convert type " << ffi::DLDataTypeToString(raw_t) + << " to OpenCL type"; } void CodeGenOpenCL::PrintType(const Type& type, std::ostream& os) { // NOLINT(*) @@ -286,41 +288,44 @@ void CodeGenOpenCL::PrintType(const Type& type, std::ostream& os) { // NOLINT(* } } -void CodeGenOpenCL::PrintVecAddr(const BufferNode* buffer, DataType t, PrimExpr base, +void CodeGenOpenCL::PrintVecAddr(const BufferNode* buffer, DLDataType t, PrimExpr base, std::ostream& os) { // NOLINT(*) const VarNode* buffer_var = buffer->data.get(); - if (!HandleTypeMatch(buffer_var, t.element_of())) { + DLDataType elem_type{t.code, t.bits, 1}; + if (!HandleTypeMatch(buffer_var, elem_type)) { os << '('; auto it = alloc_storage_scope_.find(buffer_var); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, os); } - PrintType(t.element_of(), os); + PrintType(elem_type, os); os << "*)"; } os << GetVarID(buffer_var) << " + "; PrintExpr(base, os); } -std::string CodeGenOpenCL::GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base) { +std::string CodeGenOpenCL::GetVecLoad(DLDataType t, const BufferNode* buffer, PrimExpr base) { std::ostringstream os; - os << "vload" << t.lanes() << "(0, "; + os << "vload" << PrimType(t).lanes() << "(0, "; PrintVecAddr(buffer, t, base, os); os << ")"; return os.str(); } -void CodeGenOpenCL::PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base, +void CodeGenOpenCL::PrintVecStore(const BufferNode* buffer, DLDataType t, PrimExpr base, const std::string& value) { this->PrintIndent(); - stream << "vstore" << t.lanes() << "(" << value << ", 0, "; + stream << "vstore" << PrimType(t).lanes() << "(" << value << ", 0, "; PrintVecAddr(buffer, t, base, stream); stream << ");\n"; } -void CodeGenOpenCL::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, +void CodeGenOpenCL::PrintVecElemLoadExpr(DLDataType t, int i, const std::string& value, std::ostream& os) { // NOLINT(*) - TVM_FFI_ICHECK_GT(t.lanes(), 1); - if (t.bits() == 8 && (t.is_int() || t.is_uint())) { + PrimType t_ty(t); + int lanes = t_ty.lanes(); + TVM_FFI_ICHECK_GT(lanes, 1); + if (t.bits == 8 && (t_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt))) { if (i != 0) { os << "|"; } @@ -334,7 +339,7 @@ void CodeGenOpenCL::PrintVecElemLoadExpr(DataType t, int i, const std::string& v os << ")("; } os << value; - if (i != t.lanes() - 1) { + if (i != lanes - 1) { os << ","; } else { os << "))"; @@ -376,14 +381,14 @@ void CodeGenOpenCL::PrintRestrict(const Var& v, std::ostream& os) { } } -std::string CodeGenOpenCL::CastFromTo(std::string value, DataType from, DataType target) { +std::string CodeGenOpenCL::CastFromTo(std::string value, DLDataType from, DLDataType target) { if (from == target) return value; return CastTo(value, target); } -std::string CodeGenOpenCL::CastTo(std::string value, DataType target) { +std::string CodeGenOpenCL::CastTo(std::string value, DLDataType target) { std::ostringstream os; - if (target == DataType::Bool()) { + if (target == DLDataType{kDLBool, 8, 1}) { os << "("; os << "("; this->PrintType(target, os); @@ -422,7 +427,7 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, os); } - this->PrintType(load->dtype.element_of(), os); + this->PrintType(DLDataType{load->ty()->dtype.code, load->ty()->dtype.bits, 1}, os); os << " *)" << this->GetVarID(load->buffer->data.get()) << " + "; this->PrintExpr(load->indices[0], os); os << ')'; @@ -434,13 +439,14 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { const int channel_size = op->args[4].as_or_throw()->value; TVM_FFI_ICHECK(channel_size == 64 || channel_size == 128) << "Unsupported Channel Size: " << channel_size; - DataType channel_type = runtime::GetChannelType(channel_size); + DLDataType channel_type = runtime::GetChannelType(channel_size); - DataType buffer_type = ptr_type->element_type.as()->dtype; + DLDataType buffer_type = ptr_type->element_type.as()->dtype; std::stringstream ss; this->PrintExpr(op->args[5], ss); std::string value; - value = this->SSAGetID(ss.str(), buffer_type.with_lanes(channel_size / buffer_type.bits())); + value = this->SSAGetID(ss.str(), + PrimType(buffer_type).WithLanes(channel_size / buffer_type.bits)->dtype); if (channel_size == 64) { os << "write_imageh("; } else if (channel_size == 128) { @@ -467,11 +473,11 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { enable_compliant_texture_reads_ = true; std::stringstream ss; const int channel_size = op->args[4].as_or_throw()->value; - const int data_lanes = channel_size / op->dtype.bits(); + const int data_lanes = channel_size / op->ty().bits(); TVM_FFI_ICHECK(channel_size == 64 || channel_size == 128) << "Unsupported Channel Size: " << channel_size; ss << "as_"; - this->PrintType(op->dtype.with_lanes(data_lanes), ss); + this->PrintType(op->ty().WithLanes(data_lanes)->dtype, ss); ss << "("; if (channel_size == 64) { ss << "READ_IMAGEH("; @@ -493,7 +499,7 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { this->PrintExpr(IntImm::Int32(0), ss); ss << "))))"; - std::string rhs = SSAGetID(ss.str(), op->dtype.with_lanes(data_lanes)); + std::string rhs = SSAGetID(ss.str(), op->ty().WithLanes(data_lanes)->dtype); if (auto ramp = op->args.back().as()) { if (ramp->base.as() && *tirx::as_const_int(ramp->base) == 0 && *tirx::as_const_int(ramp->lanes) == data_lanes && @@ -501,10 +507,10 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { os << rhs; } else if (*tirx::as_const_int(ramp->stride) == 1) { os << "(*("; - this->PrintType(op->dtype.with_lanes(*tirx::as_const_int(ramp->lanes)), os); + this->PrintType(op->ty().WithLanes(*tirx::as_const_int(ramp->lanes))->dtype, os); os << "*)"; os << "(("; - this->PrintType(op->dtype.with_lanes(1), os); + this->PrintType(op->ty().WithLanes(1)->dtype, os); os << "*)&" << rhs << " + "; this->PrintExpr(ramp->base, os); os << "))"; @@ -513,7 +519,7 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { } } else { os << "(("; - this->PrintType(op->dtype.with_lanes(1), os); + this->PrintType(op->ty().WithLanes(1)->dtype, os); os << "*)&" << rhs << ")["; this->PrintExpr(op->args.back(), os); os << "]"; @@ -521,7 +527,7 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { } else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { auto func = op->args[0].as_or_throw(); // Enable atomics extension if used. - if (func->value == "atomic_add" && op->dtype.is_float()) { + if (func->value == "atomic_add" && op->ty().code() == DLDataTypeCode::kDLFloat) { enable_atomics_ = true; this->PrintCallExtern(GetType(ffi::GetRef(op)), "atomic_add_float_emu", op->args, true, os); @@ -540,9 +546,9 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); - int lanes = op->dtype.lanes(); + int lanes = op->ty().lanes(); os << "(("; - PrintType(op->dtype, os); + PrintType(op->ty()->dtype, os); os << ")("; for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; @@ -553,9 +559,9 @@ void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // void CodeGenOpenCL::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*) os << "(("; - PrintType(op->dtype, os); + PrintType(op->ty()->dtype, os); os << ")("; - int lanes = op->dtype.lanes(); + int lanes = op->ty().lanes(); for (int i = 0; i < lanes; i++) { os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i << ")"; @@ -579,18 +585,18 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // N template inline void PrintBinaryExpr(const T* op, const char* opstr, std::ostream& os, CodeGenOpenCL* p) { - if (op->dtype.lanes() == 1) { + if (op->ty().lanes() == 1) { os << opstr << "(("; - p->PrintType(op->a->dtype, os); + p->PrintType(op->a.ty()->dtype, os); os << ")"; p->PrintExpr(op->a, os); os << ", ("; - p->PrintType(op->b->dtype, os); + p->PrintType(op->b.ty()->dtype, os); os << ")"; p->PrintExpr(op->b, os); os << ')'; } else { - p->PrintVecBinaryOp(opstr, op->dtype, op->a, op->b, os); + p->PrintVecBinaryOp(opstr, op->ty()->dtype, op->a, op->b, os); } } @@ -604,14 +610,16 @@ void CodeGenOpenCL::VisitExpr_(const MaxNode* op, std::ostream& os) { void CodeGenOpenCL::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT(*) std::string opstr; - if (op->dtype.is_int() || op->dtype.is_uint()) { + PrimType op_ty = op->ty(); + if (op_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) { opstr = "%"; } else { - TVM_FFI_ICHECK(op->dtype.is_float()) - << "Expected floating point or integer dtype in Mod, but got " << op->dtype; + TVM_FFI_ICHECK(op_ty.code() == DLDataTypeCode::kDLFloat) + << "Expected floating point or integer dtype in Mod, but got " + << ffi::DLDataTypeToString(op->ty()->dtype); opstr = "fmod"; } - if (op->dtype.lanes() == 1) { + if (op_ty.lanes() == 1) { if (isalpha(opstr.c_str()[0])) { os << opstr.c_str() << '('; this->PrintExpr(op->a, os); @@ -626,7 +634,7 @@ void CodeGenOpenCL::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT os << ')'; } } else { - this->PrintVecBinaryOp(opstr.c_str(), op->dtype, op->a, op->b, os); + this->PrintVecBinaryOp(opstr.c_str(), op->ty()->dtype, op->a, op->b, os); } } @@ -634,11 +642,11 @@ void CodeGenOpenCL::VisitExpr_(const AndNode* op, std::ostream& os) { std::ostringstream oss; os << "("; this->PrintExpr(op->a, oss); - os << CastTo(oss.str(), op->dtype); + os << CastTo(oss.str(), op->ty()->dtype); oss.str(""); os << " && "; this->PrintExpr(op->b, oss); - os << CastTo(oss.str(), op->dtype); + os << CastTo(oss.str(), op->ty()->dtype); os << ")"; } @@ -646,11 +654,11 @@ void CodeGenOpenCL::VisitExpr_(const OrNode* op, std::ostream& os) { std::ostringstream oss; os << "("; this->PrintExpr(op->a, oss); - os << CastTo(oss.str(), op->dtype); + os << CastTo(oss.str(), op->ty()->dtype); oss.str(""); os << " || "; this->PrintExpr(op->b, oss); - os << CastTo(oss.str(), op->dtype); + os << CastTo(oss.str(), op->ty()->dtype); os << ")"; } @@ -658,18 +666,19 @@ void CodeGenOpenCL::VisitExpr_(const SelectNode* op, std::ostream& os) { std::ostringstream oss; os << "select("; PrintExpr(op->false_value, oss); - os << CastFromTo(oss.str(), op->false_value.dtype(), op->dtype); + os << CastFromTo(oss.str(), op->false_value.ty()->dtype, op->ty()->dtype); oss.str(""); os << ", "; PrintExpr(op->true_value, oss); - os << CastFromTo(oss.str(), op->true_value.dtype(), op->dtype); + os << CastFromTo(oss.str(), op->true_value.ty()->dtype, op->ty()->dtype); oss.str(""); os << ", "; PrintExpr(op->condition, oss); - if (op->dtype.is_float()) { - os << CastTo(oss.str(), DataType::Int(op->dtype.bits(), op->dtype.lanes())); + if (op->ty().code() == DLDataTypeCode::kDLFloat) { + os << CastTo(oss.str(), DLDataType{kDLInt, static_cast(op->ty().bits()), + static_cast(op->ty().lanes())}); } else { - os << CastFromTo(oss.str(), op->condition.dtype(), op->dtype); + os << CastFromTo(oss.str(), op->condition.ty()->dtype, op->ty()->dtype); } os << ")"; } diff --git a/src/backend/opencl/codegen/codegen_opencl.h b/src/backend/opencl/codegen/codegen_opencl.h index d588a18c2029..47667e30663a 100644 --- a/src/backend/opencl/codegen/codegen_opencl.h +++ b/src/backend/opencl/codegen/codegen_opencl.h @@ -46,20 +46,20 @@ class CodeGenOpenCL final : public CodeGenC { void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageSync(const CallNode* op) final; // NOLINT(*) - void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + void PrintType(DLDataType t, std::ostream& os) final; // NOLINT(*) void PrintType(const Type& type, std::ostream& os) final; // NOLINT(*) - std::string GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base) final; - void PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base, + std::string GetVecLoad(DLDataType t, const BufferNode* buffer, PrimExpr base) final; + void PrintVecStore(const BufferNode* buffer, DLDataType t, PrimExpr base, const std::string& value) final; // NOLINT(*) - void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, + void PrintVecElemLoadExpr(DLDataType t, int i, const std::string& value, std::ostream& os) final; // NOLINT(*) // the address of load/store - void PrintVecAddr(const BufferNode* buffer, DataType t, PrimExpr base, - std::ostream& os); // NOLINT(*) - void PrintRestrict(const Var& v, std::ostream& os) final; // NOLINT(*) - std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*) - std::string CastTo(std::string value, DataType target); // NOLINT(*) - void SetTextureScope(const std::unordered_map&); // NOLINT(*) + void PrintVecAddr(const BufferNode* buffer, DLDataType t, PrimExpr base, + std::ostream& os); // NOLINT(*) + void PrintRestrict(const Var& v, std::ostream& os) final; // NOLINT(*) + std::string CastFromTo(std::string value, DLDataType from, DLDataType target); // NOLINT(*) + std::string CastTo(std::string value, DLDataType target); // NOLINT(*) + void SetTextureScope(const std::unordered_map&); // NOLINT(*) // overload visitor void VisitStmt_(const AllocBufferNode* op) final; // NOLINT(*) diff --git a/src/backend/opencl/codegen/intrin_rule_opencl.cc b/src/backend/opencl/codegen/intrin_rule_opencl.cc index f0f58be84d10..669fd1863b39 100644 --- a/src/backend/opencl/codegen/intrin_rule_opencl.cc +++ b/src/backend/opencl/codegen/intrin_rule_opencl.cc @@ -42,7 +42,7 @@ static PrimExpr DispatchIntelShuffle(const PrimExpr& e) { << "Intel warp shuffle dose not support width != warp_size"; ffi::Array opencl_args{ {StringImm("intel_sub_group_shuffle"), call->args[1], call->args[2]}}; - return Call(call->dtype, builtin::call_pure_extern(), opencl_args); + return Call(e.ty(), builtin::call_pure_extern(), opencl_args); } void RegisterOpenCLIntrinRules() { @@ -75,7 +75,7 @@ TVM_REGISTER_OP("tirx.round") for (auto arg : call->args) { new_args.push_back(arg); } - return tirx::Call(call->dtype, tirx::builtin::call_pure_extern(), new_args); + return tirx::Call(e.ty(), tirx::builtin::call_pure_extern(), new_args); }); TVM_REGISTER_OP("tirx.nearbyint") diff --git a/src/backend/opencl/runtime/opencl_common.h b/src/backend/opencl/runtime/opencl_common.h index 3b99fa166def..4fc7ce85e383 100644 --- a/src/backend/opencl/runtime/opencl_common.h +++ b/src/backend/opencl/runtime/opencl_common.h @@ -186,24 +186,25 @@ inline const char* CLGetErrorString(cl_int error) { } inline cl_channel_type DTypeToOpenCLChannelType(DLDataType data_type) { - DataType dtype(data_type); - dtype = dtype.with_lanes(1); + DLDataType dtype = data_type; + // OpenCL image channel type depends on the scalar element type, not vector lanes. + dtype.lanes = 1; - if (dtype == DataType::Float(32)) { + if (dtype == DLDataType{kDLFloat, 32, 1}) { return CL_FLOAT; - } else if (dtype == DataType::Float(16)) { + } else if (dtype == DLDataType{kDLFloat, 16, 1}) { return CL_HALF_FLOAT; - } else if (dtype == DataType::Int(8)) { + } else if (dtype == DLDataType{kDLInt, 8, 1}) { return CL_SIGNED_INT8; - } else if (dtype == DataType::Int(16)) { + } else if (dtype == DLDataType{kDLInt, 16, 1}) { return CL_SIGNED_INT16; - } else if (dtype == DataType::Int(32)) { + } else if (dtype == DLDataType{kDLInt, 32, 1}) { return CL_SIGNED_INT32; - } else if (dtype == DataType::UInt(8)) { + } else if (dtype == DLDataType{kDLUInt, 8, 1}) { return CL_UNSIGNED_INT8; - } else if (dtype == DataType::UInt(16)) { + } else if (dtype == DLDataType{kDLUInt, 16, 1}) { return CL_UNSIGNED_INT16; - } else if (dtype == DataType::UInt(32)) { + } else if (dtype == DLDataType{kDLUInt, 32, 1}) { return CL_UNSIGNED_INT32; } TVM_FFI_THROW(InternalError) << "data type is not supported in OpenCL runtime yet: " << dtype; diff --git a/src/backend/opencl/runtime/opencl_device_api.cc b/src/backend/opencl/runtime/opencl_device_api.cc index eeb8e95ad543..0b53a1915192 100644 --- a/src/backend/opencl/runtime/opencl_device_api.cc +++ b/src/backend/opencl/runtime/opencl_device_api.cc @@ -779,14 +779,12 @@ TVM_FFI_STATIC_INIT_BLOCK() { int64_t height = shape[1]; int64_t depth = shape[2]; int64_t channel_size = args[7].cast(); - DataType channel_type = GetChannelType(channel_size); + DLDataType channel_type = GetChannelType(channel_size); Device dev; dev.device_type = static_cast(device_type); dev.device_id = device_id; DLDataType type_hint; - type_hint.code = channel_type.code(); - type_hint.bits = channel_type.bits(); - type_hint.lanes = channel_type.lanes(); + type_hint = channel_type; *rv = OpenCLWorkspace::Global()->AllocDataSpace( dev, static_cast(width), static_cast(height), diff --git a/src/backend/opencl/runtime/texture.h b/src/backend/opencl/runtime/texture.h index a8711805cbfa..3aa2d3681142 100644 --- a/src/backend/opencl/runtime/texture.h +++ b/src/backend/opencl/runtime/texture.h @@ -120,15 +120,13 @@ size_t GetTextureMemorySize(T shape, int bits, int lanes, std::string mem_scope, /*! * \brief Returns the standard channel datatype for any given type. * \param channel_size The Number of bits in a Channel - * \return DataType to be used in the codegen. + * \return DLDataType to be used in the codegen. */ -inline DataType GetChannelType(size_t channel_size) { - DataType channel_type; - +inline DLDataType GetChannelType(size_t channel_size) { if (channel_size == 128) - return DataType::Float(32, 4); + return DLDataType{kDLFloat, 32, 4}; else if (channel_size == 64) - return DataType::Float(16, 4); + return DLDataType{kDLFloat, 16, 4}; TVM_FFI_THROW(InternalError) << "Unsupported Channel Size: " << channel_size; } diff --git a/src/backend/rocm/codegen/llvm/codegen_amdgpu.cc b/src/backend/rocm/codegen/llvm/codegen_amdgpu.cc index 22ce75cddade..6f70343f46a4 100644 --- a/src/backend/rocm/codegen/llvm/codegen_amdgpu.cc +++ b/src/backend/rocm/codegen/llvm/codegen_amdgpu.cc @@ -100,7 +100,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { llvm::Value* buf = nullptr; StorageInfo& info = alloc_storage_info_[op->buffer->data.get()]; auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data)); - DataType dtype = op->buffer->dtype; + PrimType dtype = op->buffer->dtype; if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { LOG(WARNING) << "Dynamic shared memory support for rocm is experimental."; @@ -188,7 +188,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id); #endif llvm::Value* result = builder_->CreateCall(f, {}); - return this->CreateCast(DataType::Int(32), iv->var->dtype, result); + return this->CreateCast(PrimType::Int(32), iv->var.ty(), result); } llvm::Value* CreateStorageSync(const CallNode* op) final { @@ -220,10 +220,11 @@ class CodeGenAMDGPU : public CodeGenLLVM { llvm::Value* CreateIntrinsic(const CallNode* op) final { if (op->op.same_as(builtin::atomic_add())) { - TVM_FFI_ICHECK(op->args[1]->dtype.bits() == 32) << "Only supports 32 bit atomic for now"; + PrimType value_ty = op->args[1].ty(); + TVM_FFI_ICHECK(value_ty.bits() == 32) << "Only supports 32 bit atomic for now"; llvm::Value* v0 = MakeValue(op->args[0]); llvm::Value* v1 = MakeValue(op->args[1]); - if (op->args[1]->dtype.is_float()) { + if (value_ty.MatchesCode(DLDataTypeCode::kDLFloat)) { return builder_->CreateAtomicRMW(llvm::AtomicRMWInst::FAdd, v0, v1, llvm::MaybeAlign(), llvm::AtomicOrdering::Monotonic); } diff --git a/src/backend/rocm/codegen/llvm/intrin_rule_rocm.cc b/src/backend/rocm/codegen/llvm/intrin_rule_rocm.cc index 4859fd5f4a24..db0f113b9c8b 100644 --- a/src/backend/rocm/codegen/llvm/intrin_rule_rocm.cc +++ b/src/backend/rocm/codegen/llvm/intrin_rule_rocm.cc @@ -50,14 +50,14 @@ inline PrimExpr DispatchPureExternOCML(const PrimExpr& e) { TVM_FFI_ICHECK_EQ(name.substr(0, 5), "tirx."); std::ostringstream intrinsic_name; - intrinsic_name << "__ocml_" << name.substr(5) << "_f" << call->dtype.bits(); + intrinsic_name << "__ocml_" << name.substr(5) << "_f" << call->ty().bits(); ffi::Array new_args = {StringImm(intrinsic_name.str())}; for (auto arg : call->args) { new_args.push_back(arg); } - return Call(call->dtype, builtin::call_pure_extern(), new_args); + return Call(call->ty(), builtin::call_pure_extern(), new_args); } inline PrimExpr DispatchShuffle(const PrimExpr& e) { @@ -66,15 +66,17 @@ inline PrimExpr DispatchShuffle(const PrimExpr& e) { TVM_FFI_ICHECK(call != nullptr); TVM_FFI_ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size PrimExpr var = call->args[1]; - TVM_FFI_ICHECK_EQ(var.dtype().bits(), 32); + PrimType var_ty = var.ty(); + TVM_FFI_ICHECK_EQ(var_ty.bits(), 32); // get own lane in self (__lane_id) PrimExpr minus_one = IntImm::Int32(-1); PrimExpr zero = IntImm::Int32(0); - PrimExpr lo = Call(DataType::Int(32), builtin::call_pure_extern(), + PrimType i32_ty = PrimType::Int(32); + PrimExpr lo = Call(i32_ty, builtin::call_pure_extern(), {StringImm("llvm.amdgcn.mbcnt.lo"), minus_one, zero}); - PrimExpr self = Call(DataType::Int(32), builtin::call_pure_extern(), - {StringImm("llvm.amdgcn.mbcnt.hi"), minus_one, lo}); + PrimExpr self = + Call(i32_ty, builtin::call_pure_extern(), {StringImm("llvm.amdgcn.mbcnt.hi"), minus_one, lo}); // compute lane to get from PrimExpr width = call->args[3]; @@ -93,12 +95,12 @@ inline PrimExpr DispatchShuffle(const PrimExpr& e) { index = Select((self & (width - 1)) + delta >= width, self, index); } // reinterprete var as int32 - bool is_int32 = var.dtype().is_int() && var.dtype().bits() == 32; - PrimExpr source = is_int32 ? var : reinterpret(DataType::Int(32), var); - PrimExpr res = Call(DataType::Int(32), builtin::call_pure_extern(), + bool is_int32 = var_ty.MatchesElementType(DLDataTypeCode::kDLInt, 32); + PrimExpr source = is_int32 ? var : reinterpret(PrimType::Int(32), var); + PrimExpr res = Call(i32_ty, builtin::call_pure_extern(), {StringImm("llvm.amdgcn.ds.bpermute"), index << 2, source}); if (!is_int32) { - res = reinterpret(var.dtype(), res); + res = reinterpret(var_ty, res); } return res; } diff --git a/src/backend/trn/codegen/codegen_trn.cc b/src/backend/trn/codegen/codegen_trn.cc index eb9d7ca4b437..631df21f8b08 100644 --- a/src/backend/trn/codegen/codegen_trn.cc +++ b/src/backend/trn/codegen/codegen_trn.cc @@ -110,7 +110,7 @@ void CodeGenTrainium::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { size_t num_buffer = 0; for (size_t i = 0; i < func->params.size(); ++i, ++num_buffer) { Var v = func->params[i]; - if (!v.dtype().is_handle()) { + if (!v.ty().IsHandle()) { LOG(FATAL) << "Trainium codegen currently only support buffer arguments"; }; std::string vid = AllocVarID(v.get()); @@ -137,16 +137,17 @@ void CodeGenTrainium::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { this->EndScope(func_scope); } -void CodeGenTrainium::PrintType(DataType t, std::ostream& os) { // NOLINT(*) +void CodeGenTrainium::PrintType(DLDataType raw_t, std::ostream& os) { // NOLINT(*) + PrimType t(raw_t); int lanes = t.lanes(); TVM_FFI_ICHECK(lanes == 1) << "Trainium codegen does not support vector types"; - TVM_FFI_ICHECK(!t.is_handle()) << "Trainium codegen does not support handle type"; - TVM_FFI_ICHECK(!t.is_void()) << "Trainium codegen does not support void type"; - if (t == DataType::Bool()) { + TVM_FFI_ICHECK(!t.IsHandle()) << "Trainium codegen does not support handle type"; + TVM_FFI_ICHECK(!t.IsVoid()) << "Trainium codegen does not support void type"; + if (t.MatchesCode(DLDataTypeCode::kDLBool)) { os << "np.bool"; return; } - if (t.is_float()) { + if (t.code() == DLDataTypeCode::kDLFloat) { switch (t.bits()) { case 16: os << "np.float16"; @@ -160,13 +161,13 @@ void CodeGenTrainium::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } return; } - if (t.is_uint() || t.is_int()) { + if (t.MatchesCode(DLDataTypeCode::kDLUInt, DLDataTypeCode::kDLInt)) { if (t.bits() == 1) { os << "np.bool"; return; } os << "np."; - if (t.is_uint()) { + if (t.MatchesCode(DLDataTypeCode::kDLUInt)) { os << 'u'; } switch (t.bits()) { @@ -188,11 +189,11 @@ void CodeGenTrainium::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } return; } - if (t.is_bfloat16()) { + if (t.code() == DLDataTypeCode::kDLBfloat && t.bits() == 16) { os << "nl.bfloat16"; return; } - LOG(FATAL) << "Cannot convert type " << t << " to Trainium type"; + LOG(FATAL) << "Cannot convert type " << raw_t << " to Trainium type"; } std::string CodeGenTrainium::GetStorageScopeStr(const std::string& scope) { // NOLINT(*) @@ -215,7 +216,7 @@ void CodeGenTrainium::VisitStmt_(const AllocBufferNode* op) { this->PrintIndent(); auto scope = GetPtrStorageScope(op->buffer->data); std::ostringstream dtype_os; - PrintType(op->buffer->dtype, dtype_os); + PrintType(op->buffer->dtype->dtype, dtype_os); std::string dtype_str = dtype_os.str(); if (scope == "trn.psum") { stream << vid << " = nl.ndarray(shape=["; @@ -589,7 +590,7 @@ void CodeGenTrainium::VisitExpr_(const VarNode* op, std::ostream& os) { // NOLI } void CodeGenTrainium::VisitExpr_(const CastNode* op, std::ostream& os) { - ctx_.dst_dtype = op->dtype; + ctx_.dst_dtype = op->ty(); CodeGenTrainium::VisitExpr(op->value, os); } diff --git a/src/backend/trn/codegen/codegen_trn.h b/src/backend/trn/codegen/codegen_trn.h index 2c3b5fd37393..ec4eaad29cce 100644 --- a/src/backend/trn/codegen/codegen_trn.h +++ b/src/backend/trn/codegen/codegen_trn.h @@ -41,7 +41,7 @@ struct NKIInstructionCtx { bool is_matmul_input = false; int buffer_index = -1; int used_var_cnt = 0; - DataType dst_dtype; + PrimType dst_dtype = PrimType::Void(); PrimExpr mask; bool tensorizing = false; }; @@ -57,7 +57,7 @@ class CodeGenTrainium final : public CodeGenC { void InitFuncState(const PrimFunc& f) final; std::string GetStorageScopeStr(const std::string& scope); // NOLINT(*) void VisitExpr_(const VarNode* op, std::ostream& os) final; // NOLINT(*) - void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + void PrintType(DLDataType t, std::ostream& os) final; // NOLINT(*) void VisitStmt_(const AllocBufferNode* op) final; // NOLINT(*) void VisitStmt_(const AttrStmtNode* op) final; // NOLINT(*) void VisitStmt_(const ForNode* op) final; // NOLINT(*) diff --git a/src/backend/trn/transform/lower_trainium_layout.cc b/src/backend/trn/transform/lower_trainium_layout.cc index ad4b206a48b2..fb1d92c5215d 100644 --- a/src/backend/trn/transform/lower_trainium_layout.cc +++ b/src/backend/trn/transform/lower_trainium_layout.cc @@ -176,8 +176,8 @@ class TrainiumLayoutApplier : public arith::IRMutatorWithAnalyzer { flattened = buf.GetFlattenedBuffer(); writer = flattened.CopyOnWrite(); } - if (flattened->dtype == DataType::Bool()) { - writer->dtype = DataType::Int(8); + if (flattened->dtype->dtype == DLDataType{kDLBool, 8, 1}) { + writer->dtype = PrimType::Int(8); } for (size_t i = 0; i < flattened->shape.size(); ++i) { writer->shape.Set(i, analyzer_->canonical_simplify(flattened->shape[i])); @@ -191,28 +191,30 @@ class TrainiumLayoutApplier : public arith::IRMutatorWithAnalyzer { Stmt VisitStmt_(const BufferStoreNode* op) final { BufferStore store = StmtExprMutator::VisitStmt_(op).as_or_throw(); - bool store_returns_bool = (op->value.dtype() == DataType::Bool()); + PrimType store_value_ty = op->value.ty(); + bool store_returns_bool = store_value_ty.MatchesCode(DLDataTypeCode::kDLBool); store = VisitBufferAccess(store); if (store_returns_bool) { - TVM_FFI_ICHECK_EQ(store->buffer->dtype, DataType::Int(8)) + TVM_FFI_ICHECK_EQ(store->buffer->dtype->dtype, (DLDataType{kDLInt, 8, 1})) << "Expected int8 backing array for boolean tensor"; auto writer = store.CopyOnWrite(); - writer->value = tvm::cast(DataType::Int(8), store->value); + writer->value = tvm::cast(PrimType::Int(8), store->value); return std::move(store); } return std::move(store); } PrimExpr VisitExpr_(const BufferLoadNode* op) final { - bool load_returns_bool = (op->dtype == DataType::Bool()); + PrimType load_ty = op->ty(); + bool load_returns_bool = load_ty.MatchesCode(DLDataTypeCode::kDLBool); BufferLoad load = StmtExprMutator::VisitExpr_(op).as_or_throw(); load = VisitBufferAccess(load); if (load_returns_bool) { - TVM_FFI_ICHECK_EQ(load->buffer->dtype, DataType::Int(8)) + TVM_FFI_ICHECK_EQ(load->buffer->dtype->dtype, (DLDataType{kDLInt, 8, 1})) << "Expected int8 backing array for boolean tensor"; - load.CopyOnWrite()->dtype = DataType::Int(8); - return tvm::cast(DataType::Bool(), load); + load.CopyOnWrite()->BaseExprNode::ty = PrimType::Int(8); + return tvm::cast(PrimType::Bool(), load); } else { return std::move(load); } diff --git a/src/backend/vulkan/codegen/codegen_spirv.cc b/src/backend/vulkan/codegen/codegen_spirv.cc index 5737c60da9dc..094e31370481 100644 --- a/src/backend/vulkan/codegen/codegen_spirv.cc +++ b/src/backend/vulkan/codegen/codegen_spirv.cc @@ -52,8 +52,8 @@ runtime::SPIRVShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::s const uint32_t descriptor_set = 0; for (Var arg : f->params) { - DataType t = arg.dtype(); - if (t.is_handle()) { + PrimType t = PrimType(arg.ty()->dtype); + if (t.IsHandle()) { auto* ptr = arg->type_annotation.as(); TVM_FFI_ICHECK(ptr) << "All handles passed to the Vulkan codegen must have a type_annotation as a " @@ -64,11 +64,11 @@ runtime::SPIRVShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::s << "All handles passed to the Vulkan codegen must have a type_annotation as a " "PointerType, " << "and must point to a PrimType"; - DataType value_storage_type = prim->dtype; - if (value_storage_type == DataType::Bool()) { + PrimType value_storage_type(prim->dtype); + if (value_storage_type == PrimType::Bool()) { // We need a physically addressable buffer type to support boolean tensors. // The loaded byte is cast to bool inside the LoadNode visitor below. - value_storage_type = boolean_storage_type_.with_lanes(value_storage_type.lanes()); + value_storage_type = boolean_storage_type_.WithLanes(value_storage_type.lanes()); } spirv::Value arg_value = builder_->BufferArgument(builder_->GetSType(value_storage_type), descriptor_set, i_buffer++); @@ -87,7 +87,7 @@ runtime::SPIRVShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::s if (pod_args.size() != 0) { std::vector value_types; for (size_t i = 0; i < pod_args.size(); ++i) { - value_types.push_back(builder_->GetSType(pod_args[i].dtype())); + value_types.push_back(builder_->GetSType(PrimType(pod_args[i].ty()->dtype))); } if (pod_args.size() * sizeof(runtime::ArgUnion64) <= runtime::vulkan::kMaxPushConstantsBytes) { spirv::Value ptr = builder_->DeclarePushConstant(value_types); @@ -150,7 +150,7 @@ spirv::Value CodeGenSPIRV::GetThreadIndex(const IterVar& iv, const PrimExpr& ext } else { v = builder_->GetWorkgroupID(ts.dim_index); } - return builder_->Cast(builder_->GetSType(iv->var.dtype()), v); + return builder_->Cast(builder_->GetSType(PrimType(iv->var.ty()->dtype)), v); } spirv::Value CodeGenSPIRV::CreateStorageSync(const CallNode* op) { @@ -179,7 +179,7 @@ spirv::Value CodeGenSPIRV::CreateStorageSync(const CallNode* op) { TVM_FFI_THROW(InternalError) << "Do not support sync " << sync; } - auto type_int = builder_->GetSType(DataType::Int(32)); + auto type_int = builder_->GetSType(PrimType::Int(32)); builder_->MakeInst(spv::OpControlBarrier, builder_->IntImm(type_int, sync_scope), builder_->IntImm(type_int, sync_scope), builder_->IntImm(type_int, memory_semantics)); @@ -194,11 +194,11 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const VarNode* op) { } spirv::Value CodeGenSPIRV::VisitExpr_(const IntImmNode* op) { - return builder_->IntImm(builder_->GetSType(op->dtype), op->value); + return builder_->IntImm(builder_->GetSType(PrimType(op->ty()->dtype)), op->value); } spirv::Value CodeGenSPIRV::VisitExpr_(const FloatImmNode* op) { - return builder_->FloatImm(builder_->GetSType(op->dtype), op->value); + return builder_->FloatImm(builder_->GetSType(PrimType(op->ty()->dtype)), op->value); } spirv::Value CodeGenSPIRV::VisitExpr_(const StringImmNode* op) { @@ -206,7 +206,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const StringImmNode* op) { } spirv::Value CodeGenSPIRV::VisitExpr_(const CastNode* op) { - return builder_->Cast(builder_->GetSType(op->dtype), MakeValue(op->value)); + return builder_->Cast(builder_->GetSType(PrimType(op->ty()->dtype)), MakeValue(op->value)); } spirv::Value CodeGenSPIRV::VisitExpr_(const AddNode* op) { @@ -308,7 +308,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { for (size_t i = 1; i < op->args.size(); ++i) { values.push_back(MakeValue(op->args[i])); } - return builder_->CallGLSL450(builder_->GetSType(op->dtype), inst_id, values); + return builder_->CallGLSL450(builder_->GetSType(PrimType(op->ty()->dtype)), inst_id, values); } else if (op->op.same_as(builtin::bitwise_and())) { TVM_FFI_ICHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); @@ -337,20 +337,20 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { TVM_FFI_ICHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); spirv::Value b = MakeValue(op->args[1]); - if (op->args[0].dtype().is_int()) { + if (PrimType(op->args[0].ty()->dtype).MatchesCode(DLDataTypeCode::kDLInt)) { return builder_->MakeValue(spv::OpShiftRightArithmetic, a.stype, a, b); } else { return builder_->MakeValue(spv::OpShiftRightLogical, a.stype, a, b); } } else if (op->op.same_as(builtin::reinterpret())) { - return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->dtype), + return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(PrimType(op->ty()->dtype)), MakeValue(op->args[0])); } else if (op->op.same_as(builtin::large_uint_imm())) { TVM_FFI_ICHECK_EQ(op->args.size(), 2U); uint64_t low = static_cast(op->args[0].as_or_throw()->value); uint64_t high = static_cast(op->args[1].as_or_throw()->value); uint64_t val = (high << 32U) | low; - return builder_->UIntImm(builder_->GetSType(op->dtype), val); + return builder_->UIntImm(builder_->GetSType(PrimType(op->ty()->dtype)), val); } else if (op->op.same_as(builtin::tvm_storage_sync())) { return this->CreateStorageSync(op); } else if (op->op.same_as(builtin::if_then_else())) { @@ -378,7 +378,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { phi.SetIncoming(1, else_value, else_value_label); return phi; } else if (op->op.same_as(builtin::popcount())) { - return builder_->MakeValue(spv::OpBitCount, builder_->GetSType(op->dtype), + return builder_->MakeValue(spv::OpBitCount, builder_->GetSType(PrimType(op->ty()->dtype)), MakeValue(op->args[0])); } else if (op->op.same_as(builtin::call_pure_extern())) { TVM_FFI_ICHECK_GE(op->args.size(), 1U); @@ -388,7 +388,8 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { for (size_t i = 1; i < op->args.size(); ++i) { values.push_back(MakeValue(op->args[i])); } - return builder_->CallKHRIntegerDotProduct(builder_->GetSType(op->dtype), values, op->dtype); + PrimType op_dtype(op->ty()->dtype); + return builder_->CallKHRIntegerDotProduct(builder_->GetSType(op_dtype), values, op_dtype); } else { TVM_FFI_THROW(InternalError) << "SPIR-V shader cannot make extern calls. Graph contains extern \"" @@ -412,8 +413,9 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { TVM_FFI_ICHECK_EQ(op->args.size(), 6U); const VarNode* buffer_node = op->args[0].as(); TVM_FFI_ICHECK(buffer_node && fragment_info_.count(buffer_node)); - DataType ele_dtype = GetElementDataType(buffer_node); - TVM_FFI_ICHECK(ele_dtype.is_float()) << "Only floating point fragment accumulator is supported"; + PrimType ele_dtype = GetElementDataType(buffer_node); + TVM_FFI_ICHECK(ele_dtype.MatchesCode(DLDataTypeCode::kDLFloat)) + << "Only floating point fragment accumulator is supported"; spirv::SType ele_stype = builder_->GetSType(ele_dtype); spirv::SType& fragment_type = fragment_info_[buffer_node].stype; double init = static_cast(op->args[5].as_or_throw()->value); @@ -435,7 +437,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { PrimExpr dst_index = op->args[4]; PrimExpr src_ptr_expr = op->args[5]; int stride = static_cast(op->args[6].as_or_throw()->value); - auto type_int = builder_->GetSType(DataType::Int(32)); + auto type_int = builder_->GetSType(PrimType::Int(32)); spirv::Value stride_val = builder_->IntImm(type_int, stride); std::string layout = (op->args[7].as())->value; spirv::SType dst_ptr_type = @@ -443,7 +445,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { spirv::Value dst_ptr = builder_->StructArrayAccess(dst_ptr_type, var_map_[buffer_node], MakeValue(dst_index)); spirv::Value src_ptr = VisitExpr(op->args[5]); - spirv::SType type_bool = builder_->GetSType(DataType::Bool()); + spirv::SType type_bool = builder_->GetSType(PrimType::Bool()); spirv::Value t_val = builder_->UIntImm(type_bool, 1); spirv::Value f_val = builder_->UIntImm(type_bool, 0); spirv::Value loaded = @@ -494,7 +496,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { PrimExpr index = op->args[4]; PrimExpr buffer_ptr = op->args[5]; int stride = static_cast(op->args[6].as_or_throw()->value); - auto type_int = builder_->GetSType(DataType::Int(32)); + auto type_int = builder_->GetSType(PrimType::Int(32)); spirv::Value stride_val = builder_->IntImm(type_int, stride); std::string layout = (op->args[7].as())->value; spirv::Value dst_ptr = VisitExpr(op->args[5]); @@ -505,7 +507,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], MakeValue(index)); uint32_t mask = spv::MemoryAccessMaskNone; spirv::Value loaded = builder_->MakeValue(spv::OpLoad, fragment_type, ptr, mask); - spirv::SType type_bool = builder_->GetSType(DataType::Bool()); + spirv::SType type_bool = builder_->GetSType(PrimType::Bool()); spirv::Value t_val = builder_->UIntImm(type_bool, 1); spirv::Value f_val = builder_->UIntImm(type_bool, 0); builder_->MakeInst(spv::OpCooperativeMatrixStoreNV, dst_ptr, loaded, stride_val, @@ -516,7 +518,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { Var buffer_var = load->buffer->data; const VarNode* buffer_node = buffer_var.get(); PrimExpr index = load->indices[0]; - DataType ele_dtype = GetElementDataType(buffer_node); + PrimType ele_dtype = GetElementDataType(buffer_node); spirv::SType ele_stype = builder_->GetSType(ele_dtype); spirv::Value buffer_val = MakeValue(buffer_var); spirv::SType ptr_type = builder_->GetPointerType(ele_stype, buffer_val.stype.storage_class); @@ -532,11 +534,11 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const RampNode* op) { std::vector values; spirv::Value base = MakeValue(op->base); - int lanes = op->dtype.lanes(); + int lanes = op->ty().lanes(); for (int i = 0; i < lanes; ++i) { spirv::Value v = base; if (i != 0) { - spirv::Value offset = MakeValue(MakeConst(op->stride.dtype(), i) * op->stride); + spirv::Value offset = MakeValue(MakeConst(op->stride.ty(), i) * op->stride); v = builder_->Add(v, offset); } values.push_back(v); @@ -547,7 +549,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const RampNode* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const BroadcastNode* op) { std::vector values; spirv::Value v = MakeValue(op->value); - int lanes = op->dtype.lanes(); + int lanes = op->ty().lanes(); for (int i = 0; i < lanes; i++) { values.push_back(v); } @@ -560,15 +562,15 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const BufferLoadNode* op) { Var buffer_var = op->buffer->data; PrimExpr prim_index = op->indices[0]; - DataType desired_read_type = op->dtype; - if (desired_read_type == DataType::Bool()) { - desired_read_type = boolean_storage_type_.with_lanes(desired_read_type.lanes()); + PrimType desired_read_type(op->ty()->dtype); + if (desired_read_type == PrimType::Bool()) { + desired_read_type = boolean_storage_type_.WithLanes(desired_read_type.lanes()); } auto it = storage_info_.find(buffer_var.get()); TVM_FFI_ICHECK(it != storage_info_.end()); StorageInfo& info = it->second; - info.CheckContentType(desired_read_type, prim_index.dtype().lanes()); + info.CheckContentType(desired_read_type, PrimType(prim_index.ty()->dtype).lanes()); spirv::SType content_type = builder_->GetSType(info.element_type); spirv::Value buffer = MakeValue(buffer_var); @@ -588,13 +590,13 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const BufferLoadNode* op) { spirv::Value loaded = builder_->MakeValue(spv::OpLoad, content_type, ptr, mask); // OpTypeBool have no physical address/storage. Here, cast from // the storage type to an OpTypeBool. - if (op->dtype == DataType::Bool()) { - auto spirv_bool = builder_->GetSType(DataType::Bool()); + if (PrimType(op->ty()->dtype) == PrimType::Bool()) { + auto spirv_bool = builder_->GetSType(PrimType::Bool()); loaded = builder_->Cast(spirv_bool, loaded); } return loaded; - } else if (desired_read_type.element_of() == info.element_type) { + } else if (desired_read_type.WithLanes(1) == info.element_type) { // Requested several elements returned as an array. Read out each // element and concatenate into the result. std::vector values; @@ -609,21 +611,22 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const BufferLoadNode* op) { TVM_FFI_THROW(InternalError) << "Cannot perform buffer access of buffer variable '" << buffer_var->name_hint << "' with element type " << info.element_type << " using index of type " - << prim_index->dtype << " to produce output of type " << op->dtype; + << PrimType(prim_index.ty()->dtype) + << " to produce output of type " << PrimType(op->ty()->dtype); return spirv::Value(); } } void CodeGenSPIRV::Scalarize(const PrimExpr& e, std::function f) { if (const RampNode* ramp = e.as()) { - for (int i = 0; i < ramp->dtype.lanes(); ++i) { + for (int i = 0; i < ramp->ty().lanes(); ++i) { PrimExpr offset = ramp->base + ramp->stride * i; f(i, MakeValue(offset)); } } else { - spirv::SType etype = builder_->GetSType(e.dtype().element_of()); + spirv::SType etype = builder_->GetSType(PrimType(e.ty()->dtype).WithLanes(1)); spirv::Value value = MakeValue(e); - for (int i = 0; i < e.dtype().lanes(); ++i) { + for (int i = 0; i < PrimType(e.ty()->dtype).lanes(); ++i) { f(i, builder_->MakeValue(spv::OpCompositeExtract, etype, value, i)); } } @@ -635,7 +638,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const ShuffleNode* op) { << "of one vector with one index"; spirv::Value vector = MakeValue(op->vectors[0]); int index = op->indices[0].as_or_throw()->value; - spirv::SType etype = builder_->GetSType(op->dtype); + spirv::SType etype = builder_->GetSType(PrimType(op->ty()->dtype)); spirv::Value element = builder_->MakeValue(spv::OpCompositeExtract, etype, vector, index); return element; } @@ -649,7 +652,7 @@ void CodeGenSPIRV::VisitStmt_(const BufferStoreNode* op) { auto it = storage_info_.find(buffer_var.get()); TVM_FFI_ICHECK(it != storage_info_.end()); StorageInfo& info = it->second; - info.CheckContentType(op->value.dtype(), prim_index.dtype().lanes()); + info.CheckContentType(PrimType(op->value.ty()->dtype), PrimType(prim_index.ty()->dtype).lanes()); spirv::SType content_type = builder_->GetSType(info.element_type); spirv::Value buffer = MakeValue(buffer_var); @@ -661,16 +664,16 @@ void CodeGenSPIRV::VisitStmt_(const BufferStoreNode* op) { mask |= spv::MemoryAccessVolatileMask; } - if (op->value.dtype() == info.element_type) { + if (PrimType(op->value.ty()->dtype) == info.element_type) { // Requested store of a single value. This may be a scalar store // or a vectorized store, based on the array element type. - TVM_FFI_ICHECK_EQ(info.element_type, op->value.dtype()) + TVM_FFI_ICHECK_EQ(info.element_type, PrimType(op->value.ty()->dtype)) << "Vulkan only allow one type access to the same buffer"; spirv::Value index = MakeValue(prim_index); spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); builder_->MakeInst(spv::OpStore, ptr, value, mask); - } else if (op->value.dtype().element_of() == info.element_type) { + } else if (PrimType(op->value.ty()->dtype).WithLanes(1) == info.element_type) { // Requested store of several arbitrarily located values. Extract // each value from the composite, then assign to the buffer. auto f = [&](int i, spirv::Value index) { @@ -681,10 +684,10 @@ void CodeGenSPIRV::VisitStmt_(const BufferStoreNode* op) { this->Scalarize(prim_index, f); } else { - TVM_FFI_THROW(InternalError) << "Cannot store value of type " << op->value.dtype() + TVM_FFI_THROW(InternalError) << "Cannot store value of type " << PrimType(op->value.ty()->dtype) << " into buffer variable '" << buffer_var->name_hint << "' with element type " << info.element_type - << " using index of type " << prim_index->dtype; + << " using index of type " << PrimType(prim_index.ty()->dtype); } } @@ -697,10 +700,11 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { // loop step spirv::Value step; if (op->HasTrivialStep()) { - step = op->loop_var.dtype().is_int() ? builder_->IntImm(init_value.stype, 1) - : builder_->UIntImm(init_value.stype, 1); + step = PrimType(op->loop_var.ty()->dtype).MatchesCode(DLDataTypeCode::kDLInt) + ? builder_->IntImm(init_value.stype, 1) + : builder_->UIntImm(init_value.stype, 1); } else { - step = MakeValue(tvm::cast(end->dtype, *op->step)); + step = MakeValue(tvm::cast(end.ty(), *op->step)); } // Must get init label after making value(to make sure they are correct) @@ -807,7 +811,7 @@ void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) { } void CodeGenSPIRV::VisitStmt_(const AllocBufferNode* op) { - TVM_FFI_ICHECK(!op->buffer->dtype.is_handle()); + TVM_FFI_ICHECK(!op->buffer->dtype.IsHandle()); const IntImmNode* dim_imm = op->buffer->shape[0].as(); TVM_FFI_ICHECK(dim_imm) << "Can only handle constant size stack allocation in GPU"; size_t constant_size = static_cast(dim_imm->value); @@ -848,7 +852,7 @@ void CodeGenSPIRV::VisitStmt_(const AllocBufferNode* op) { int32_t aligned_constant_size = ((constant_size + 3) & ~0x3); buf = builder_->Allocate(etype, static_cast(aligned_constant_size), storage_class); - size_t num_bytes = op->buffer->dtype.bytes() * op->buffer->dtype.lanes() * + size_t num_bytes = ((op->buffer->dtype.bits() + 7) / 8) * op->buffer->dtype.lanes() * static_cast(aligned_constant_size); shared_memory_bytes_used_ += num_bytes; } break; @@ -897,7 +901,7 @@ void CodeGenSPIRV::VisitStmt_(const AssertStmtNode* op) { void CodeGenSPIRV::VisitStmt_(const BindNode* op) { TVM_FFI_ICHECK(!var_map_.count(op->var.get())); - TVM_FFI_ICHECK(!op->var.dtype().is_handle()); + TVM_FFI_ICHECK(!PrimType(op->var.ty()->dtype).IsHandle()); var_map_[op->var.get()] = MakeValue(op->value); analyzer_->Bind(op->var, op->value); } @@ -910,18 +914,18 @@ void CodeGenSPIRV::VisitStmt_(const SeqStmtNode* op) { void CodeGenSPIRV::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); } -spirv::SType CodeGenSPIRV::GetFragmentSType(const VarNode* buffer, const DataType& dtype) { +spirv::SType CodeGenSPIRV::GetFragmentSType(const VarNode* buffer, const PrimType& dtype) { TVM_FFI_ICHECK(fragment_info_.count(buffer)); const std::string& scope = fragment_info_[buffer].scope; const std::string& shape_str = fragment_info_.at(buffer).shape; std::pair dim = GetWmmaFragmentDimSize(shape_str, scope); int64_t size = dim.first * dim.second; - spirv::SType stype = builder_->GetSType(dtype.with_lanes(size), dim.first, dim.second); + spirv::SType stype = builder_->GetSType(dtype.WithLanes(size), dim.first, dim.second); fragment_info_[buffer].stype = stype; return stype; } -DataType CodeGenSPIRV::GetElementDataType(const VarNode* buffer) { +PrimType CodeGenSPIRV::GetElementDataType(const VarNode* buffer) { auto it = storage_info_.find(buffer); TVM_FFI_ICHECK(it != storage_info_.end()); return it->second.element_type; diff --git a/src/backend/vulkan/codegen/codegen_spirv.h b/src/backend/vulkan/codegen/codegen_spirv.h index 46fbcb696b6f..5ade6e383908 100644 --- a/src/backend/vulkan/codegen/codegen_spirv.h +++ b/src/backend/vulkan/codegen/codegen_spirv.h @@ -142,7 +142,7 @@ class CodeGenSPIRV : public ExprFunctor, * buffer variable (AllocBufferNode) or of the parameter (shader * arguments). */ - DataType element_type{DataType()}; + PrimType element_type{PrimType::Void()}; /* \brief Check that the access type matches the known type * @@ -156,10 +156,10 @@ class CodeGenSPIRV : public ExprFunctor, * product of the number of lanes of the buffer element type and * the number of lanes of the index. */ - void CheckContentType(DataType type, int index_lanes = 1) const { + void CheckContentType(PrimType type, int index_lanes = 1) const { TVM_FFI_ICHECK(element_type_known) << "Cannot check element type of buffer " << name_hint << " no previous element type defined"; - DataType expected_type = element_type.with_lanes(index_lanes * element_type.lanes()); + PrimType expected_type = element_type.WithLanes(index_lanes * element_type.lanes()); TVM_FFI_ICHECK_EQ(type, expected_type) << "Attempted to access buffer " << name_hint << " as element type " << type << " using an index of size " << index_lanes << " when the element type is " @@ -167,7 +167,7 @@ class CodeGenSPIRV : public ExprFunctor, } // Update content type if it hasn't been updated. - void SetContentType(DataType type, std::string name_hint) { + void SetContentType(PrimType type, std::string name_hint) { TVM_FFI_ICHECK(!element_type_known) << "Cannot set element type of buffer " << name_hint << " a second time."; this->element_type = type; @@ -191,8 +191,8 @@ class CodeGenSPIRV : public ExprFunctor, spirv::Value CreateStorageSync(const CallNode* op); void Scalarize(const PrimExpr& e, std::function f); - spirv::SType GetFragmentSType(const VarNode* buffer, const DataType& dtype); - DataType GetElementDataType(const VarNode* buffer); + spirv::SType GetFragmentSType(const VarNode* buffer, const PrimType& dtype); + PrimType GetElementDataType(const VarNode* buffer); // SPIRV-related capabilities of the target SPIRVSupport spirv_support_; @@ -213,7 +213,7 @@ class CodeGenSPIRV : public ExprFunctor, * integer type supported by the device, as not all Vulkan * implementations support int8. */ - DataType boolean_storage_type_{DataType::Int(8)}; + PrimType boolean_storage_type_{PrimType::Int(8)}; // the storage scope of allocation std::unordered_map storage_info_; diff --git a/src/backend/vulkan/codegen/intrin_rule_spirv.cc b/src/backend/vulkan/codegen/intrin_rule_spirv.cc index 14287562d9e4..6deb6e0a9b61 100644 --- a/src/backend/vulkan/codegen/intrin_rule_spirv.cc +++ b/src/backend/vulkan/codegen/intrin_rule_spirv.cc @@ -39,12 +39,12 @@ PrimExpr CallGLSLIntrin(PrimExpr e, const ffi::Array& args) { TVM_FFI_ICHECK(call != nullptr); ffi::Array cargs; // intrin id. - cargs.push_back(IntImm(DataType::UInt(32), id)); + cargs.push_back(IntImm(PrimType::UInt(32), id)); for (PrimExpr arg : args) { cargs.push_back(arg); } - return tirx::Call(call->dtype, tirx::builtin::call_spirv_pure_glsl450(), cargs); + return tirx::Call(call->ty(), tirx::builtin::call_spirv_pure_glsl450(), cargs); } template @@ -166,21 +166,22 @@ TVM_REGISTER_OP("tirx.clz") TVM_FFI_ICHECK(call != nullptr); TVM_FFI_ICHECK_EQ(call->args.size(), 1); PrimExpr arg = call->args[0]; + PrimType arg_ty = arg.ty(); PrimExpr msb; - if (arg.dtype().bits() == 64) { + if (arg_ty.bits() == 64) { // SPIR-V FindUMsb intrinsic only supports 32 bit input - auto int32 = DataType::Int(32); + auto int32 = PrimType::Int(32); PrimExpr arg_hi32 = tvm::tirx::Cast(int32, arg >> 32); PrimExpr arg_lo32 = tvm::tirx::Cast(int32, arg); PrimExpr msb_hi = CallGLSLIntrin(e, {arg_hi32}); PrimExpr msb_lo = CallGLSLIntrin(e, {arg_lo32}); msb = tvm::if_then_else(arg_hi32 == 0, msb_lo, msb_hi + 32); - } else if (arg.dtype().bits() == 32) { + } else if (arg_ty.bits() == 32) { msb = CallGLSLIntrin(e); } else { TVM_FFI_THROW(InternalError) << "SPIR-V clz only supports a 32 bit or 64 bit integer."; } - return PrimExpr(arg.dtype().bits() - 1) - msb; + return PrimExpr(arg_ty.bits() - 1) - msb; }); // clang-format on } diff --git a/src/backend/vulkan/codegen/ir_builder.cc b/src/backend/vulkan/codegen/ir_builder.cc index f912e482761c..ca82b06b0554 100644 --- a/src/backend/vulkan/codegen/ir_builder.cc +++ b/src/backend/vulkan/codegen/ir_builder.cc @@ -74,10 +74,10 @@ void IRBuilder::InitHeader() { void IRBuilder::InitPreDefs() { ext_glsl450_ = ExtInstImport("GLSL.std.450"); - t_int32_ = DeclareType(DataType::Int(32)); - t_uint32_ = DeclareType(DataType::UInt(32)); - t_bool_ = DeclareType(DataType::Bool()); - t_fp32_ = DeclareType(DataType::Float(32)); + t_int32_ = DeclareType(PrimType::Int(32)); + t_uint32_ = DeclareType(PrimType::UInt(32)); + t_bool_ = DeclareType(PrimType::Bool()); + t_fp32_ = DeclareType(PrimType::Float(32)); const_i32_zero_ = IntImm(t_int32_, 0); // declare void, and void functions @@ -112,14 +112,14 @@ std::vector IRBuilder::Finalize() { return data; } -SType IRBuilder::GetSType(const DataType& dtype, uint32_t row, uint32_t col) { - if (dtype == DataType::Int(32)) { +SType IRBuilder::GetSType(const PrimType& dtype, uint32_t row, uint32_t col) { + if (dtype == PrimType::Int(32)) { return t_int32_; - } else if (dtype == DataType::Bool()) { + } else if (dtype == PrimType::Bool()) { return t_bool_; - } else if (dtype == DataType::Float(32)) { + } else if (dtype == PrimType::Float(32)) { return t_fp32_; - } else if (dtype == DataType::UInt(32)) { + } else if (dtype == PrimType::UInt(32)) { return t_uint32_; } uint64_t type_key; @@ -151,7 +151,7 @@ SType IRBuilder::GetPointerType(const SType& value_type, spv::StorageClass stora } SType t; t.id = id_counter_++; - t.type = DataType::Handle(); + t.type = PrimType::Handle(); t.element_type_id = value_type.id; t.storage_class = storage_class; ib_.Begin(spv::OpTypePointer).AddSeq(t, storage_class, value_type).Commit(&global_); @@ -169,11 +169,11 @@ SType IRBuilder::GetStructArrayType(const SType& value_type, uint32_t num_elems, SType arr_type; arr_type.id = id_counter_++; - arr_type.type = DataType::Handle(); + arr_type.type = PrimType::Handle(); arr_type.element_type_id = value_type.id; if (num_elems != 0) { - Value length = UIntImm(GetSType(DataType::UInt(32)), num_elems); + Value length = UIntImm(GetSType(PrimType::UInt(32)), num_elems); ib_.Begin(spv::OpTypeArray).AddSeq(arr_type, value_type, length).Commit(&global_); } else { ib_.Begin(spv::OpTypeRuntimeArray).AddSeq(arr_type, value_type).Commit(&global_); @@ -188,7 +188,7 @@ SType IRBuilder::GetStructArrayType(const SType& value_type, uint32_t num_elems, // declare struct of array SType struct_type; struct_type.id = id_counter_++; - struct_type.type = DataType::Handle(); + struct_type.type = PrimType::Handle(); struct_type.element_type_id = value_type.id; ib_.Begin(spv::OpTypeStruct).AddSeq(struct_type, arr_type).Commit(&global_); @@ -241,7 +241,7 @@ Value IRBuilder::FloatImm(const SType& dtype, double value) { if (data == 0) return GetConst_(dtype, &data); else - return Cast(dtype, FloatImm(GetSType(DataType::Float(32)), value)); + return Cast(dtype, FloatImm(GetSType(PrimType::Float(32)), value)); } } @@ -270,7 +270,7 @@ Value IRBuilder::DeclareStorageVariable(const std::vector& value_types, spv::StorageClass storage_class, ValueKind kind) { SType struct_type; struct_type.id = id_counter_++; - struct_type.type = DataType::Handle(); + struct_type.type = PrimType::Handle(); ib_.Begin(spv::OpTypeStruct).Add(struct_type); for (const SType& vtype : value_types) { ib_.Add(vtype); @@ -282,7 +282,7 @@ Value IRBuilder::DeclareStorageVariable(const std::vector& value_types, ib_.Begin(spv::OpMemberDecorate) .AddSeq(struct_type, i, spv::DecorationOffset, offset) .Commit(&decorate_); - DataType t = value_types[i].type; + PrimType t = value_types[i].type; uint32_t nbits = t.bits() * t.lanes(); TVM_FFI_ICHECK_EQ(nbits % 8, 0); uint32_t bytes = (nbits / 8); @@ -394,13 +394,13 @@ Value IRBuilder::GetBuiltInValue(spv::BuiltIn built_in, uint32_t index, const st } } - DataType data_type; - DataType global_arr_type; + PrimType data_type; + PrimType global_arr_type; switch (built_in) { case spv::BuiltInLocalInvocationId: case spv::BuiltInWorkgroupId: - data_type = DataType::Int(32); - global_arr_type = data_type.with_lanes(3); + data_type = PrimType::Int(32); + global_arr_type = data_type.WithLanes(3); break; default: @@ -468,7 +468,7 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) { } TVM_FFI_ICHECK_LE(dtype.type.bits(), 64); Value ret = NewValue(dtype, kConstant); - if (dtype.type == DataType::Bool()) { + if (dtype.type == PrimType::Bool()) { // bool types. if (*pvalue) { ib_.Begin(spv::OpConstantTrue).AddSeq(dtype, ret); @@ -481,7 +481,7 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) { uint64_t mask = 0xFFFFFFFFUL; ib_.Add(static_cast(pvalue[0] & mask)); if (dtype.type.bits() > 32) { - if (dtype.type.is_int()) { + if (dtype.type.MatchesCode(DLDataTypeCode::kDLInt)) { int64_t sign_mask = 0xFFFFFFFFL; const int64_t* sign_ptr = reinterpret_cast(pvalue); ib_.Add(static_cast((sign_ptr[0] >> 32L) & sign_mask)); @@ -495,20 +495,20 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) { return ret; } -SType IRBuilder::DeclareType(const DataType& dtype, uint32_t row, uint32_t col) { +SType IRBuilder::DeclareType(const PrimType& dtype, uint32_t row, uint32_t col) { AddCapabilityFor(dtype); if (dtype.lanes() == 1) { SType t; t.id = id_counter_++; t.type = dtype; - if (dtype.is_bool()) { + if (dtype.MatchesCode(DLDataTypeCode::kDLBool)) { ib_.Begin(spv::OpTypeBool).Add(t).Commit(&global_); - } else if (dtype.is_int()) { + } else if (dtype.MatchesCode(DLDataTypeCode::kDLInt)) { ib_.Begin(spv::OpTypeInt).AddSeq(t, dtype.bits(), 1).Commit(&global_); - } else if (dtype.is_uint()) { + } else if (dtype.MatchesCode(DLDataTypeCode::kDLUInt)) { ib_.Begin(spv::OpTypeInt).AddSeq(t, dtype.bits(), 0).Commit(&global_); - } else if (dtype.is_float()) { + } else if (dtype.MatchesCode(DLDataTypeCode::kDLFloat)) { ib_.Begin(spv::OpTypeFloat).AddSeq(t, dtype.bits()).Commit(&global_); } else { TVM_FFI_THROW(InternalError) << "declare type do not support handle"; @@ -518,15 +518,15 @@ SType IRBuilder::DeclareType(const DataType& dtype, uint32_t row, uint32_t col) SType t; t.id = id_counter_++; t.type = dtype; - SType base_type = GetSType(dtype.element_of()); + SType base_type = GetSType(dtype.WithLanes(1)); if (row * col == 0) { TVM_FFI_ICHECK((row == 0) && (col == 0)); ib_.Begin(spv::OpTypeVector).AddSeq(t, base_type, dtype.lanes()).Commit(&global_); } else { - Value v_row = GetSpecConst(GetSType(DataType::UInt(32)), row); - Value v_col = GetSpecConst(GetSType(DataType::UInt(32)), col); - Value scope = UIntImm(GetSType(DataType::UInt(32)), spv::ScopeSubgroup); + Value v_row = GetSpecConst(GetSType(PrimType::UInt(32)), row); + Value v_col = GetSpecConst(GetSType(PrimType::UInt(32)), col); + Value scope = UIntImm(GetSType(PrimType::UInt(32)), spv::ScopeSubgroup); ib_.Begin(spv::OpTypeCooperativeMatrixNV) .AddSeq(t, base_type, scope, v_row, v_col) .Commit(&global_); @@ -535,9 +535,9 @@ SType IRBuilder::DeclareType(const DataType& dtype, uint32_t row, uint32_t col) } } -void IRBuilder::AddCapabilityFor(const DataType& dtype) { +void IRBuilder::AddCapabilityFor(const PrimType& dtype) { // Declare appropriate capabilities for int/float types - if (dtype.is_int() || dtype.is_uint()) { + if (dtype.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) { if (dtype.bits() == 8) { TVM_FFI_ICHECK(spirv_support_.supports_int8) << "Vulkan target does not support Int8 capability. " @@ -561,7 +561,7 @@ void IRBuilder::AddCapabilityFor(const DataType& dtype) { capabilities_used_.insert(spv::CapabilityInt64); } - } else if (dtype.is_float()) { + } else if (dtype.MatchesCode(DLDataTypeCode::kDLFloat)) { if (dtype.bits() == 16) { TVM_FFI_ICHECK(spirv_support_.supports_float16) << "Vulkan target does not support Float16 capability. " @@ -584,7 +584,7 @@ void IRBuilder::AddCapabilityFor(const DataType& dtype) { // future. Requiring StorageBuffer8BitAccess in order to declare an // Int8 prevents use of an 8-bit loop iterator on a device that // supports Int8 but doesn't support 8-bit buffer access. - if (dtype.bits() == 8 && !dtype.is_bool()) { + if (dtype.bits() == 8 && !dtype.MatchesCode(DLDataTypeCode::kDLBool)) { TVM_FFI_ICHECK(spirv_support_.supports_storage_buffer_8bit_access) << "Vulkan target does not support StorageBuffer8BitAccess. " << "If your device supports 8-bit buffer access, " @@ -642,7 +642,7 @@ Value IRBuilder::CallGLSL450(const SType& ret_type, uint32_t inst_id, } Value IRBuilder::CallKHRIntegerDotProduct(const SType& ret_type, const std::vector& args, - const DataType& dtype) { + const PrimType& dtype) { if (args.size() != 3) { TVM_FFI_THROW(InternalError) << "Unresolved arguments in SPIRV_KHR_integer_dot_product"; } @@ -653,9 +653,9 @@ Value IRBuilder::CallKHRIntegerDotProduct(const SType& ret_type, const std::vect << "If your device supports integer dot product operations, " << "please either add -mattr=+dotprod to the target, " << "or query all device parameters by adding -from_device=0."; - if (dtype.is_int()) { + if (dtype.MatchesCode(DLDataTypeCode::kDLInt)) { ib_.Begin(spv::OpSDotAccSatKHR).AddSeq(ret_type, val); - } else if (dtype.is_uint()) { + } else if (dtype.MatchesCode(DLDataTypeCode::kDLUInt)) { ib_.Begin(spv::OpUDotAccSatKHR).AddSeq(ret_type, val); } else { TVM_FFI_THROW(InternalError) << "Unsupported type"; @@ -674,15 +674,15 @@ Value IRBuilder::CallKHRIntegerDotProduct(const SType& ret_type, const std::vect Value IRBuilder::Concat(const std::vector& vec) { bool is_const = vec[0].flag == kConstant; - DataType etype = vec[0].stype.type; + PrimType etype = vec[0].stype.type; int lanes = etype.lanes(); for (size_t i = 1; i < vec.size(); ++i) { - TVM_FFI_ICHECK_EQ(etype, vec[i].stype.type.element_of()) + TVM_FFI_ICHECK_EQ(etype, vec[i].stype.type.WithLanes(1)) << "Cannot concat vector of different element type"; lanes += vec[i].stype.type.lanes(); is_const = is_const && (vec[i].flag == kConstant); } - Value ret = NewValue(GetSType(etype.with_lanes(lanes)), kNormal); + Value ret = NewValue(GetSType(etype.WithLanes(lanes)), kNormal); if (is_const && vec.size() == static_cast(lanes)) { ib_.Begin(spv::OpConstantComposite); ib_.AddSeq(ret.stype, ret); @@ -704,53 +704,56 @@ Value IRBuilder::Concat(const std::vector& vec) { Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) { TVM_FFI_ICHECK_NE(value.stype.id, 0U); if (value.stype.id == dst_type.id) return value; - const tvm::DataType& from = value.stype.type; - const tvm::DataType& to = dst_type.type; + const tvm::PrimType& from = value.stype.type; + const tvm::PrimType& to = dst_type.type; TVM_FFI_ICHECK_EQ(from.lanes(), to.lanes()); - if (from == DataType::Bool()) { - if (to.is_int()) { + if (from == PrimType::Bool()) { + if (to.MatchesCode(DLDataTypeCode::kDLInt)) { return Select(value, IntImm(dst_type, 1), IntImm(dst_type, 0)); - } else if (to.is_uint()) { + } else if (to.MatchesCode(DLDataTypeCode::kDLUInt)) { return Select(value, UIntImm(dst_type, 1), UIntImm(dst_type, 0)); - } else if (to.is_float()) { + } else if (to.MatchesCode(DLDataTypeCode::kDLFloat)) { return MakeValue(spv::OpConvertUToF, dst_type, Select(value, UIntImm(t_uint32_, 1), UIntImm(t_uint32_, 0))); } else { TVM_FFI_THROW(InternalError) << "cannot cast from " << from << " to " << to; return Value(); } - } else if (to == DataType::Bool()) { - if (from.is_int()) { + } else if (to == PrimType::Bool()) { + if (from.MatchesCode(DLDataTypeCode::kDLInt)) { return NE(value, IntImm(value.stype, 0)); - } else if (to.is_uint()) { + } else if (from.MatchesCode(DLDataTypeCode::kDLUInt)) { return NE(value, UIntImm(value.stype, 0)); } else { TVM_FFI_THROW(InternalError) << "cannot cast from " << from << " to " << to; return Value(); } - } else if (from.is_int() && to.is_int()) { + } else if (from.MatchesCode(DLDataTypeCode::kDLInt) && to.MatchesCode(DLDataTypeCode::kDLInt)) { return MakeValue(spv::OpSConvert, dst_type, value); - } else if (from.is_uint() && to.is_uint()) { + } else if (from.MatchesCode(DLDataTypeCode::kDLUInt) && to.MatchesCode(DLDataTypeCode::kDLUInt)) { return MakeValue(spv::OpUConvert, dst_type, value); - } else if (from.is_uint() && to.is_int()) { + } else if (from.MatchesCode(DLDataTypeCode::kDLUInt) && to.MatchesCode(DLDataTypeCode::kDLInt)) { if (from.bits() != to.bits()) { - value = MakeValue(spv::OpUConvert, GetSType(from.with_bits(to.bits())), value); + value = MakeValue(spv::OpUConvert, GetSType(from.WithBits(to.bits())), value); } return MakeValue(spv::OpBitcast, dst_type, value); - } else if (from.is_int() && to.is_uint()) { + } else if (from.MatchesCode(DLDataTypeCode::kDLInt) && to.MatchesCode(DLDataTypeCode::kDLUInt)) { if (from.bits() != to.bits()) { - value = MakeValue(spv::OpSConvert, GetSType(from.with_bits(to.bits())), value); + value = MakeValue(spv::OpSConvert, GetSType(from.WithBits(to.bits())), value); } return MakeValue(spv::OpBitcast, dst_type, value); - } else if (from.is_float() && to.is_int()) { + } else if (from.MatchesCode(DLDataTypeCode::kDLFloat) && to.MatchesCode(DLDataTypeCode::kDLInt)) { return MakeValue(spv::OpConvertFToS, dst_type, value); - } else if (from.is_float() && to.is_uint()) { + } else if (from.MatchesCode(DLDataTypeCode::kDLFloat) && + to.MatchesCode(DLDataTypeCode::kDLUInt)) { return MakeValue(spv::OpConvertFToU, dst_type, value); - } else if (from.is_int() && to.is_float()) { + } else if (from.MatchesCode(DLDataTypeCode::kDLInt) && to.MatchesCode(DLDataTypeCode::kDLFloat)) { return MakeValue(spv::OpConvertSToF, dst_type, value); - } else if (from.is_uint() && to.is_float()) { + } else if (from.MatchesCode(DLDataTypeCode::kDLUInt) && + to.MatchesCode(DLDataTypeCode::kDLFloat)) { return MakeValue(spv::OpConvertUToF, dst_type, value); - } else if (from.is_float() && to.is_float()) { + } else if (from.MatchesCode(DLDataTypeCode::kDLFloat) && + to.MatchesCode(DLDataTypeCode::kDLFloat)) { return MakeValue(spv::OpFConvert, dst_type, value); } else { TVM_FFI_THROW(InternalError) << "do not support type cast from " << from << " to " << to; @@ -782,28 +785,28 @@ Value IRBuilder::GetSpecConst(const SType& dtype, uint64_t value) { return ret; } -#define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \ - Value IRBuilder::_OpName(Value a, Value b) { \ - TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \ - if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ - return MakeValue(spv::OpI##_Op, a.stype, a, b); \ - } else { \ - TVM_FFI_ICHECK(a.stype.type.is_float()); \ - return MakeValue(spv::OpF##_Op, a.stype, a, b); \ - } \ +#define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b) { \ + TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \ + if (a.stype.type.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) { \ + return MakeValue(spv::OpI##_Op, a.stype, a, b); \ + } else { \ + TVM_FFI_ICHECK(a.stype.type.MatchesCode(DLDataTypeCode::kDLFloat)); \ + return MakeValue(spv::OpF##_Op, a.stype, a, b); \ + } \ } -#define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \ - Value IRBuilder::_OpName(Value a, Value b) { \ - TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \ - if (a.stype.type.is_int()) { \ - return MakeValue(spv::OpS##_Op, a.stype, a, b); \ - } else if (a.stype.type.is_uint()) { \ - return MakeValue(spv::OpU##_Op, a.stype, a, b); \ - } else { \ - TVM_FFI_ICHECK(a.stype.type.is_float()); \ - return MakeValue(spv::OpF##_Op, a.stype, a, b); \ - } \ +#define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b) { \ + TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \ + if (a.stype.type.MatchesCode(DLDataTypeCode::kDLInt)) { \ + return MakeValue(spv::OpS##_Op, a.stype, a, b); \ + } else if (a.stype.type.MatchesCode(DLDataTypeCode::kDLUInt)) { \ + return MakeValue(spv::OpU##_Op, a.stype, a, b); \ + } else { \ + TVM_FFI_ICHECK(a.stype.type.MatchesCode(DLDataTypeCode::kDLFloat)); \ + return MakeValue(spv::OpF##_Op, a.stype, a, b); \ + } \ } DEFINE_BUILDER_BINARY_USIGN_OP(Add, Add); @@ -813,29 +816,29 @@ DEFINE_BUILDER_BINARY_SIGN_OP(Div, Div); Value IRBuilder::Mod(Value a, Value b) { TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); - if (a.stype.type.is_int()) { + if (a.stype.type.MatchesCode(DLDataTypeCode::kDLInt)) { return MakeValue(spv::OpSRem, a.stype, a, b); - } else if (a.stype.type.is_uint()) { + } else if (a.stype.type.MatchesCode(DLDataTypeCode::kDLUInt)) { return MakeValue(spv::OpUMod, a.stype, a, b); } else { - TVM_FFI_ICHECK(a.stype.type.is_float()); + TVM_FFI_ICHECK(a.stype.type.MatchesCode(DLDataTypeCode::kDLFloat)); return MakeValue(spv::OpFRem, a.stype, a, b); } } -#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \ - Value IRBuilder::_OpName(Value a, Value b) { \ - TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \ - TVM_FFI_ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ - const auto& bool_type = this->GetSType(DataType::Bool().with_lanes(a.stype.type.lanes())); \ - if (a.stype.type.is_int()) { \ - return MakeValue(spv::OpS##_Op, bool_type, a, b); \ - } else if (a.stype.type.is_uint()) { \ - return MakeValue(spv::OpU##_Op, bool_type, a, b); \ - } else { \ - TVM_FFI_ICHECK(a.stype.type.is_float()); \ - return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ - } \ +#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b) { \ + TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \ + TVM_FFI_ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ + const auto& bool_type = this->GetSType(PrimType::Bool().WithLanes(a.stype.type.lanes())); \ + if (a.stype.type.MatchesCode(DLDataTypeCode::kDLInt)) { \ + return MakeValue(spv::OpS##_Op, bool_type, a, b); \ + } else if (a.stype.type.MatchesCode(DLDataTypeCode::kDLUInt)) { \ + return MakeValue(spv::OpU##_Op, bool_type, a, b); \ + } else { \ + TVM_FFI_ICHECK(a.stype.type.MatchesCode(DLDataTypeCode::kDLFloat)); \ + return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ + } \ } DEFINE_BUILDER_CMP_OP(LT, LessThan); @@ -843,17 +846,17 @@ DEFINE_BUILDER_CMP_OP(LE, LessThanEqual); DEFINE_BUILDER_CMP_OP(GT, GreaterThan); DEFINE_BUILDER_CMP_OP(GE, GreaterThanEqual); -#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \ - Value IRBuilder::_OpName(Value a, Value b) { \ - TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \ - TVM_FFI_ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ - const auto& bool_type = this->GetSType(DataType::Bool().with_lanes(a.stype.type.lanes())); \ - if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ - return MakeValue(spv::OpI##_Op, bool_type, a, b); \ - } else { \ - TVM_FFI_ICHECK(a.stype.type.is_float()); \ - return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ - } \ +#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b) { \ + TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \ + TVM_FFI_ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ + const auto& bool_type = this->GetSType(PrimType::Bool().WithLanes(a.stype.type.lanes())); \ + if (a.stype.type.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt)) { \ + return MakeValue(spv::OpI##_Op, bool_type, a, b); \ + } else { \ + TVM_FFI_ICHECK(a.stype.type.MatchesCode(DLDataTypeCode::kDLFloat)); \ + return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ + } \ } DEFINE_BUILDER_CMP_UOP(EQ, Equal); @@ -861,7 +864,7 @@ DEFINE_BUILDER_CMP_UOP(NE, NotEqual); Value IRBuilder::Select(Value cond, Value a, Value b) { TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); - TVM_FFI_ICHECK_EQ(cond.stype.type.element_of(), DataType::Bool()); + TVM_FFI_ICHECK_EQ(cond.stype.type.WithLanes(1), PrimType::Bool()); return MakeValue(spv::OpSelect, a.stype, cond, a, b); } diff --git a/src/backend/vulkan/codegen/ir_builder.h b/src/backend/vulkan/codegen/ir_builder.h index 3cca1b4cfe33..85dbdc00cff4 100644 --- a/src/backend/vulkan/codegen/ir_builder.h +++ b/src/backend/vulkan/codegen/ir_builder.h @@ -50,7 +50,7 @@ struct SType { /*! \brief The Id to represent type */ uint32_t id{0}; /*! \brief corresponding TVM type */ - tvm::DataType type; + tvm::PrimType type; /*! \brief content type id if it is a pointer/struct-array class */ uint32_t element_type_id{0}; /*! \brief The storage class, if it is a pointer */ @@ -430,7 +430,7 @@ class IRBuilder { * \return The result value. */ Value CallKHRIntegerDotProduct(const SType& ret_type, const std::vector& args, - const DataType& dtype); + const PrimType& dtype); /*! * \brief Build vector by concatenating components @@ -444,7 +444,7 @@ class IRBuilder { * \param dtype The data type. * \return The corresponding spirv type. */ - SType GetSType(const tvm::DataType& dtype, uint32_t row = 0, uint32_t col = 0); + SType GetSType(const tvm::PrimType& dtype, uint32_t row = 0, uint32_t col = 0); /*! * \brief Get the pointer type that points to value_type * \param value_type. @@ -656,11 +656,11 @@ class IRBuilder { Value GetConst_(const SType& dtype, const uint64_t* pvalue); // declare type - SType DeclareType(const DataType& dtype, uint32_t row = 0, uint32_t col = 0); + SType DeclareType(const PrimType& dtype, uint32_t row = 0, uint32_t col = 0); // Declare the appropriate SPIR-V capabilities and extensions to use // this data type. - void AddCapabilityFor(const DataType& dtype); + void AddCapabilityFor(const PrimType& dtype); /*! \brief SPIRV-related capabilities of the target * diff --git a/src/backend/webgpu/codegen/codegen_webgpu.cc b/src/backend/webgpu/codegen/codegen_webgpu.cc index 440f1f04b95e..7129aa23d2ee 100644 --- a/src/backend/webgpu/codegen/codegen_webgpu.cc +++ b/src/backend/webgpu/codegen/codegen_webgpu.cc @@ -68,7 +68,7 @@ class WebGPUWorkgroupInfoCollector : public StmtExprVisitor { void VisitExpr_(const VarNode* op) final { StmtExprVisitor::VisitExpr_(op); Var buffer_var = ffi::GetRef(op); - if (buffer_var.dtype().is_handle()) { + if (buffer_var.ty().IsHandle()) { info_.write_access_set.insert(buffer_var); } } @@ -119,7 +119,7 @@ void CodeGenWebGPU::InitFuncState(const PrimFunc& f) { CodeGenC::InitFuncState(f); // analyze the data; for (Var arg : f->params) { - if (arg.dtype().is_handle()) { + if (arg.ty().IsHandle()) { alloc_storage_scope_[arg.get()] = "global"; } } @@ -174,10 +174,10 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re os_param_access << "paramWriteAccess:["; // setup buffer argumemts for (Var arg : f->params) { - DataType t = arg.dtype(); - func_arg_types.push_back(t); + PrimType t = arg.ty(); + func_arg_types.push_back(t->dtype); - if (t.is_handle()) { + if (t.IsHandle()) { auto* ptr = arg->type_annotation.as(); TVM_FFI_ICHECK(ptr) << "All handles passed to the CodeGenWebGPU must have a type_annotation as a " @@ -188,11 +188,11 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re << "All handles passed to the CodeGenWebGPU must have a type_annotation as a " "PointerType, " << "and must point to a PrimType"; - DataType value_storage_type = prim->dtype; - if (value_storage_type == DataType::Bool()) { + PrimType value_storage_type(prim->dtype); + if (value_storage_type.MatchesCode(DLDataTypeCode::kDLBool)) { // We need a physically addressable buffer type to support boolean tensors. // The loaded byte is cast to bool inside the LoadNode visitor below. - value_storage_type = boolean_storage_type_.with_lanes(value_storage_type.lanes()); + value_storage_type = boolean_storage_type_.WithLanes(value_storage_type.lanes()); } std::string vid = AllocVarID(arg.get()); std::string access_mode; @@ -209,7 +209,7 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re // add extra access mode info to launch params this->decl_stream << "@group(0) @binding(" << num_buffer++ << ") " << "var " << vid << " : array<"; - this->PrintType(value_storage_type, this->decl_stream); + this->PrintType(value_storage_type->dtype, this->decl_stream); this->decl_stream << ">;\n"; } else { pod_args.push_back(arg); @@ -228,17 +228,17 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re for (size_t i = 0; i < pod_args.size(); ++i) { Var v = pod_args[i]; - TVM_FFI_ICHECK(!v.dtype().is_handle()); + TVM_FFI_ICHECK(!v.ty().IsHandle()); std::string vid = AllocVarID(v.get()); - if (v.dtype() == DataType::Int(32)) { + if (v.ty() == PrimType::Int(32)) { this->decl_stream << " " << vid << ": i32"; - } else if (v.dtype() == DataType::UInt(32)) { + } else if (v.ty() == PrimType::UInt(32)) { this->decl_stream << " " << vid << ": u32"; - } else if (v.dtype() == DataType::Float(32)) { + } else if (v.ty() == PrimType::Float(32)) { this->decl_stream << " " << vid << ": f32"; } else { - TVM_FFI_THROW(InternalError) << "Do not support pod argument type " << v.dtype(); + TVM_FFI_THROW(InternalError) << "Do not support pod argument type " << v.ty()->dtype; } this->decl_stream << ",\n"; // value ref @@ -289,13 +289,13 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re void CodeGenWebGPU::BindThreadIndex(const IterVar& iv) { TVM_FFI_ICHECK(!var_idmap_.count(iv->var.get())); std::ostringstream os; - PrintType(iv->var.dtype(), os); + PrintType(iv->var.ty()->dtype, os); if (iv->thread_tag == "blockIdx.x") { // WebGPU have restriction to limit the maximum size of blockId.x to be 65535 // We allow runtime to spread the load out to blockIdx.z so it can be a large number. os << "(blockIdx.z * gridDim.x + blockIdx.x)"; std::string tidx = os.str(); - std::string aggregated_bidx = SSAGetID(os.str(), iv->var.dtype()); + std::string aggregated_bidx = SSAGetID(os.str(), iv->var.ty()->dtype); var_idmap_[iv->var.get()] = aggregated_bidx; } else { os << "(" << iv->thread_tag << ")"; @@ -305,16 +305,17 @@ void CodeGenWebGPU::BindThreadIndex(const IterVar& iv) { } } -void CodeGenWebGPU::PrintType(DataType t, std::ostream& os) { // NOLINT(*) +void CodeGenWebGPU::PrintType(DLDataType raw_t, std::ostream& os) { // NOLINT(*) + PrimType t(raw_t); int lanes = t.lanes(); - if (t.is_handle()) { + if (t.IsHandle()) { TVM_FFI_THROW(InternalError) << "Cannot print handle type in WebGPU"; } - if (t.is_void()) { + if (t.IsVoid()) { os << "void"; return; } - if (t == DataType::Bool()) { + if (raw_t == DLDataType{kDLBool, 8, 1}) { os << "bool"; return; } @@ -323,28 +324,29 @@ void CodeGenWebGPU::PrintType(DataType t, std::ostream& os) { // NOLINT(*) TVM_FFI_ICHECK(lanes >= 2 && lanes <= 4) << "CodeGenWebGPU: only allows vector with lanes in {2, 3, 4}"; // Currently WebGPU doesn't support `i8` and an `int8x4` is represented as a `u32`. - if (t.is_int() && t.bits() == 8 && lanes == 4) { + if (t.MatchesCode(DLDataTypeCode::kDLInt) && t.bits() == 8 && lanes == 4) { os << "u32"; return; } os << "vec" << lanes << "<"; } - if (t.is_float()) { + if (t.code() == DLDataTypeCode::kDLFloat) { TVM_FFI_ICHECK(t.bits() == 16 || t.bits() == 32) << "CodeGenWebGPU: only support f16 or f32"; if (t.bits() == 16) { // Using f16 requires enable directive enable_fp16_ = true; } os << "f" << t.bits(); - } else if (t.is_uint()) { + } else if (t.MatchesCode(DLDataTypeCode::kDLUInt)) { TVM_FFI_ICHECK(t.bits() != 64) << "CodeGenWebGPU: do not support u64"; os << "u" << t.bits(); - } else if (t.is_int()) { + } else if (t.MatchesCode(DLDataTypeCode::kDLInt)) { TVM_FFI_ICHECK(t.bits() != 64) << "CodeGenWebGPU: do not support i64"; os << "i" << t.bits(); } else { - TVM_FFI_THROW(InternalError) << "CodeGenWebGPU: Cannot convert type " << t << " to WebGPU type"; + TVM_FFI_THROW(InternalError) << "CodeGenWebGPU: Cannot convert type " + << ffi::DLDataTypeToString(raw_t) << " to WebGPU type"; } if (lanes != 1) { os << ">"; @@ -365,18 +367,18 @@ void CodeGenWebGPU::PrintStorageSync(const CallNode* op) { } void CodeGenWebGPU::PrintSSAAssign(const std::string& target, const std::string& src, - DataType type) { + PrimType type) { stream << "let " << target << " : "; - PrintType(type, stream); + PrintType(type->dtype, stream); stream << " = " << src << ";\n"; } -void CodeGenWebGPU::PrintVecElemLoad(const std::string& vec, DataType t, int i, +void CodeGenWebGPU::PrintVecElemLoad(const std::string& vec, DLDataType t, int i, std::ostream& os) { // NOLINT(*) os << vec << "[" << i << "]"; } -void CodeGenWebGPU::PrintVecElemStore(const std::string& vec, DataType t, int i, +void CodeGenWebGPU::PrintVecElemStore(const std::string& vec, DLDataType t, int i, const std::string& value) { this->PrintIndent(); stream << vec << "[" << i << "] = " << value << ";\n"; @@ -384,8 +386,8 @@ void CodeGenWebGPU::PrintVecElemStore(const std::string& vec, DataType t, int i, void CodeGenWebGPU::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); - int lanes = op->dtype.lanes(); - PrintType(op->dtype, os); + int lanes = op->ty().lanes(); + PrintType(op->ty()->dtype, os); os << "("; for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; @@ -395,14 +397,14 @@ void CodeGenWebGPU::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // } PrimExpr CodeGenWebGPU::EnforceU32(PrimExpr value) { - return cast(DataType::UInt(32, value.dtype().lanes()), value); + return cast(PrimType::UInt(32, value.ty().lanes()), value); } void CodeGenWebGPU::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) if (op->op.same_as(builtin::reinterpret())) { // generate bitcast(ARG) os << "bitcast<"; - this->PrintType(op->dtype, os); + this->PrintType(op->ty()->dtype, os); os << ">("; this->PrintExpr(op->args[0], os); os << ")"; @@ -426,7 +428,7 @@ void CodeGenWebGPU::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLIN std::string cond = PrintExpr(op->args[0]); this->PrintIndent(); this->stream << "var " << result << " : "; - PrintType(op->dtype, this->stream); + PrintType(op->ty()->dtype, this->stream); this->stream << ";\n"; this->PrintIndent(); this->stream << "if (" << cond << ") {\n"; @@ -459,7 +461,7 @@ void CodeGenWebGPU::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLIN } void CodeGenWebGPU::VisitExpr_(const CastNode* op, std::ostream& os) { // NOLINT(*) - PrintType(op->dtype, os); + PrintType(op->ty()->dtype, os); os << "(" << PrintExpr(op->value) << ")"; } @@ -478,7 +480,7 @@ void CodeGenWebGPU::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT PrintIndent(); std::string value = PrintExpr(op->value); this->stream << "let " << AllocVarID(op->var.get()) << " : "; - PrintType(op->var.dtype(), this->stream); + PrintType(op->var.ty()->dtype, this->stream); this->stream << " = " << value << ";\n"; } os << PrintExpr(op->body); @@ -490,18 +492,18 @@ void CodeGenWebGPU::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT } void CodeGenWebGPU::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*) - if (op->dtype.bits() == 32) { + if (op->ty().bits() == 32) { std::ostringstream temp; - if (op->dtype.is_int()) { + if (op->ty().MatchesCode(DLDataTypeCode::kDLInt)) { temp << op->value << "i"; } else { - TVM_FFI_ICHECK(op->dtype.is_uint()); + TVM_FFI_ICHECK(op->ty().MatchesCode(DLDataTypeCode::kDLUInt)); temp << op->value << "u"; } this->MarkConst(temp.str()); os << temp.str(); } else { - this->PrintType(op->dtype, os); + this->PrintType(op->ty()->dtype, os); os << "(" << op->value << ")"; } } @@ -509,14 +511,14 @@ void CodeGenWebGPU::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOL void CodeGenWebGPU::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) std::ostringstream temp; temp << std::scientific << op->value; - if (op->dtype.bits() == 32) { + if (op->ty().bits() == 32) { temp << 'f'; - } else if (op->dtype.bits() == 16) { + } else if (op->ty().bits() == 16) { // Using f16 requires enable directive enable_fp16_ = true; temp << 'h'; } else { - TVM_FFI_THROW(InternalError) << "Unsupported floating point bits " << op->dtype.bits(); + TVM_FFI_THROW(InternalError) << "Unsupported floating point bits " << op->ty().bits(); } MarkConst(temp.str()); os << temp.str(); @@ -530,39 +532,42 @@ void CodeGenWebGPU::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // TVM_FFI_ICHECK_EQ(op->indices.size(), 1) << "Load from non-flat memory not supported."; TVM_FFI_ICHECK(!op->predicate.defined()) << "Predicated buffer load is not supported."; - DataType value_dtype = op->dtype; + DLDataType value_dtype = op->ty()->dtype; + PrimType value_ty(value_dtype); PrimExpr index = op->indices[0]; Var buffer_var = op->buffer->data; - DataType element_dtype = op->buffer->dtype; + DLDataType element_dtype = op->buffer->dtype->dtype; + PrimType element_ty(element_dtype); - int lanes = op->dtype.lanes(); + int lanes = value_ty.lanes(); std::string buffer_vid = GetVarID(buffer_var.get()); - if (value_dtype.lanes() == element_dtype.lanes()) { + if (value_ty.lanes() == element_ty.lanes()) { // Direct buffer loading // Special handle bool loading - if (value_dtype == DataType::Bool()) { + if (value_dtype == DLDataType{kDLBool, 8, 1}) { this->PrintType(value_dtype, os); os << "("; } else { TVM_FFI_ICHECK(value_dtype == element_dtype); } - TVM_FFI_ICHECK_EQ(index.dtype().lanes(), 1); + TVM_FFI_ICHECK_EQ(index.ty().lanes(), 1); os << buffer_vid << "[" << this->PrintExpr(index) << "]"; // Special handle bool loading - if (value_dtype == DataType::Bool()) { + if (value_dtype == DLDataType{kDLBool, 8, 1}) { os << ")"; } } else { // Vector load from scalar buffer - TVM_FFI_ICHECK_EQ(element_dtype.lanes(), 1) << "Can only vector load scalar array"; - TVM_FFI_ICHECK(value_dtype.element_of() == element_dtype) + TVM_FFI_ICHECK_EQ(element_ty.lanes(), 1) << "Can only vector load scalar array"; + DLDataType value_element_dtype{value_dtype.code, value_dtype.bits, 1}; + TVM_FFI_ICHECK(value_element_dtype == element_dtype) << "WebGPU vector loading requires base type to match"; arith::PVar base; - if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) { + if (arith::ramp(base, 1, value_ty.lanes()).Match(index)) { // vec3(buf[base + 0], buf[base + 1], buf[base + 2]); - std::string base_vid = SSAGetID(PrintExpr(base.Eval()), base.Eval().dtype()); - PrintType(element_dtype.with_lanes(value_dtype.lanes()), os); + std::string base_vid = SSAGetID(PrintExpr(base.Eval()), base.Eval().ty()->dtype); + PrintType(element_ty.WithLanes(value_ty.lanes())->dtype, os); os << "("; for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; @@ -571,8 +576,8 @@ void CodeGenWebGPU::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // os << ")"; } else { // vec3(buf[index[0]], buf[index[1]], buf[index[2]]); - std::string index_vid = SSAGetID(PrintExpr(index), index.dtype()); - PrintType(element_dtype.with_lanes(value_dtype.lanes()), os); + std::string index_vid = SSAGetID(PrintExpr(index), index.ty()->dtype); + PrintType(element_ty.WithLanes(value_ty.lanes())->dtype, os); os << "("; for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; @@ -593,7 +598,7 @@ void CodeGenWebGPU::VisitStmt_(const BindNode* op) { PrintIndent(); std::string value = PrintExpr(op->value); this->stream << "let " << AllocVarID(op->var.get()) << " : "; - PrintType(op->var.dtype(), this->stream); + PrintType(op->var.ty()->dtype, this->stream); this->stream << " = " << value << ";\n"; } } @@ -602,14 +607,16 @@ void CodeGenWebGPU::VisitStmt_(const BufferStoreNode* op) { TVM_FFI_ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported."; TVM_FFI_ICHECK(!op->predicate.defined()) << "Predicated buffer store is not supported."; - DataType value_dtype = op->value.dtype(); - DataType element_dtype = op->buffer->dtype; + DLDataType value_dtype = op->value.ty()->dtype; + PrimType value_ty(value_dtype); + DLDataType element_dtype = op->buffer->dtype->dtype; + PrimType element_ty(element_dtype); PrimExpr index = op->indices[0]; Var buffer_var = op->buffer->data; std::string buffer_vid = GetVarID(buffer_var.get()); - if (value_dtype.lanes() == element_dtype.lanes()) { + if (value_ty.lanes() == element_ty.lanes()) { // must execute print expr first // so we won't have recursive append to stream std::string index_vid = PrintExpr(index); @@ -618,7 +625,7 @@ void CodeGenWebGPU::VisitStmt_(const BufferStoreNode* op) { this->PrintIndent(); stream << buffer_vid << "[" << index_vid << "] = "; // special explicit conversion of bool - if (value_dtype == DataType::Bool()) { + if (value_dtype == DLDataType{kDLBool, 8, 1}) { PrintType(element_dtype, stream); stream << "("; } else { @@ -626,22 +633,23 @@ void CodeGenWebGPU::VisitStmt_(const BufferStoreNode* op) { } stream << value_vid; // Special handle bool store - if (value_dtype == DataType::Bool()) { + if (value_dtype == DLDataType{kDLBool, 8, 1}) { stream << ")"; } stream << ";\n"; } else { // Vector store into scalar buffer - TVM_FFI_ICHECK_EQ(element_dtype.lanes(), 1) << "Can only vector load scalar array"; - TVM_FFI_ICHECK(value_dtype.element_of() == element_dtype) + TVM_FFI_ICHECK_EQ(element_ty.lanes(), 1) << "Can only vector load scalar array"; + DLDataType value_element_dtype{value_dtype.code, value_dtype.bits, 1}; + TVM_FFI_ICHECK(value_element_dtype == element_dtype) << "WebGPU vector stire requires base type to match"; std::string value_vid = PrintExpr(op->value); arith::PVar base; - if (arith::ramp(base, 1, value_dtype.lanes()).Match(index)) { + if (arith::ramp(base, 1, value_ty.lanes()).Match(index)) { // buf[base + 0] = value[0] // buf[base + 1] = value[1] - std::string base_vid = SSAGetID(PrintExpr(base.Eval()), base.Eval().dtype()); - for (int i = 0; i < value_dtype.lanes(); ++i) { + std::string base_vid = SSAGetID(PrintExpr(base.Eval()), base.Eval().ty()->dtype); + for (int i = 0; i < value_ty.lanes(); ++i) { this->PrintIndent(); stream << buffer_vid << "[" << base_vid << " + " << i << "] = " << value_vid << "[" << i << "];\n"; @@ -649,8 +657,8 @@ void CodeGenWebGPU::VisitStmt_(const BufferStoreNode* op) { } else { // buf[index[0]] = value[0] // buf[index[1]] = value[1] - std::string index_vid = SSAGetID(PrintExpr(index), index.dtype()); - for (int i = 0; i < value_dtype.lanes(); ++i) { + std::string index_vid = SSAGetID(PrintExpr(index), index.ty()->dtype); + for (int i = 0; i < value_ty.lanes(); ++i) { this->PrintIndent(); stream << buffer_vid << "[" << index_vid << "[" << i << "]] = " << value_vid << "[" << i << "];\n"; @@ -673,12 +681,12 @@ void CodeGenWebGPU::VisitStmt_(const AllocBufferNode* op) { if (storage_scope.rank == runtime::StorageRank::kShared) { this->decl_stream << "var " << vid << " : array<"; - PrintType(op->buffer->dtype, this->decl_stream); + PrintType(op->buffer->dtype->dtype, this->decl_stream); this->decl_stream << ", " << constant_size << ">;\n"; } else if (storage_scope.rank == runtime::StorageRank::kLocal) { this->PrintIndent(); this->stream << "var " << vid << " : array<"; - PrintType(op->buffer->dtype, this->stream); + PrintType(op->buffer->dtype->dtype, this->stream); this->stream << ", " << constant_size << ">;\n"; } else { TVM_FFI_THROW(InternalError) << "WebGPU: Do not support storage scope: " @@ -694,7 +702,7 @@ void CodeGenWebGPU::VisitStmt_(const ForNode* op) { std::string vid = AllocVarID(op->loop_var.get()); PrintIndent(); stream << "for (var " << vid << " : "; - PrintType(op->loop_var.dtype(), stream); + PrintType(op->loop_var.ty()->dtype, stream); stream << " = " << begin_str << "; " << vid << " < " << end_str << "; " << vid; if (step_str.empty()) { stream << "++"; diff --git a/src/backend/webgpu/codegen/codegen_webgpu.h b/src/backend/webgpu/codegen/codegen_webgpu.h index 4c873ac3db18..c2179c5c48aa 100644 --- a/src/backend/webgpu/codegen/codegen_webgpu.h +++ b/src/backend/webgpu/codegen/codegen_webgpu.h @@ -51,16 +51,17 @@ class CodeGenWebGPU final : public CodeGenC { using CodeGenC::AddFunction; runtime::FunctionInfo AddFunction(const PrimFunc& f, bool skip_readonly_decl); // NOLINT(*) void InitFuncState(const PrimFunc& f) final; - void PrintStorageSync(const CallNode* op) final; // NOLINT(*) - void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) - void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) + void PrintStorageSync(const CallNode* op) final; // NOLINT(*) + void PrintType(DLDataType t, std::ostream& os) final; // NOLINT(*) + void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) // assignment printing - void PrintSSAAssign(const std::string& target, const std::string& src, DataType type) final; + void PrintSSAAssign(const std::string& target, const std::string& src, PrimType type) final; // overload printing vector element load/store - void PrintVecElemLoad(const std::string& vec, DataType t, int i, std::ostream& os) final; - void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; + void PrintVecElemLoad(const std::string& vec, DLDataType t, int i, std::ostream& os) final; + void PrintVecElemStore(const std::string& vec, DLDataType t, int i, + const std::string& value) final; // overload visitor void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) @@ -90,7 +91,7 @@ class CodeGenWebGPU final : public CodeGenC { /*! * \brief Storage type of bool values. */ - DataType boolean_storage_type_{DataType::Int(8)}; + PrimType boolean_storage_type_{PrimType::Int(8)}; // whether enable fp16 bool enable_fp16_{false}; diff --git a/src/backend/webgpu/codegen/intrin_rule_webgpu.cc b/src/backend/webgpu/codegen/intrin_rule_webgpu.cc index 1c172fcd141b..7992fa9915c0 100644 --- a/src/backend/webgpu/codegen/intrin_rule_webgpu.cc +++ b/src/backend/webgpu/codegen/intrin_rule_webgpu.cc @@ -34,7 +34,7 @@ using tirx::FLowerIntrinsic; // warp-level primitives. Follows implementation in intrin_rule_metal.cc struct WebGPUWarpIntrinsic { - const Op operator()(DataType t, const Op& orig_op) const { + const Op operator()(PrimType t, const Op& orig_op) const { if (orig_op.same_as(builtin::tvm_warp_shuffle())) { static const Op& webgpu_subgroup_shuffle_op = Op::Get("tirx.webgpu.subgroup_shuffle"); return webgpu_subgroup_shuffle_op; @@ -55,9 +55,9 @@ static PrimExpr DispatchWebGPUShuffle(const PrimExpr& e) { const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); TVM_FFI_ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size - PrimExpr lane_or_delta = Cast(DataType::UInt(32, call->args[2].dtype().lanes()), call->args[2]); + PrimExpr lane_or_delta = Cast(PrimType::UInt(32, call->args[2].ty().lanes()), call->args[2]); ffi::Array webgpu_args{{call->args[1], lane_or_delta}}; - return Call(call->dtype, T()(call->dtype, call->op.as_or_throw()), webgpu_args); + return Call(e.ty(), T()(e.ty(), call->op.as_or_throw()), webgpu_args); } void RegisterWebGPUIntrinRules() { @@ -69,7 +69,7 @@ void RegisterWebGPUIntrinRules() { // See full list of builtin: https://www.w3.org/TR/WGSL/#builtin-functions struct ReturnAbs { - std::string operator()(DataType t, std::string name) const { return "abs"; } + std::string operator()(PrimType t, std::string name) const { return "abs"; } }; TVM_REGISTER_OP("tirx.fabs") @@ -124,7 +124,7 @@ TVM_REGISTER_OP("tirx.pow") .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); struct ReturnRound { - std::string operator()(DataType t, std::string name) const { return "round"; } + std::string operator()(PrimType t, std::string name) const { return "round"; } }; // WGSL round() uses ties-to-even (banker's rounding), matching IEEE 754 and ONNX Round spec. diff --git a/src/ir/expr.cc b/src/ir/expr.cc index ef6ea0ed6dca..f73cd6ae3913 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -48,33 +49,39 @@ TVM_FFI_STATIC_INIT_BLOCK() { PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm::Int32(value)) {} -PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(DataType::Float(32), value)) {} +PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(PrimType::Float(32), value)) {} PrimExpr PrimExpr::ConvertFallbackValue(ffi::String value) { return tirx::StringImm(value); } -IntImm::IntImm(DataType dtype, int64_t value, Span span) { - TVM_FFI_CHECK(dtype.is_scalar(), ValueError) - << "IntImm can only take scalar, but " << dtype << " was supplied."; - TVM_FFI_CHECK(dtype.is_int() || dtype.is_uint() || dtype.is_bool(), ValueError) - << "IntImm supports only int or uint or bool type, but " << dtype << " was supplied."; - if (dtype.is_uint()) { +IntImm::IntImm(PrimType value_ty, int64_t value, Span span) { + DLDataType runtime_dtype = value_ty->dtype; + DLDataTypeCode code = value_ty.code(); + int32_t bits = value_ty.bits(); + TVM_FFI_CHECK(!value_ty.IsScalableVector() && !value_ty.IsFixedLengthVector(), ValueError) + << "IntImm can only take scalar, but " << runtime_dtype << " was supplied."; + TVM_FFI_CHECK(value_ty.MatchesCode(DLDataTypeCode::kDLInt, DLDataTypeCode::kDLUInt, + DLDataTypeCode::kDLBool), + ValueError) + << "IntImm supports only int or uint or bool type, but " << runtime_dtype << " was supplied."; + if (code == DLDataTypeCode::kDLUInt) { TVM_FFI_CHECK_GE(value, 0U, ValueError) - << "Literal value " << value << " is negative for unsigned integer type " << dtype; - if (dtype.bits() < 64) { - TVM_FFI_CHECK_LT(value, 1LL << dtype.bits(), ValueError) - << "Literal value " << value << " exceeds maximum of " << dtype; + << "Literal value " << value << " is negative for unsigned integer type " << runtime_dtype; + if (bits < 64) { + TVM_FFI_CHECK_LT(value, 1LL << bits, ValueError) + << "Literal value " << value << " exceeds maximum of " << runtime_dtype; } - } else if (dtype.bits() == 1 || dtype.is_bool()) { + } else if (bits == 1 || code == DLDataTypeCode::kDLBool) { // int(1) - TVM_FFI_CHECK(value == 0 || value == 1, ValueError) << value << " exceeds range of " << dtype; - } else if (dtype.bits() < 64) { - TVM_FFI_CHECK_GE(value, -(1LL << (dtype.bits() - 1)), ValueError) - << "Literal value " << value << " exceeds minimum of " << dtype; - TVM_FFI_CHECK_LT(value, 1LL << (dtype.bits() - 1), ValueError) - << "Literal value " << value << " exceeds maximum of " << dtype; + TVM_FFI_CHECK(value == 0 || value == 1, ValueError) + << value << " exceeds range of " << runtime_dtype; + } else if (bits < 64) { + TVM_FFI_CHECK_GE(value, -(1LL << (bits - 1)), ValueError) + << "Literal value " << value << " exceeds minimum of " << runtime_dtype; + TVM_FFI_CHECK_LT(value, 1LL << (bits - 1), ValueError) + << "Literal value " << value << " exceeds maximum of " << runtime_dtype; } ffi::ObjectPtr node = ffi::make_object(); - node->dtype = dtype; + node->BaseExprNode::ty = std::move(value_ty); node->value = value; node->span = span; data_ = std::move(node); @@ -82,103 +89,118 @@ IntImm::IntImm(DataType dtype, int64_t value, Span span) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ir.IntImm", [](DataType dtype, int64_t value, Span span) { - return IntImm(dtype, value, span); + refl::GlobalDef().def("ir.IntImm", [](DLDataType dtype, int64_t value, Span span) { + return IntImm(PrimType(dtype), value, span); }); } -FloatImm::FloatImm(DataType dtype, double value, Span span) { - TVM_FFI_CHECK_EQ(dtype.lanes(), 1, ValueError) << "FloatImm can only take scalar."; +FloatImm::FloatImm(PrimType value_ty, double value, Span span) { + DLDataType runtime_dtype = value_ty->dtype; + DLDataTypeCode code = value_ty.code(); + int32_t bits = value_ty.bits(); + TVM_FFI_CHECK(!value_ty.IsScalableVector() && !value_ty.IsFixedLengthVector(), ValueError) + << "FloatImm can only take scalar."; - TVM_FFI_CHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || dtype.is_float6() || - dtype.is_float4() || dtype.code() >= DataType::kCustomBegin, - ValueError) - << "FloatImm supports only float, but " << dtype << " was supplied."; + TVM_FFI_CHECK( + value_ty.MatchesCode(DLDataTypeCode::kDLFloat, DLDataTypeCode::kDLFloat8_e3m4, + DLDataTypeCode::kDLFloat8_e4m3, DLDataTypeCode::kDLFloat8_e4m3b11fnuz, + DLDataTypeCode::kDLFloat8_e4m3fn, DLDataTypeCode::kDLFloat8_e4m3fnuz, + DLDataTypeCode::kDLFloat8_e5m2, DLDataTypeCode::kDLFloat8_e5m2fnuz, + DLDataTypeCode::kDLFloat8_e8m0fnu, DLDataTypeCode::kDLFloat6_e2m3fn, + DLDataTypeCode::kDLFloat6_e3m2fn) || + value_ty.MatchesElementType(DLDataTypeCode::kDLBfloat, 16) || + value_ty.MatchesElementType(DLDataTypeCode::kDLFloat4_e2m1fn, 4) || + static_cast(code) >= static_cast(ffi::DLExtDataTypeCode::kDLExtCustomBegin), + ValueError) + << "FloatImm supports only float, but " << runtime_dtype << " was supplied."; // check range for float32 and float16 since they have specified range. if (!std::isinf(value) && !std::isnan(value)) { - if (dtype.bits() == 32) { + if (bits == 32) { TVM_FFI_CHECK_GE(value, std::numeric_limits::lowest(), ValueError) - << "Literal value " << value << " exceeds minimum of " << dtype; + << "Literal value " << value << " exceeds minimum of " << runtime_dtype; TVM_FFI_CHECK_LE(value, std::numeric_limits::max(), ValueError) - << "Literal value " << value << " exceeds maximum of " << dtype; - } else if (dtype.is_float16()) { + << "Literal value " << value << " exceeds maximum of " << runtime_dtype; + } else if (value_ty.MatchesElementType(DLDataTypeCode::kDLFloat, 16)) { TVM_FFI_CHECK_GE(value, -support::kMaxFloat16, ValueError) - << "Literal value " << value << " exceeds minimum of " << dtype; + << "Literal value " << value << " exceeds minimum of " << runtime_dtype; TVM_FFI_CHECK_LE(value, support::kMaxFloat16, ValueError) - << "Literal value " << value << " exceeds maximum of " << dtype; - } else if (dtype.is_bfloat16()) { + << "Literal value " << value << " exceeds maximum of " << runtime_dtype; + } else if (value_ty.MatchesElementType(DLDataTypeCode::kDLBfloat, 16)) { TVM_FFI_CHECK_GE(value, -support::kMaxBFloat16, ValueError) - << "Literal value " << value << " exceeds minimum of " << dtype; + << "Literal value " << value << " exceeds minimum of " << runtime_dtype; TVM_FFI_CHECK_LE(value, support::kMaxBFloat16, ValueError) - << "Literal value " << value << " exceeds maximum of " << dtype; - } else if (dtype.is_float8_e3m4() || dtype.is_float8_e4m3() || dtype.is_float8_e4m3b11fnuz() || - dtype.is_float8_e4m3fn() || dtype.is_float8_e4m3fnuz() || dtype.is_float8_e5m2() || - dtype.is_float8_e5m2fnuz() || dtype.is_float8_e8m0fnu()) { + << "Literal value " << value << " exceeds maximum of " << runtime_dtype; + } else if (value_ty.MatchesCode( + DLDataTypeCode::kDLFloat8_e3m4, DLDataTypeCode::kDLFloat8_e4m3, + DLDataTypeCode::kDLFloat8_e4m3b11fnuz, DLDataTypeCode::kDLFloat8_e4m3fn, + DLDataTypeCode::kDLFloat8_e4m3fnuz, DLDataTypeCode::kDLFloat8_e5m2, + DLDataTypeCode::kDLFloat8_e5m2fnuz, DLDataTypeCode::kDLFloat8_e8m0fnu)) { double bound = 0.0; bool nonneg = false; - switch (dtype.code()) { - case DataType::TypeCode::kFloat8_e3m4: + switch (code) { + case DLDataTypeCode::kDLFloat8_e3m4: bound = support::kMaxE3M4; break; - case DataType::TypeCode::kFloat8_e4m3: + case DLDataTypeCode::kDLFloat8_e4m3: bound = support::kMaxE4M3; break; - case DataType::TypeCode::kFloat8_e4m3b11fnuz: + case DLDataTypeCode::kDLFloat8_e4m3b11fnuz: bound = support::kMaxE4M3B11FNUZ; nonneg = true; break; - case DataType::TypeCode::kFloat8_e4m3fn: + case DLDataTypeCode::kDLFloat8_e4m3fn: bound = support::kMaxE4M3FN; break; - case DataType::TypeCode::kFloat8_e4m3fnuz: + case DLDataTypeCode::kDLFloat8_e4m3fnuz: bound = support::kMaxE4M3FNUZ; nonneg = true; break; - case DataType::TypeCode::kFloat8_e5m2: + case DLDataTypeCode::kDLFloat8_e5m2: bound = support::kMaxE5M2; break; - case DataType::TypeCode::kFloat8_e5m2fnuz: + case DLDataTypeCode::kDLFloat8_e5m2fnuz: bound = support::kMaxE5M2FNUZ; nonneg = true; break; - case DataType::TypeCode::kFloat8_e8m0fnu: + case DLDataTypeCode::kDLFloat8_e8m0fnu: bound = support::kMaxE8M0FNU; nonneg = true; break; default: - TVM_FFI_THROW(InternalError) << "Unhandled float8 type: " << dtype; + TVM_FFI_THROW(InternalError) << "Unhandled float8 type: " << runtime_dtype; } if (nonneg) { TVM_FFI_CHECK_GE(value, 0, ValueError) - << "Literal value " << value << " below zero for unsigned " << dtype; + << "Literal value " << value << " below zero for unsigned " << runtime_dtype; } else { TVM_FFI_CHECK_GE(value, -bound, ValueError) - << "Literal value " << value << " below minimum of " << dtype; + << "Literal value " << value << " below minimum of " << runtime_dtype; } TVM_FFI_CHECK_LE(value, bound, ValueError) - << "Literal value " << value << " exceeds maximum of " << dtype; + << "Literal value " << value << " exceeds maximum of " << runtime_dtype; - } else if (dtype.is_float6_e2m3fn() || dtype.is_float6_e3m2fn()) { - double bound = (dtype.code() == DataType::TypeCode::kFloat6_e2m3fn) ? support::kMaxE2M3FN - : support::kMaxE3M2FN; + } else if (value_ty.MatchesCode(DLDataTypeCode::kDLFloat6_e2m3fn, + DLDataTypeCode::kDLFloat6_e3m2fn)) { + double bound = + (code == DLDataTypeCode::kDLFloat6_e2m3fn) ? support::kMaxE2M3FN : support::kMaxE3M2FN; TVM_FFI_CHECK_GE(value, -bound, ValueError) - << "Literal value " << value << " below minimum of " << dtype; + << "Literal value " << value << " below minimum of " << runtime_dtype; TVM_FFI_CHECK_LE(value, bound, ValueError) - << "Literal value " << value << " exceeds maximum of " << dtype; + << "Literal value " << value << " exceeds maximum of " << runtime_dtype; - } else if (dtype.is_float4_e2m1fn()) { + } else if (code == DLDataTypeCode::kDLFloat4_e2m1fn) { double bound = support::kMaxE2M1FN; TVM_FFI_CHECK_GE(value, -bound, ValueError) - << "Literal value " << value << " below minimum of " << dtype; + << "Literal value " << value << " below minimum of " << runtime_dtype; TVM_FFI_CHECK_LE(value, bound, ValueError) - << "Literal value " << value << " exceeds maximum of " << dtype; + << "Literal value " << value << " exceeds maximum of " << runtime_dtype; } } ffi::ObjectPtr node = ffi::make_object(); - node->dtype = dtype; + node->BaseExprNode::ty = std::move(value_ty); node->value = value; node->span = span; data_ = std::move(node); @@ -186,8 +208,8 @@ FloatImm::FloatImm(DataType dtype, double value, Span span) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ir.FloatImm", [](DataType dtype, double value, Span span) { - return FloatImm(dtype, value, span); + refl::GlobalDef().def("ir.FloatImm", [](DLDataType dtype, double value, Span span) { + return FloatImm(PrimType(dtype), value, span); }); } @@ -206,7 +228,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { if (end.defined()) { return Range(begin, end.value(), span); } else { - return Range(IntImm(begin->dtype, 0), begin, span); + return Range(IntImm(begin.ty(), 0), begin, span); } }); } diff --git a/src/ir/type.cc b/src/ir/type.cc index d6d059dba079..2464f6faa659 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -24,27 +24,120 @@ #include #include #include + +#include +#include + namespace tvm { +namespace { + +DLDataType ScalableVectorDType(DLDataTypeCode code, int bits, int lanes) { + TVM_FFI_ICHECK_GT(lanes, 1) << "Invalid value for vscale factor " << lanes; + TVM_FFI_ICHECK_LT(lanes, 32768); + return DLDataType{static_cast(code), static_cast(bits), + static_cast(-lanes)}; +} + +uint32_t PackDataTypeKey(DLDataType dtype) { + return (static_cast(dtype.code) << 24) | (static_cast(dtype.bits) << 16) | + static_cast(dtype.lanes); +} + +int64_t PrimTypeAnyHash(const ffi::Any& src) { + return static_cast(PackDataTypeKey(src.cast()->dtype)); +} + +bool PrimTypeAnyEqual(const ffi::Any& lhs, const ffi::Any& rhs) { + return lhs.cast()->dtype == rhs.cast()->dtype; +} + +ffi::ObjectPtr GetCachedPrimTypeNode(DLDataType dtype) { + thread_local std::unordered_map> cache; + uint32_t key = PackDataTypeKey(dtype); + auto it = cache.find(key); + if (it != cache.end()) { + return it->second; + } + + ffi::ObjectPtr node = ffi::make_object(); + node->dtype = dtype; + return cache.emplace(key, std::move(node)).first->second; +} + +} // namespace + TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; TypeNode::RegisterReflection(); PrimTypeNode::RegisterReflection(); + refl::TypeAttrDef() + .attr(refl::type_attr::kAnyHash, reinterpret_cast(&PrimTypeAnyHash)) + .attr(refl::type_attr::kAnyEqual, reinterpret_cast(&PrimTypeAnyEqual)); PointerTypeNode::RegisterReflection(); TupleTypeNode::RegisterReflection(); FuncTypeNode::RegisterReflection(); TensorMapTypeNode::RegisterReflection(); } -PrimType::PrimType(runtime::DataType dtype, Span span) { - ffi::ObjectPtr n = ffi::make_object(); - n->dtype = dtype; - n->span = std::move(span); - data_ = std::move(n); +PrimType::PrimType(DLDataType dtype) { data_ = GetCachedPrimTypeNode(dtype); } + +PrimType::PrimType(DLDataTypeCode code, int bits, int lanes) + : PrimType(DLDataType{static_cast(code), static_cast(bits), + static_cast(lanes)}) {} + +PrimType PrimType::Int(int bits, int lanes) { + if (lanes == 1) { + if (bits == 32) { + static const PrimType i32_ty(DLDataType{kDLInt, 32, 1}); + return i32_ty; + } + if (bits == 64) { + static const PrimType i64_ty(DLDataType{kDLInt, 64, 1}); + return i64_ty; + } + } + return PrimType(DLDataType{kDLInt, static_cast(bits), static_cast(lanes)}); +} + +PrimType PrimType::UInt(int bits, int lanes) { + return PrimType(DLDataType{kDLUInt, static_cast(bits), static_cast(lanes)}); +} + +PrimType PrimType::Float(int bits, int lanes) { + if (bits == 32 && lanes == 1) { + static const PrimType f32_ty(DLDataType{kDLFloat, 32, 1}); + return f32_ty; + } + return PrimType(DLDataType{kDLFloat, static_cast(bits), static_cast(lanes)}); +} + +PrimType PrimType::BFloat(int bits, int lanes) { + return PrimType(DLDataType{kDLBfloat, static_cast(bits), static_cast(lanes)}); +} + +PrimType PrimType::Bool(int lanes) { + if (lanes == 1) { + static const PrimType bool_ty(DLDataType{kDLBool, 8, 1}); + return bool_ty; + } + return PrimType(DLDataType{kDLBool, 8, static_cast(lanes)}); +} + +PrimType PrimType::Handle(int bits, int lanes) { + return PrimType( + DLDataType{kDLOpaqueHandle, static_cast(bits), static_cast(lanes)}); +} + +PrimType PrimType::Void() { return PrimType(DLDataType{kDLOpaqueHandle, 0, 0}); } + +PrimType PrimType::ScalableVector(DLDataTypeCode code, int bits, int lanes) { + return PrimType(ScalableVectorDType(code, bits, lanes)); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ir.PrimType", [](runtime::DataType dtype) { return PrimType(dtype); }); + refl::GlobalDef().def("ir.PrimType", [](DLDataType dtype) { return PrimType(dtype); }); } PointerType::PointerType(Type element_type, ffi::String storage_scope) { diff --git a/src/relax/analysis/tir_op_pattern_kind.cc b/src/relax/analysis/tir_op_pattern_kind.cc index 369f5793d9b5..0bfb48cca94c 100644 --- a/src/relax/analysis/tir_op_pattern_kind.cc +++ b/src/relax/analysis/tir_op_pattern_kind.cc @@ -478,8 +478,8 @@ bool HasReshapePattern(const PrimFunc& func) { } if (nontrivial_indices.defined()) { - DataType dtype = - !block->iter_vars.empty() ? block->iter_vars[0]->var->dtype : DataType::Int(64); + PrimType dtype = + !block->iter_vars.empty() ? block->iter_vars[0]->var.ty() : PrimType::Int(64); tirx::Var fused_var("fused", dtype); ffi::Map inverse_indices_map; PrimExpr stride = IntImm(dtype, /*value=*/1); @@ -494,7 +494,8 @@ bool HasReshapePattern(const PrimFunc& func) { ffi::Array simplify_res = arith::IterMapSimplify( /*indices=*/{flattened_idx}, - /*input_iters=*/{{fused_var, Range(IntImm(dtype, /*value=*/0), stride)}}, + /*input_iters=*/ + ffi::Map{{fused_var, Range(IntImm(dtype, /*value=*/0), stride)}}, /*input_pred=*/IntImm::Bool(true), /*check_level=*/arith::IterMapLevel::Surjective, /*analyzer=*/this->ana_, diff --git a/src/relax/analysis/type_analysis.cc b/src/relax/analysis/type_analysis.cc index 33070051ae63..34f5a4de6216 100644 --- a/src/relax/analysis/type_analysis.cc +++ b/src/relax/analysis/type_analysis.cc @@ -43,7 +43,7 @@ class StaticTypeDeriver : public TypeFunctor { public: Type VisitType_(const ObjectTypeNode* op) final { return ObjectType(op->span); } - Type VisitType_(const PrimTypeNode* op) final { return PrimType(op->dtype, op->span); } + Type VisitType_(const PrimTypeNode* op) final { return tvm::PrimType(op->dtype); } Type VisitType_(const ShapeTypeNode* op) final { return ShapeType(op->ndim, op->span); } @@ -86,7 +86,9 @@ Type TypeFromStaticType(const Type& type) { if (type.as()) { return ObjectType(type->span); } else if (const PrimTypeNode* prim_type = type.as()) { - return PrimType(prim_type->dtype, prim_type->span); + return tvm::PrimType(prim_type->dtype); + } else if (const tvm::PrimTypeNode* prim_type = type.as()) { + return tvm::PrimType(prim_type->dtype); } else if (const ShapeTypeNode* shape_type = type.as()) { return ShapeType(shape_type->ndim, type->span); } else if (const TensorTypeNode* tensor_type = type.as()) { @@ -221,9 +223,9 @@ class WellDefinedEraser : public TypeMutator, public ExprMutatorBase, public tir if (ret.defined()) { PrimExpr value = ret.value(); if (value->IsInstance()) { - return tvm::cast(DataType::Int(64), value); + return tvm::cast(PrimType::Int(64), value); } - TVM_FFI_ICHECK(value.dtype() == DataType::Int(64)) + TVM_FFI_ICHECK(value.ty().MatchesElementType(DLDataTypeCode::kDLInt, 64)) << "Can only provide i64 expressions in shape"; return value; } else { @@ -1015,7 +1017,9 @@ class TypeLCAFinder : public TypeFunctor { if (rhs == nullptr) return ObjectType(lhs->span); // find the target dtype, ndim, and vdevice. - DataType dtype = lhs->dtype == rhs->dtype ? lhs->dtype : DataType::Void(); + PrimType dtype = lhs->dtype->dtype == rhs->dtype->dtype + ? PrimType(lhs->dtype->dtype) + : PrimType(DLDataType{kDLOpaqueHandle, 0, 0}); int ndim = lhs->ndim == rhs->ndim ? lhs->ndim : kUnknownNDim; VDevice vdev = VDevice(); if (lhs->vdevice.defined() && rhs->vdevice.defined() && @@ -1028,7 +1032,7 @@ class TypeLCAFinder : public TypeFunctor { !CanProveShapeEqual(lhs->shape.value(), rhs->shape.value(), ffi::GetRef(analyzer_))) { // reuse lhs when possible - if (!lhs->shape.defined() && lhs->dtype == dtype && lhs->ndim == ndim && + if (!lhs->shape.defined() && lhs->dtype->dtype == dtype->dtype && lhs->ndim == ndim && (!lhs->vdevice.defined() || vdev.defined())) { return ffi::GetRef(lhs); } else { @@ -1036,7 +1040,7 @@ class TypeLCAFinder : public TypeFunctor { } } // symbolic shape and vdevice match but dtype mismatch - if (lhs->dtype != dtype || (lhs->vdevice.defined() && !vdev.defined())) { + if (lhs->dtype->dtype != dtype->dtype || (lhs->vdevice.defined() && !vdev.defined())) { return TensorType(lhs->shape.value(), dtype, vdev, lhs->span); } else { return ffi::GetRef(lhs); diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 5c3547249c5e..52e974be75f0 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -457,9 +457,9 @@ class WellFormedChecker : public relax::ExprVisitor, for (PrimExpr expr : op->values) { // check if the symbolic vars in the expr are defined, e.g, 2 * m tirx::ExprVisitor::VisitExpr(expr); - if (!expr.dtype().is_int()) { + if (expr.ty().code() != DLDataTypeCode::kDLInt) { TVM_FFI_VISIT_THROW(TypeError, expr) - << "Shape expressions must be of integer type, but got " << expr.dtype(); + << "Shape expressions must be of integer type, but got " << expr.ty()->dtype; } } CheckType(op); diff --git a/src/relax/backend/contrib/codegen_c/codegen_c.h b/src/relax/backend/contrib/codegen_c/codegen_c.h index 1a5fb1dd801e..0c36b04812c8 100644 --- a/src/relax/backend/contrib/codegen_c/codegen_c.h +++ b/src/relax/backend/contrib/codegen_c/codegen_c.h @@ -347,19 +347,20 @@ class CodegenCBase { */ std::string GetDtypeString(const TensorTypeNode* tensor_ty) { std::string dtype; - if (runtime::TypeMatch(tensor_ty->dtype, kDLFloat, 32)) { + DLDataType raw_dtype = tensor_ty->dtype->dtype; + if (raw_dtype == DLDataType{kDLFloat, 32, 1}) { dtype = "float"; - } else if (runtime::TypeMatch(tensor_ty->dtype, kDLFloat, 16)) { + } else if (raw_dtype == DLDataType{kDLFloat, 16, 1}) { dtype = "half"; - } else if (runtime::TypeMatch(tensor_ty->dtype, kDLBfloat, 16)) { + } else if (raw_dtype == DLDataType{kDLBfloat, 16, 1}) { dtype = "bfloat"; - } else if (runtime::TypeMatch(tensor_ty->dtype, kDLInt, 32)) { + } else if (raw_dtype == DLDataType{kDLInt, 32, 1}) { dtype = "int"; - } else if (runtime::TypeMatch(tensor_ty->dtype, kDLInt, 64)) { + } else if (raw_dtype == DLDataType{kDLInt, 64, 1}) { dtype = "int64_t"; - } else if (runtime::TypeMatch(tensor_ty->dtype, kDLInt, 8)) { + } else if (raw_dtype == DLDataType{kDLInt, 8, 1}) { dtype = "int8_t"; - } else if (runtime::TypeMatch(tensor_ty->dtype, kDLUInt, 8)) { + } else if (raw_dtype == DLDataType{kDLUInt, 8, 1}) { dtype = "uint8_t"; } else { TVM_FFI_THROW(InternalError) << "Unsupported dtype " << tensor_ty->dtype; diff --git a/src/relax/backend/contrib/cublas/codegen.cc b/src/relax/backend/contrib/cublas/codegen.cc index 5284de94f622..f2999b172136 100644 --- a/src/relax/backend/contrib/cublas/codegen.cc +++ b/src/relax/backend/contrib/cublas/codegen.cc @@ -86,11 +86,11 @@ class CublasJSONSerializer : public JSONSerializer { const auto* const_expr = dequantize_call->args[1].as(); auto ty = const_expr->ty.as_or_throw(); float alpha = 1.0; - if (ty->dtype == DataType::Float(16)) { + if (ty->dtype == PrimType::Float(16)) { alpha = __extendXfYf2__( static_cast(const_expr->data->data)[0]); } else { - TVM_FFI_ICHECK(ty->dtype == DataType::Float(32)); + TVM_FFI_ICHECK(ty->dtype == PrimType::Float(32)); alpha = static_cast(const_expr->data->data)[0]; } diff --git a/src/relax/backend/contrib/utils.h b/src/relax/backend/contrib/utils.h index 93916bf23236..6147a6eb2199 100644 --- a/src/relax/backend/contrib/utils.h +++ b/src/relax/backend/contrib/utils.h @@ -59,9 +59,7 @@ inline std::vector GetIntShape(const ffi::Array& shape) { * \param typ * \return std::string string format of type */ -inline std::string DType2String(const tvm::DataType dtype) { - return tvm::ffi::DLDataTypeToString(dtype); -} +inline std::string DType2String(DLDataType dtype) { return tvm::ffi::DLDataTypeToString(dtype); } /*! * \brief Check if a call node is calling an op with the given name diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index c1e9af85511c..3e2ac365d4fb 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -88,19 +88,19 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { PrimExpr RegListGet(int64_t slot) const { // use 128 bits to represent any - return tirx::Call(DataType::Handle(), tirx::builtin::anylist_getitem(), + return tirx::Call(tvm::PrimType::Handle(), tirx::builtin::anylist_getitem(), {reg_anylist_handle_, ConstInt32(slot)}); } PrimExpr ConstListGet(int64_t slot) const { // use 128 bits to represent any - return tirx::Call(DataType::Handle(), tirx::builtin::anylist_getitem(), + return tirx::Call(tvm::PrimType::Handle(), tirx::builtin::anylist_getitem(), {const_anylist_handle_, ConstInt32(slot)}); } PrimExpr FuncListGet(int64_t slot) const { // use 128 bits to represent any - return tirx::Call(DataType::Handle(), tirx::builtin::anylist_getitem(), + return tirx::Call(tvm::PrimType::Handle(), tirx::builtin::anylist_getitem(), {func_anylist_handle_, ConstInt32(slot)}); } @@ -121,11 +121,11 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { all_args.push_back(arg); } if (dst_anylist_slot >= 0) { - this->EmitStmt(tirx::Evaluate( - tirx::Call(DataType::Int(32), tirx::builtin::anylist_setitem_call_packed(), all_args))); + this->EmitStmt(tirx::Evaluate(tirx::Call( + tvm::PrimType::Int(32), tirx::builtin::anylist_setitem_call_packed(), all_args))); } else { this->EmitStmt(tirx::Evaluate( - tirx::Call(DataType::Int(32), tirx::builtin::tvm_call_packed(), all_args))); + tirx::Call(tvm::PrimType::Int(32), tirx::builtin::tvm_call_packed(), all_args))); } } @@ -143,11 +143,11 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { all_args.push_back(arg); } if (dst_anylist_slot >= 0) { - this->EmitStmt(tirx::Evaluate( - tirx::Call(DataType::Int(32), tirx::builtin::anylist_setitem_call_cpacked(), all_args))); + this->EmitStmt(tirx::Evaluate(tirx::Call( + tvm::PrimType::Int(32), tirx::builtin::anylist_setitem_call_cpacked(), all_args))); } else { this->EmitStmt(tirx::Evaluate( - tirx::Call(DataType::Int(32), tirx::builtin::tvm_call_cpacked(), all_args))); + tirx::Call(tvm::PrimType::Int(32), tirx::builtin::tvm_call_cpacked(), all_args))); } } @@ -160,10 +160,10 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { stmt_stack_ = {}; registers_num_ = 0; var_map_.clear(); - ctx_ptr_ = tirx::Var("ctx_ptr", DataType::Handle()); - reg_anylist_handle_ = tirx::Var("r", DataType::Handle()); - func_anylist_handle_ = tirx::Var("f", DataType::Handle()); - const_anylist_handle_ = tirx::Var("c", DataType::Handle()); + ctx_ptr_ = tirx::Var("ctx_ptr", PrimType::Handle()); + reg_anylist_handle_ = tirx::Var("r", PrimType::Handle()); + func_anylist_handle_ = tirx::Var("f", PrimType::Handle()); + const_anylist_handle_ = tirx::Var("c", PrimType::Handle()); ffi::Array param_names; for (Var param : func->params) { @@ -231,7 +231,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { Call call = ffi::GetRef(call_node); if (call_node->op == null_value_op_) { - return tirx::Call(DataType::Handle(), tirx::builtin::reinterpret(), {IntImm::Int64(0)}); + return tirx::Call(tvm::PrimType::Handle(), tirx::builtin::reinterpret(), {IntImm::Int64(0)}); } int64_t dst_reg = HasVoidType(call) ? -1 : NewRegister(); if (call->op.as()) { @@ -264,7 +264,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { size_t merge_register = NewRegister(); PrimExpr cond_value = this->VisitExpr(op->cond).value(); - cond_value = tirx::Call(DataType::Bool(), tirx::builtin::tvm_call_packed(), + cond_value = tirx::Call(tvm::PrimType::Bool(), tirx::builtin::tvm_call_packed(), {tirx::StringImm("vm.builtin.read_if_cond"), cond_value}); tirx::Stmt true_branch = WithNewScope([&]() { @@ -438,7 +438,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { TVM_FFI_ICHECK(tir_call->args[0].same_as(reg_anylist_handle_)); const auto* p_dst_reg = tir_call->args[1].as(); TVM_FFI_ICHECK(p_dst_reg != nullptr); - TVM_FFI_ICHECK(p_dst_reg->dtype == DataType::Int(32)); + TVM_FFI_ICHECK(p_dst_reg->ty().MatchesElementType(DLDataTypeCode::kDLInt, 32)); int64_t dst_reg = p_dst_reg->value; this->EmitCallPacked("vm.builtin.null_value", {}, dst_reg); diff --git a/src/relax/backend/vm/lower_runtime_builtin.cc b/src/relax/backend/vm/lower_runtime_builtin.cc index 344fc6a67e65..4a32efd81e5a 100644 --- a/src/relax/backend/vm/lower_runtime_builtin.cc +++ b/src/relax/backend/vm/lower_runtime_builtin.cc @@ -21,6 +21,7 @@ * \brief Lowers most builtin functions and packed calls. */ #include +#include #include #include #include @@ -29,7 +30,6 @@ #include #include #include -#include #include namespace tvm { @@ -85,7 +85,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { Expr MakeMemAllocStorage(const Call& call) { PrimValue runtime_device_index = call->args[1].as_or_throw(); StringImm storage_scope = call->args[2].as_or_throw(); - DataTypeImm output_dtype = DataTypeImm(DataType::UInt(8)); + DataTypeImm output_dtype = DataTypeImm((DLDataType{kDLUInt, 8, 1})); return Call(vm_alloc_storage_op_, {call->args[0], runtime_device_index, output_dtype, storage_scope}, Attrs()); } diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index 3d895349bbc3..6784489c5b32 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -229,7 +229,7 @@ class VMShapeLowerMutator slot_map_.clear(); current_gvar_ = gvar; PrimExprSlotCollector::Collect(func, &slot_vec_, &slot_map_); - heap_size_ = IntImm(ShapeDType(), static_cast(slot_vec_.size())); + heap_size_ = IntImm(tvm::PrimType(ShapeDType()), static_cast(slot_vec_.size())); VarBinding shape_heap_binding = this->AllocShapeHeapBinding(heap_size_); shape_heap_ = shape_heap_binding->var; @@ -298,7 +298,7 @@ class VMShapeLowerMutator //------------------------------------------------------- // PrimExpr slot handling //------------------------------------------------------- - static DataType ShapeDType() { return DataType::Int(64); } + static DLDataType ShapeDType() { return DLDataType{kDLInt, 64, 1}; } /*! \brief populate additional information in the slot. */ void PopulateSlotInfo() { @@ -329,7 +329,7 @@ class VMShapeLowerMutator VarBinding AllocShapeHeapBinding(IntImm heap_size) { if (heap_size->value > 0) { - TensorType heap_ty(ShapeDType(), 1); + TensorType heap_ty(PrimType(ShapeDType()), 1); Var var("shape_heap", heap_ty); // set up the builtin func. Call call(call_builtin_with_ctx_op_, @@ -566,7 +566,7 @@ class VMShapeLowerMutator if (to_compute.size() == 0) return 0; TVM_FFI_ICHECK_GT(heap_size_->value, 0); // construct a PrimFunc that compute the shape. - tirx::Var heap("heap", DataType::Handle()); + tirx::Var heap("heap", PrimType::Handle()); ffi::Array buffer_shape{heap_size_}; tirx::Buffer buffer = tirx::decl_buffer(buffer_shape, ShapeDType(), "H", "global"); ffi::Map buffer_map; @@ -575,7 +575,8 @@ class VMShapeLowerMutator auto var_map = [&](const tirx::Var& var) -> ffi::Optional { auto it = slot_map_.find(var); TVM_FFI_ICHECK(it != slot_map_.end()); - return tirx::BufferLoad(buffer, {IntImm(ShapeDType(), it->second->index)}); + return tirx::BufferLoad( + buffer, ffi::Array{IntImm(tvm::PrimType(ShapeDType()), it->second->index)}); }; ffi::Array seq; @@ -583,7 +584,8 @@ class VMShapeLowerMutator TVM_FFI_ICHECK(!slot->value_computed); slot->value_computed = true; PrimExpr value = tirx::Substitute(slot->expr, var_map); - seq.push_back(tirx::BufferStore(buffer, value, {IntImm(ShapeDType(), slot->index)})); + seq.push_back( + tirx::BufferStore(buffer, value, {IntImm(tvm::PrimType(ShapeDType()), slot->index)})); } tirx::Stmt body = tirx::SeqStmt::Flatten(seq); @@ -678,10 +680,11 @@ class VMShapeLowerMutator // if we only check dynamic shapes, and the shape is static, we can skip. return; } - if (always_check || !IsBaseOf(TensorType(op->dtype, op->ndim), GetType(value))) { + if (always_check || !IsBaseOf(TensorType(PrimType(op->dtype), op->ndim), GetType(value))) { // check_tensor_info(value, ndim, dtype, err_ctx) Call call(builtin_check_tensor_info_, - {value, PrimValue::Int64(op->ndim), DataTypeImm(op->dtype), GetErrContext(err_ctx)}, + {value, PrimValue::Int64(op->ndim), DataTypeImm(op->dtype->dtype), + GetErrContext(err_ctx)}, Attrs(), {void_ty_}); builder_->Emit(call, "_"); } diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc index 7b14a1f7e7e9..10fd67de1740 100644 --- a/src/relax/ir/dataflow_expr_rewriter.cc +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -736,7 +736,7 @@ PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { return ExternFuncPattern(func->global_symbol); } else if (auto prim = expr.as()) { - return TypePattern(WildcardPattern(), PrimType(prim->value.dtype())); + return TypePattern(WildcardPattern(), PrimType(prim->value.ty())); } else { TVM_FFI_THROW(TypeError) << "Cannot convert Relax expression of type " << expr->GetTypeKey() diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 08689bd10f0b..f75c540a96cd 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -573,8 +573,7 @@ bool DFPatternMatcher::VisitDFPattern_(const DataTypePatternNode* op, const Expr // no need to jump, as var.dtype == value.dtype auto expr_ty = expr.as()->ty; if (const TensorTypeNode* tensor_ty = expr_ty.as()) { - return (ffi::StructuralEqual()(op->dtype, tensor_ty->dtype)) && - VisitDFPattern(op->pattern, expr); + return op->dtype == tensor_ty->dtype->dtype && VisitDFPattern(op->pattern, expr); } return false; } diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc index 5cb5352ec6c2..6302ee85049a 100644 --- a/src/relax/ir/dataflow_pattern.cc +++ b/src/relax/ir/dataflow_pattern.cc @@ -369,15 +369,15 @@ RELAX_PATTERN_PRINTER_DEF(SameShapeConstraintNode, [](auto p, auto node) { p->stream << ")"; }); -DataTypePattern::DataTypePattern(DFPattern pattern, DataType dtype) { +DataTypePattern::DataTypePattern(DFPattern pattern, DLDataType dtype) { ffi::ObjectPtr n = ffi::make_object(); n->pattern = std::move(pattern); - n->dtype = std::move(dtype); + n->dtype = dtype; data_ = std::move(n); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.dpl.DataTypePattern", [](DFPattern pattern, DataType dtype) { + refl::GlobalDef().def("relax.dpl.DataTypePattern", [](DFPattern pattern, DLDataType dtype) { return DataTypePattern(pattern, dtype); }); } @@ -474,11 +474,11 @@ AttrPattern DFPattern::HasAttr(const ffi::Map& attrs) const { return AttrPattern(*this, DictAttrs(attrs)); } TypePattern DFPattern::HasType(const Type& ty) const { return TypePattern(*this, ty); } -DataTypePattern DFPattern::HasDtype(const DataType& dtype) const { +DataTypePattern DFPattern::HasDtype(DLDataType dtype) const { return DataTypePattern(*this, dtype); } DataTypePattern DFPattern::HasDtype(const std::string& dtype) const { - return HasDtype(DataType(ffi::StringToDLDataType(dtype))); + return HasDtype(ffi::StringToDLDataType(dtype)); } ShapePattern DFPattern::HasShape(const ffi::Array& shape) const { return ShapePattern(*this, shape); diff --git a/src/relax/ir/dependent_type.cc b/src/relax/ir/dependent_type.cc index 6a2034ccc2a8..d95ebb1534e7 100644 --- a/src/relax/ir/dependent_type.cc +++ b/src/relax/ir/dependent_type.cc @@ -54,9 +54,9 @@ ShapeType::ShapeType(ffi::Array values, Span span) { n->ndim = static_cast(values.size()); n->values = values.Map([](PrimExpr value) { if (value->IsInstance()) { - return tvm::cast(DataType::Int(64), value); + return tvm::cast(PrimType::Int(64), value); } - TVM_FFI_ICHECK(value.dtype() == DataType::Int(64)) + TVM_FFI_ICHECK(value.ty().MatchesElementType(DLDataTypeCode::kDLInt, 64)) << "the value in ShapeType can only have dtype of int64"; return value; }); @@ -86,7 +86,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } // Tensor -TensorType::TensorType(Expr shape, DataType dtype, ffi::Optional vdevice, Span span) { +TensorType::TensorType(Expr shape, PrimType dtype, ffi::Optional vdevice, Span span) { ffi::ObjectPtr n = ffi::make_object(); // assign ndim before move TVM_FFI_ICHECK(shape.defined()) << "Must provide a shape in this constructor"; @@ -103,7 +103,7 @@ TensorType::TensorType(Expr shape, DataType dtype, ffi::Optional vdevic data_ = std::move(n); } -TensorType::TensorType(DataType dtype, int ndim, ffi::Optional vdevice, Span span) { +TensorType::TensorType(PrimType dtype, int ndim, ffi::Optional vdevice, Span span) { ffi::ObjectPtr n = ffi::make_object(); TVM_FFI_ICHECK(ndim >= -1) << "ndim of TensorType must be >= -1, but got " << ndim; n->ndim = ndim; @@ -116,13 +116,14 @@ TensorType::TensorType(DataType dtype, int ndim, ffi::Optional vdevice, TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "relax.TensorType", [](ffi::Optional shape, ffi::Optional dtype, int ndim, + "relax.TensorType", [](ffi::Optional shape, ffi::Optional dtype, int ndim, VDevice vdevice, Span span) { + PrimType resolved_dtype = dtype.value_or(PrimType(DLDataType{kDLOpaqueHandle, 0, 0})); if (shape.defined()) { TVM_FFI_CHECK_EQ(ndim, kUnknownNDim, ValueError) << "Cannot both specify shape and ndim"; - return TensorType(shape.value(), dtype.value_or(DataType::Void()), vdevice, span); + return TensorType(shape.value(), resolved_dtype, vdevice, span); } else { - return TensorType(dtype.value_or(DataType::Void()), ndim, vdevice, span); + return TensorType(resolved_dtype, ndim, vdevice, span); } }); } diff --git a/src/relax/ir/emit_te.cc b/src/relax/ir/emit_te.cc index 304911c1dca2..68e48eaf93b6 100644 --- a/src/relax/ir/emit_te.cc +++ b/src/relax/ir/emit_te.cc @@ -42,7 +42,7 @@ te::Tensor TETensor(Expr value, ffi::Map tir_var_map, std:: // checked-type might not be properly set. In this case we set the shape and dtype of the returned // TE tensor. if (const auto* constant = value.as()) { - n->dtype = DataType(constant->data->dtype); + n->dtype = PrimType(constant->data->dtype); int ndim = constant->data->ndim; ffi::Shape shape_tuple = constant->data.Shape(); diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 11e80135500a..b4c4486f0dd4 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -257,9 +257,9 @@ ShapeExpr::ShapeExpr(ffi::Array values, Span span) { n->values = values.Map([](PrimExpr value) { if (value->IsInstance()) { - return tvm::cast(DataType::Int(64), value); + return tvm::cast(PrimType::Int(64), value); } - TVM_FFI_ICHECK(value.dtype() == DataType::Int(64)) + TVM_FFI_ICHECK(value.ty().MatchesElementType(DLDataTypeCode::kDLInt, 64)) << "the value in ShapeType can only have dtype of int64"; return value; }); @@ -350,7 +350,7 @@ Constant::Constant(runtime::Tensor data, ffi::Optional ty_annotation, Span if (ty_annotation.defined()) { n->ty = ty_annotation.value(); } else { - TensorType tinfo(ShapeExpr(values), n->data.DataType(), VDevice(), span); + TensorType tinfo(ShapeExpr(values), PrimType(n->data.DataType()), VDevice(), span); n->ty = tinfo; } @@ -366,7 +366,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { PrimValue::PrimValue(PrimExpr value, Span span) { ffi::ObjectPtr n = ffi::make_object(); - n->ty = PrimType(value.dtype()); + n->ty = PrimType(value.ty()); n->value = std::move(value); n->span = std::move(span); data_ = std::move(n); @@ -396,9 +396,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](ffi::String value, Span span) { return StringImm(value, span); }); } -DataTypeImm::DataTypeImm(DataType value, Span span) { +DataTypeImm::DataTypeImm(DLDataType value, Span span) { ffi::ObjectPtr n = ffi::make_object(); - n->value = std::move(value); + n->value = value; n->span = std::move(span); n->ty = ObjectType(); data_ = std::move(n); @@ -407,7 +407,7 @@ DataTypeImm::DataTypeImm(DataType value, Span span) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.DataTypeImm", - [](DataType value, Span span) { return DataTypeImm(value, span); }); + [](DLDataType value, Span span) { return DataTypeImm(value, span); }); } MatchCast::MatchCast(Var var, Expr value, Type ty, Span span) { diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc index dd67f65dea09..15b8064d2b6f 100644 --- a/src/relax/op/ccl/ccl.cc +++ b/src/relax/op/ccl/ccl.cc @@ -85,7 +85,7 @@ Type InferTypeAllGather(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); int num_workers = attrs->num_workers; - DataType output_dtype = input_ty->dtype; + PrimType output_dtype = input_ty->dtype; auto input_shape = input_ty->GetShape(); if (!input_shape.defined()) { return input_ty; @@ -143,7 +143,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { Type InferTypeScatter(const Call& call, const BlockBuilder& ctx) { TensorType input_ty = GetUnaryInputTensorType(call, ctx); - DataType output_dtype = input_ty->dtype; + PrimType output_dtype = input_ty->dtype; const auto* attrs = call->attrs.as(); int num_workers = attrs->num_workers; diff --git a/src/relax/op/distributed/binary.cc b/src/relax/op/distributed/binary.cc index 766d60edb86f..daaacff4121b 100644 --- a/src/relax/op/distributed/binary.cc +++ b/src/relax/op/distributed/binary.cc @@ -31,7 +31,7 @@ Type InferDistTypeBroadcastCMP(const Call& call, const BlockBuilder& ctx) { return InferDistTypeBroadcast( call, ctx, [](const Call& call, const BlockBuilder& ctx, const TensorType& x1_ty, - const TensorType& x2_ty) { return DataType::Bool(); }); + const TensorType& x2_ty) { return DLDataType{kDLBool, 8, 1}; }); } /***************** Arithmetic operators *****************/ diff --git a/src/relax/op/distributed/binary.h b/src/relax/op/distributed/binary.h index 5fd39b50f364..a6d3fd9ba124 100644 --- a/src/relax/op/distributed/binary.h +++ b/src/relax/op/distributed/binary.h @@ -41,8 +41,8 @@ Type InferDistTypeBroadcast(const Call& call, const BlockBuilder& ctx, FType f_c TensorType x1_ty = input_dtensor_tys[0]->tensor_ty; TensorType x2_ty = input_dtensor_tys[1]->tensor_ty; - // DateType - DataType output_dtype = f_compute_out_dtype(call, ctx, x1_ty, x2_ty); + // Dtype + PrimType output_dtype(f_compute_out_dtype(call, ctx, x1_ty, x2_ty)); // ndims TVM_FFI_ICHECK(!x1_ty->IsUnknownNdim() && !x2_ty->IsUnknownNdim()) diff --git a/src/relax/op/distributed/distributed.cc b/src/relax/op/distributed/distributed.cc index b009630070cd..ff5bc986c0c7 100644 --- a/src/relax/op/distributed/distributed.cc +++ b/src/relax/op/distributed/distributed.cc @@ -154,7 +154,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { Type InferTypeRtoS(const Call& call, const BlockBuilder& ctx) { TensorType input_ty = GetUnaryInputTensorType(call, ctx); - DataType output_dtype = input_ty->dtype; + PrimType output_dtype = input_ty->dtype; const auto* attrs = call->attrs.as(); int num_workers = attrs->num_workers; diff --git a/src/relax/op/distributed/linear_algebra.cc b/src/relax/op/distributed/linear_algebra.cc index 80fccbe115a9..b498f1a4a953 100644 --- a/src/relax/op/distributed/linear_algebra.cc +++ b/src/relax/op/distributed/linear_algebra.cc @@ -32,9 +32,9 @@ Type InferDistTypeMatmul(const Call& call, const BlockBuilder& ctx) { TensorType x2_ty = input_dtensor_tys[1]->tensor_ty; const auto* attrs = call->attrs.as(); - DataType out_dtype = attrs->out_dtype.is_void() - ? InferBinaryArithOpOutDtype(call, ctx, x1_ty, x2_ty) - : attrs->out_dtype; + PrimType out_dtype = PrimType(attrs->out_dtype == DLDataType{kDLOpaqueHandle, 0, 0} + ? InferBinaryArithOpOutDtype(call, ctx, x1_ty, x2_ty) + : attrs->out_dtype); if (x1_ty->IsUnknownNdim() || x2_ty->IsUnknownNdim()) { TVM_FFI_VISIT_THROW(ValueError, call) diff --git a/src/relax/op/distributed/nn.cc b/src/relax/op/distributed/nn.cc index 1339a18e72d0..386401521974 100644 --- a/src/relax/op/distributed/nn.cc +++ b/src/relax/op/distributed/nn.cc @@ -33,7 +33,9 @@ Type InferDistTypeSoftmax(const Call& call, const BlockBuilder& ctx) { if (input_tensor_ty->IsUnknownNdim()) { TVM_FFI_VISIT_THROW(ValueError, call) << "Input of distributed operator must have known ndim"; } - if (!input_tensor_ty->IsUnknownDtype() && !input_tensor_ty->dtype.is_float()) { + PrimType input_dtype = input_tensor_ty->dtype; + // Softmax validation preserves the old float-kind check; lanes do not affect this policy. + if (!input_tensor_ty->IsUnknownDtype() && input_dtype.code() != DLDataTypeCode::kDLFloat) { TVM_FFI_VISIT_THROW(TypeError, call) << "Softmax requires the input tensor to have float " "dtype. However, the given input dtype is " << input_tensor_ty->dtype; diff --git a/src/relax/op/distributed/unary.cc b/src/relax/op/distributed/unary.cc index 4356b403c6d9..8e4ccce23a9c 100644 --- a/src/relax/op/distributed/unary.cc +++ b/src/relax/op/distributed/unary.cc @@ -25,7 +25,7 @@ namespace distributed { Type InferDistTypeUnaryCheck(const Call& call, const BlockBuilder& ctx) { return InferDistTypeUnary(call, ctx, - [](const TensorType& input_ty) { return DataType::Bool(); }); + [](const TensorType& input_ty) { return PrimType::Bool(); }); } RELAX_REGISTER_UNARY_ARITH_DIST_INFER_TYPE(abs, /*require_float_dtype=*/false); diff --git a/src/relax/op/distributed/unary.h b/src/relax/op/distributed/unary.h index 92c719ad0b98..be7ca27d3ade 100644 --- a/src/relax/op/distributed/unary.h +++ b/src/relax/op/distributed/unary.h @@ -40,15 +40,22 @@ Type InferDistTypeUnary(const Call& call, const BlockBuilder& ctx, FType f_compu distributed::DTensorType input_dtensor_ty = input_dtensor_tys[0]; TensorType input_tensor_ty = input_dtensor_ty->tensor_ty; + PrimType input_dtype = input_tensor_ty->dtype; + // Unary op validation preserves the old float-kind check; lanes do not affect this policy. if (require_float_dtype && !input_tensor_ty->IsUnknownDtype() && - !input_tensor_ty->dtype.is_float()) { + input_dtype.code() != DLDataTypeCode::kDLFloat) { TVM_FFI_VISIT_THROW(TypeError, call) << call->op << " requires the input tensor to have float dtype. However, the given input dtype is " << input_tensor_ty->dtype; } auto output_ty = ffi::make_object(*input_tensor_ty.get()); - output_ty->dtype = f_compute_out_dtype(input_tensor_ty); + auto computed_dtype = f_compute_out_dtype(input_tensor_ty); + if constexpr (std::is_same_v, PrimType>) { + output_ty->dtype = computed_dtype; + } else { + output_ty->dtype = PrimType(computed_dtype); + } TensorType out_tensor_ty(output_ty); return distributed::DTensorType(out_tensor_ty, input_dtensor_ty->device_mesh, input_dtensor_ty->placement); diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc index b92167e031f1..82b12c0fe26f 100644 --- a/src/relax/op/image/resize.cc +++ b/src/relax/op/image/resize.cc @@ -41,7 +41,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { Resize3DAttrs::RegisterReflection(); } Expr resize2d(Expr data, Expr size, ffi::Array roi, ffi::String layout, ffi::String method, ffi::String coordinate_transformation_mode, ffi::String rounding_method, double cubic_alpha, int cubic_exclude, - double extrapolation_value, ffi::Optional out_dtype) { + double extrapolation_value, ffi::Optional out_dtype) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->roi = std::move(roi); attrs->layout = std::move(layout); @@ -51,7 +51,7 @@ Expr resize2d(Expr data, Expr size, ffi::Array roi, ffi::String layout attrs->cubic_alpha = cubic_alpha; attrs->cubic_exclude = cubic_exclude; attrs->extrapolation_value = extrapolation_value; - attrs->out_dtype = out_dtype.value_or(DataType::Void()); + attrs->out_dtype = out_dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})); static const Op& op = Op::Get("relax.image.resize2d"); return Call(op, {std::move(data), std::move(size)}, Attrs(attrs), {}); @@ -93,7 +93,9 @@ Type InferTypeResize2D(const Call& call, const BlockBuilder& ctx) { /*tgt_layout=*/"NCHW", // /*tensor_name=*/"data"); - DataType out_dtype = attrs->out_dtype.is_void() ? data_ty->dtype : attrs->out_dtype; + PrimType out_dtype = attrs->out_dtype == DLDataType{kDLOpaqueHandle, 0, 0} + ? data_ty->dtype + : PrimType(attrs->out_dtype); ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, ffi::GetRef(data_ty), data_layout); @@ -155,7 +157,7 @@ TVM_REGISTER_OP("relax.image.resize2d") Expr resize3d(Expr data, Expr size, ffi::Array roi, ffi::String layout, ffi::String method, ffi::String coordinate_transformation_mode, ffi::String rounding_method, double cubic_alpha, int cubic_exclude, - double extrapolation_value, ffi::Optional out_dtype) { + double extrapolation_value, ffi::Optional out_dtype) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->roi = std::move(roi); attrs->layout = std::move(layout); @@ -165,7 +167,7 @@ Expr resize3d(Expr data, Expr size, ffi::Array roi, ffi::String layout attrs->cubic_alpha = cubic_alpha; attrs->cubic_exclude = cubic_exclude; attrs->extrapolation_value = extrapolation_value; - attrs->out_dtype = out_dtype.value_or(DataType::Void()); + attrs->out_dtype = out_dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})); static const Op& op = Op::Get("relax.image.resize3d"); return Call(op, {std::move(data), std::move(size)}, Attrs(attrs), {}); @@ -207,7 +209,9 @@ Type InferTypeResize3D(const Call& call, const BlockBuilder& ctx) { /*tgt_layout=*/"NCDHW", // /*tensor_name=*/"data"); - DataType out_dtype = attrs->out_dtype.is_void() ? data_ty->dtype : attrs->out_dtype; + PrimType out_dtype = attrs->out_dtype == DLDataType{kDLOpaqueHandle, 0, 0} + ? data_ty->dtype + : PrimType(attrs->out_dtype); ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, ffi::GetRef(data_ty), data_layout); @@ -315,7 +319,7 @@ Type InferTypeGridSample(const Call& call, const BlockBuilder& ctx) { /*tgt_layout=*/is_ncdhw ? "NCDHW" : "NCHW", /*tensor_name=*/"data"); - DataType out_dtype = data_ty->dtype; + PrimType out_dtype = data_ty->dtype; ffi::Optional data_shape = CheckNdimPerLayoutAndGetShape(call, ctx, ffi::GetRef(data_ty), data_layout); @@ -422,7 +426,7 @@ Type InferTypeAffineGrid(const Call& call, const BlockBuilder& ctx) { } } - DataType out_dtype = data_ty->dtype; + PrimType out_dtype = data_ty->dtype; if (data_shape == nullptr || size_value == nullptr) { return TensorType(out_dtype, /*ndim=*/4, data_ty->vdevice); diff --git a/src/relax/op/image/resize.h b/src/relax/op/image/resize.h index 382a3a162be2..1aaed69f9146 100644 --- a/src/relax/op/image/resize.h +++ b/src/relax/op/image/resize.h @@ -36,13 +36,13 @@ namespace relax { Expr resize2d(Expr data, Expr size, ffi::Array roi, ffi::String layout, ffi::String method, ffi::String coordinate_transformation_mode, ffi::String rounding_method, double cubic_alpha, int cubic_exclude, - double extrapolation_value, ffi::Optional out_dtype); + double extrapolation_value, ffi::Optional out_dtype); /*! \brief Image resize3d operator. */ Expr resize3d(Expr data, Expr size, ffi::Array roi, ffi::String layout, ffi::String method, ffi::String coordinate_transformation_mode, ffi::String rounding_method, double cubic_alpha, int cubic_exclude, - double extrapolation_value, ffi::Optional out_dtype); + double extrapolation_value, ffi::Optional out_dtype); /*! \brief Image grid_sample operator. */ Expr grid_sample(Expr data, Expr grid, ffi::String method, ffi::String layout, diff --git a/src/relax/op/memory/view.cc b/src/relax/op/memory/view.cc index 25ad9aa66d8e..f2c5b7da8614 100644 --- a/src/relax/op/memory/view.cc +++ b/src/relax/op/memory/view.cc @@ -87,7 +87,7 @@ Type InferTypeView(const Call& call, const BlockBuilder& ctx) { } }(); - auto view_dtype = [&]() -> std::optional { + auto view_dtype = [&]() -> std::optional { Type ty = GetType(arg_dtype); if (HasVoidType(arg_dtype)) { @@ -116,7 +116,7 @@ Type InferTypeView(const Call& call, const BlockBuilder& ctx) { } else if (ty.as()) { // The view changes the datatype, but we don't know what it is // being changed into. - return DataType::Void(); + return DLDataType{kDLOpaqueHandle, 0, 0}; } else { TVM_FFI_THROW(TypeError) << "Operator " << call->op << " expects the dtype argument to be a relax::DataTypeImm, " @@ -131,7 +131,7 @@ Type InferTypeView(const Call& call, const BlockBuilder& ctx) { // No byte offset is specified, so no change is applied. return IntImm::Int64(0); } else if (auto prim_ty = ty.as()) { - TVM_FFI_CHECK_EQ(prim_ty->dtype, DataType::Int(64), TypeError) + TVM_FFI_CHECK_EQ(prim_ty->dtype, (DLDataType{kDLInt, 64, 1}), TypeError) << "Operator " << call->op << " expects the relative_byte_offset to be a 64-bit integer, but received " << arg_relative_byte_offset << ", which has type " << ty; @@ -167,16 +167,15 @@ Type InferTypeView(const Call& call, const BlockBuilder& ctx) { output_ndim = data_ty->ndim; } - DataType output_dtype = view_dtype.value_or(data_ty->dtype); + DLDataType output_raw_dtype = view_dtype.value_or(data_ty->dtype->dtype); + PrimType output_dtype(output_raw_dtype); - // Helper function, returns the number of bytes per vectorized - // element. Cannot use `DataType::bytes`, as it returns the - // number of bytes per scalar element. - auto get_size_bytes = [](const DataType& dtype) -> ffi::Optional { - if (dtype.is_void()) { + // Helper function returns the number of bytes per vectorized element. + auto get_size_bytes = [](DLDataType dtype) -> ffi::Optional { + if ((((dtype).code == kDLOpaqueHandle) && ((dtype).bits == 0) && ((dtype).lanes == 0))) { return std::nullopt; } else { - auto size_bits = dtype.bits() * dtype.lanes(); + auto size_bits = ((dtype).bits) * static_cast((dtype).lanes); return IntImm::Int64((size_bits + 7) / 8); } }; @@ -199,8 +198,8 @@ Type InferTypeView(const Call& call, const BlockBuilder& ctx) { ffi::Optional input_nelements = get_num_elements(input_shape); ffi::Optional output_nelements = get_num_elements(output_shape); - ffi::Optional input_element_size = get_size_bytes(data_ty->dtype); - ffi::Optional output_element_size = get_size_bytes(output_dtype); + ffi::Optional input_element_size = get_size_bytes(data_ty->dtype->dtype); + ffi::Optional output_element_size = get_size_bytes(output_raw_dtype); if (input_nelements && output_nelements && input_element_size && output_element_size && view_relative_byte_offset) { @@ -329,8 +328,9 @@ Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) { } if (HasVoidType(dtype)) { - auto data_dtype = data->ty.as().value()->dtype; - TVM_FFI_ICHECK(!data_dtype.is_void()) + DLDataType data_dtype = data->ty.as().value()->dtype->dtype; + TVM_FFI_ICHECK(!(((data_dtype).code == kDLOpaqueHandle) && ((data_dtype).bits == 0) && + ((data_dtype).lanes == 0))) << "Legalization of " << call->op << " requires that either the output dtype be explicitly specified, " << "or the input dtype is known. " diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index 83080537c1d0..62e7d2959346 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -143,7 +143,7 @@ Type InferTypeAttention(const Call& call, const BlockBuilder& ctx) { return TensorType(ShapeExpr(output_shape), q_ty->dtype, q_ty->vdevice); } -Call InferMixedPrecisionAttention(const Call& call, const DataType& out_dtype) { +Call InferMixedPrecisionAttention(const Call& call, DLDataType out_dtype) { return attention(call->args[0], call->args[1], call->args[2], std::nullopt, std::nullopt, std::nullopt, std::nullopt) .as_or_throw(); diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index 1fa9b9b1ae94..90d58a9e662d 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -47,7 +47,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { Expr conv1d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, - ffi::Optional out_dtype) { + ffi::Optional out_dtype) { padding = GetCompletePadding1D(std::move(padding)); TVM_FFI_ICHECK_GT(groups, 0) @@ -62,7 +62,8 @@ Expr conv1d(Expr data, Expr weight, ffi::Array strides, ffi::Array(std::move(data), std::move(weight), std::move(strides), std::move(padding), std::move(dilation), groups, data_layout, std::move(kernel_layout), out_layout.value_or(data_layout), - out_dtype.value_or(DataType::Void()), /*op_name=*/"relax.nn.conv1d"); + out_dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})), + /*op_name=*/"relax.nn.conv1d"); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -91,9 +92,9 @@ Type InferTypeConv1d(const Call& call, const BlockBuilder& ctx) { ffi::Optional weight_shape = CheckNdimPerLayoutAndGetShape(call, ctx, weight_ty, weight_layout); - DataType out_dtype = attrs->out_dtype.is_void() - ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) - : attrs->out_dtype; + PrimType out_dtype = PrimType(attrs->out_dtype == DLDataType{kDLOpaqueHandle, 0, 0} + ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) + : attrs->out_dtype); ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_ty, weight_ty); if (!data_shape.defined() || !weight_shape.defined()) { return TensorType(out_dtype, out_layout.ndim(), vdevice); @@ -186,7 +187,7 @@ InferLayoutOutput InferLayoutConv1d( return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); } -Call InferMixedPrecisionConv1d(const Call& call, const DataType& out_dtype) { +Call InferMixedPrecisionConv1d(const Call& call, DLDataType out_dtype) { const auto* conv1d_attrs = call->attrs.as(); return conv1d(call->args[0], call->args[1], conv1d_attrs->strides, conv1d_attrs->padding, conv1d_attrs->dilation, conv1d_attrs->groups, conv1d_attrs->data_layout, @@ -210,7 +211,7 @@ TVM_REGISTER_OP("relax.nn.conv1d") Expr conv2d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, - ffi::Optional out_dtype) { + ffi::Optional out_dtype) { padding = GetCompletePadding2D(std::move(padding)); if (strides.size() == 1) { strides.push_back(strides[0]); @@ -231,7 +232,8 @@ Expr conv2d(Expr data, Expr weight, ffi::Array strides, ffi::Array(std::move(data), std::move(weight), std::move(strides), std::move(padding), std::move(dilation), groups, data_layout, std::move(kernel_layout), out_layout.value_or(data_layout), - out_dtype.value_or(DataType::Void()), /*op_name=*/"relax.nn.conv2d"); + out_dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})), + /*op_name=*/"relax.nn.conv2d"); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -260,9 +262,9 @@ Type InferTypeConv2d(const Call& call, const BlockBuilder& ctx) { ffi::Optional weight_shape = CheckNdimPerLayoutAndGetShape(call, ctx, weight_ty, weight_layout); - DataType out_dtype = attrs->out_dtype.is_void() - ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) - : attrs->out_dtype; + PrimType out_dtype = PrimType(attrs->out_dtype == DLDataType{kDLOpaqueHandle, 0, 0} + ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) + : attrs->out_dtype); ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_ty, weight_ty); if (!data_shape.defined() || !weight_shape.defined()) { return TensorType(out_dtype, out_layout.ndim(), vdevice); @@ -336,9 +338,10 @@ InferLayoutOutput InferLayoutConv2d( SLayout desired_data_layout = (*it).second[0]; SLayout desired_weight_layout = (*it).second[1]; SLayout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; - tirx::SLayout input_layout(attrs->data_layout, DataType::Int(64)); - tirx::SLayout kernel_layout(attrs->kernel_layout, DataType::Int(64)); - tirx::SLayout out_layout(attrs->out_layout, DataType::Int(64)); + tvm::PrimType i64_ty = tvm::PrimType::Int(64); + tirx::SLayout input_layout(attrs->data_layout, i64_ty); + tirx::SLayout kernel_layout(attrs->kernel_layout, i64_ty); + tirx::SLayout out_layout(attrs->out_layout, i64_ty); if ((desired_data_layout.ndim() == input_layout.ndim()) && (desired_weight_layout.ndim() == kernel_layout.ndim()) && @@ -396,7 +399,7 @@ InferLayoutOutput InferLayoutConv2d( return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); } -Call InferMixedPrecisionConv2d(const Call& call, const DataType& out_dtype) { +Call InferMixedPrecisionConv2d(const Call& call, DLDataType out_dtype) { const auto* conv2d_attrs = call->attrs.as(); return conv2d(call->args[0], call->args[1], conv2d_attrs->strides, conv2d_attrs->padding, conv2d_attrs->dilation, conv2d_attrs->groups, conv2d_attrs->data_layout, @@ -420,7 +423,7 @@ TVM_REGISTER_OP("relax.nn.conv2d") Expr conv3d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, - ffi::Optional out_dtype) { + ffi::Optional out_dtype) { padding = GetCompletePadding3D(std::move(padding)); if (strides.size() == 1) { strides.push_back(strides[0]); @@ -443,7 +446,8 @@ Expr conv3d(Expr data, Expr weight, ffi::Array strides, ffi::Array(std::move(data), std::move(weight), std::move(strides), std::move(padding), std::move(dilation), groups, data_layout, std::move(kernel_layout), out_layout.value_or(data_layout), - out_dtype.value_or(DataType::Void()), /*op_name=*/"relax.nn.conv3d"); + out_dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})), + /*op_name=*/"relax.nn.conv3d"); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -472,9 +476,9 @@ Type InferTypeConv3d(const Call& call, const BlockBuilder& ctx) { ffi::Optional weight_shape = CheckNdimPerLayoutAndGetShape(call, ctx, weight_ty, weight_layout); - DataType out_dtype = attrs->out_dtype.is_void() - ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) - : attrs->out_dtype; + PrimType out_dtype = PrimType(attrs->out_dtype == DLDataType{kDLOpaqueHandle, 0, 0} + ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) + : attrs->out_dtype); ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_ty, weight_ty); if (!data_shape.defined() || !weight_shape.defined()) { return TensorType(out_dtype, out_layout.ndim(), vdevice); @@ -581,7 +585,7 @@ InferLayoutOutput InferLayoutConv3d( return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); } -Call InferMixedPrecisionConv3d(const Call& call, const DataType& out_dtype) { +Call InferMixedPrecisionConv3d(const Call& call, DLDataType out_dtype) { const auto* conv3d_attrs = call->attrs.as(); return conv3d(call->args[0], call->args[1], conv3d_attrs->strides, conv3d_attrs->padding, conv3d_attrs->dilation, conv3d_attrs->groups, conv3d_attrs->data_layout, @@ -604,7 +608,7 @@ Expr conv1d_transpose(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, ffi::Array output_padding, ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, - ffi::Optional out_dtype) { + ffi::Optional out_dtype) { padding = GetCompletePadding1D(std::move(padding)); TVM_FFI_ICHECK_GT(groups, 0) @@ -630,7 +634,7 @@ Expr conv1d_transpose(Expr data, Expr weight, ffi::Array strides, attrs->data_layout = data_layout; attrs->kernel_layout = std::move(kernel_layout); attrs->out_layout = out_layout.value_or(data_layout); - attrs->out_dtype = std::move(out_dtype.value_or(DataType::Void())); + attrs->out_dtype = out_dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})); const Op& op = Op::Get("relax.nn.conv1d_transpose"); return Call(op, {data, weight}, Attrs(attrs), {}); } @@ -660,9 +664,9 @@ Type InferTypeConv1dTranspose(const Call& call, const BlockBuilder& ctx) { ffi::Optional weight_shape = CheckNdimPerLayoutAndGetShape(call, ctx, weight_ty, weight_layout); - DataType out_dtype = attrs->out_dtype.is_void() - ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) - : attrs->out_dtype; + PrimType out_dtype = PrimType(attrs->out_dtype == DLDataType{kDLOpaqueHandle, 0, 0} + ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) + : attrs->out_dtype); ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_ty, weight_ty); if (!data_shape.defined() || !weight_shape.defined()) { return TensorType(out_dtype, out_layout.ndim(), vdevice); @@ -758,7 +762,7 @@ InferLayoutOutput InferLayoutConv1dTranspose( return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); } -Call InferMixedPrecisionConv1dTranspose(const Call& call, const DataType& out_dtype) { +Call InferMixedPrecisionConv1dTranspose(const Call& call, DLDataType out_dtype) { const auto* conv1d_transpose_attrs = call->attrs.as(); return conv1d_transpose(call->args[0], call->args[1], conv1d_transpose_attrs->strides, conv1d_transpose_attrs->padding, conv1d_transpose_attrs->output_padding, @@ -786,7 +790,7 @@ Expr conv2d_transpose(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, ffi::Array output_padding, ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, - ffi::Optional out_dtype) { + ffi::Optional out_dtype) { padding = GetCompletePadding2D(std::move(padding)); if (output_padding.size() == 1) { output_padding.push_back(output_padding[0]); @@ -821,7 +825,7 @@ Expr conv2d_transpose(Expr data, Expr weight, ffi::Array strides, attrs->data_layout = data_layout; attrs->kernel_layout = std::move(kernel_layout); attrs->out_layout = out_layout.value_or(data_layout); - attrs->out_dtype = std::move(out_dtype.value_or(DataType::Void())); + attrs->out_dtype = out_dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})); const Op& op = Op::Get("relax.nn.conv2d_transpose"); return Call(op, {data, weight}, Attrs(attrs), {}); } @@ -852,9 +856,9 @@ Type InferTypeConv2dTranspose(const Call& call, const BlockBuilder& ctx) { ffi::Optional weight_shape = CheckNdimPerLayoutAndGetShape(call, ctx, weight_ty, weight_layout); - DataType out_dtype = attrs->out_dtype.is_void() - ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) - : attrs->out_dtype; + PrimType out_dtype = PrimType(attrs->out_dtype == DLDataType{kDLOpaqueHandle, 0, 0} + ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) + : attrs->out_dtype); ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_ty, weight_ty); if (!data_shape.defined() || !weight_shape.defined()) { return TensorType(out_dtype, out_layout.ndim(), vdevice); @@ -987,7 +991,7 @@ InferLayoutOutput InferLayoutConv2dTranspose( return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); } -Call InferMixedPrecisionConv2dTranspose(const Call& call, const DataType& out_dtype) { +Call InferMixedPrecisionConv2dTranspose(const Call& call, DLDataType out_dtype) { const auto* conv2d_transpose_attrs = call->attrs.as(); return conv2d_transpose(call->args[0], call->args[1], conv2d_transpose_attrs->strides, conv2d_transpose_attrs->padding, conv2d_transpose_attrs->output_padding, @@ -1015,7 +1019,7 @@ Expr conv3d_transpose(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, ffi::Array output_padding, ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, - ffi::Optional out_dtype) { + ffi::Optional out_dtype) { padding = GetCompletePadding3D(std::move(padding)); if (output_padding.size() == 1) { output_padding.push_back(output_padding[0]); @@ -1053,7 +1057,7 @@ Expr conv3d_transpose(Expr data, Expr weight, ffi::Array strides, attrs->data_layout = data_layout; attrs->kernel_layout = std::move(kernel_layout); attrs->out_layout = out_layout.value_or(data_layout); - attrs->out_dtype = std::move(out_dtype.value_or(DataType::Void())); + attrs->out_dtype = out_dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})); const Op& op = Op::Get("relax.nn.conv3d_transpose"); return Call(op, {data, weight}, Attrs(attrs), {}); } @@ -1084,9 +1088,9 @@ Type InferTypeConv3dTranspose(const Call& call, const BlockBuilder& ctx) { ffi::Optional weight_shape = CheckNdimPerLayoutAndGetShape(call, ctx, weight_ty, weight_layout); - DataType out_dtype = attrs->out_dtype.is_void() - ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) - : attrs->out_dtype; + PrimType out_dtype = PrimType(attrs->out_dtype == DLDataType{kDLOpaqueHandle, 0, 0} + ? InferBinaryArithOpOutDtype(call, ctx, data_ty, weight_ty) + : attrs->out_dtype); ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, data_ty, weight_ty); if (!data_shape.defined() || !weight_shape.defined()) { return TensorType(out_dtype, out_layout.ndim(), vdevice); @@ -1227,7 +1231,7 @@ InferLayoutOutput InferLayoutConv3dTranspose( return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); } -Call InferMixedPrecisionConv3dTranspose(const Call& call, const DataType& out_dtype) { +Call InferMixedPrecisionConv3dTranspose(const Call& call, DLDataType out_dtype) { const auto* conv3d_transpose_attrs = call->attrs.as(); return conv3d_transpose(call->args[0], call->args[1], conv3d_transpose_attrs->strides, conv3d_transpose_attrs->padding, conv3d_transpose_attrs->output_padding, diff --git a/src/relax/op/nn/convolution.h b/src/relax/op/nn/convolution.h index b08eb8a83ff8..b33a19f07057 100644 --- a/src/relax/op/nn/convolution.h +++ b/src/relax/op/nn/convolution.h @@ -39,7 +39,7 @@ template inline Expr MakeConv(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::String out_layout, - DataType out_dtype, std::string op_name) { + DLDataType out_dtype, std::string op_name) { auto attrs = ffi::make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); @@ -48,7 +48,7 @@ inline Expr MakeConv(Expr data, Expr weight, ffi::Array strides, attrs->data_layout = std::move(data_layout); attrs->kernel_layout = std::move(kernel_layout); attrs->out_layout = std::move(out_layout); - attrs->out_dtype = std::move(out_dtype); + attrs->out_dtype = out_dtype; const Op& op = Op::Get(op_name); return Call(op, {data, weight}, Attrs(attrs), {}); } @@ -57,19 +57,19 @@ inline Expr MakeConv(Expr data, Expr weight, ffi::Array strides, Expr conv1d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, - ffi::Optional out_dtype); + ffi::Optional out_dtype); /*! \brief 2D convolution */ Expr conv2d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, - ffi::Optional out_dtype); + ffi::Optional out_dtype); /*! \brief 3D convolution */ Expr conv3d(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, - ffi::Optional out_dtype); + ffi::Optional out_dtype); /*! * \brief One dimensional transposed convolution operator. @@ -81,7 +81,7 @@ Expr conv1d_transpose(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, ffi::Array output_padding, ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, - ffi::Optional out_dtype); + ffi::Optional out_dtype); /*! * \brief Two dimensional transposed convolution operator. @@ -93,7 +93,7 @@ Expr conv2d_transpose(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, ffi::Array output_padding, ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, - ffi::Optional out_dtype); + ffi::Optional out_dtype); /*! * \brief Three dimensional transposed convolution operator. @@ -105,7 +105,7 @@ Expr conv3d_transpose(Expr data, Expr weight, ffi::Array strides, ffi::Array padding, ffi::Array output_padding, ffi::Array dilation, int groups, ffi::String data_layout, ffi::String kernel_layout, ffi::Optional out_layout, - ffi::Optional out_dtype); + ffi::Optional out_dtype); } // namespace relax } // namespace tvm diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index b24f81c72d49..c34f7afbc79d 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -122,7 +122,9 @@ Type InferTypePRelu(const Call& call, const BlockBuilder& ctx) { if (data_ty->IsUnknownNdim()) { return data_ty; } - if (!data_ty->IsUnknownDtype() && !data_ty->dtype.is_float()) { + PrimType data_dtype = data_ty->dtype; + // PRelu preserves the old float-kind check; vector lanes are irrelevant to this check. + if (!data_ty->IsUnknownDtype() && data_dtype.code() != DLDataTypeCode::kDLFloat) { TVM_FFI_VISIT_THROW(TypeError, call) << "Prelu requires the input tensor to have float " "dtype. However, the given input dtype is " << data_ty->dtype; @@ -186,10 +188,14 @@ Type InferTypeSoftmax(const Call& call, const BlockBuilder& ctx) { if (data_ty->IsUnknownNdim()) { return data_ty; } - if (!data_ty->IsUnknownDtype() && !data_ty->dtype.is_float() && !data_ty->dtype.is_bfloat()) { - TVM_FFI_VISIT_THROW(TypeError, call) << "Softmax requires the input tensor to have float " - "dtype. However, the given input dtype is " - << data_ty->dtype; + if (!data_ty->IsUnknownDtype()) { + PrimType data_dtype = data_ty->dtype; + // Softmax only requires a floating element kind; lane encoding is irrelevant to the check. + if (data_dtype.code() != kDLFloat && data_dtype.code() != kDLBfloat) { + TVM_FFI_VISIT_THROW(TypeError, call) << "Softmax requires the input tensor to have float " + "dtype. However, the given input dtype is " + << data_ty->dtype; + } } const auto* attrs = call->attrs.as(); NormalizeAxis(call, ctx, data_ty->ndim, attrs->axis); @@ -380,10 +386,14 @@ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, axes_non_neg = NormalizeAxes(call, ctx, data_ty->ndim, axes); } int n_axis = axes.size(); - if (!data_ty->IsUnknownDtype() && (!data_ty->dtype.is_float() && !data_ty->dtype.is_bfloat())) { - TVM_FFI_VISIT_THROW(TypeError, call) - << op << " requires the input data to have float dtype. However, the given data dtype is " - << data_ty->dtype; + if (!data_ty->IsUnknownDtype()) { + PrimType data_dtype = data_ty->dtype; + // Norm ops only require a floating element kind; lane encoding is irrelevant to the check. + if (data_dtype.code() != kDLFloat && data_dtype.code() != kDLBfloat) { + TVM_FFI_VISIT_THROW(TypeError, call) + << op << " requires the input data to have float dtype. However, the given data dtype is " + << data_ty->dtype; + } } for (int i = 1; i < n_input; ++i) { if (input_ty[i]->dtype != data_ty->dtype) { @@ -462,7 +472,7 @@ Type InferTypeBatchNorm(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_ty, {attrs->axis}); - DataType dtype = input_ty[0]->dtype; + PrimType dtype = input_ty[0]->dtype; if (unknown_shape) { auto vdev = input_ty[0]->vdevice; return TupleType({TensorType(dtype, input_ty[0]->ndim, vdev), @@ -620,7 +630,9 @@ Type InferTypeGroupNorm(const Call& call, const BlockBuilder& ctx) { << channel_axis << ", axes: " << attrs->axes; } } - if (!data_ty->IsUnknownDtype() && !data_ty->dtype.is_float()) { + PrimType data_dtype = data_ty->dtype; + // GroupNorm preserves the old float-kind check; vector lanes are irrelevant to this check. + if (!data_ty->IsUnknownDtype() && data_dtype.code() != DLDataTypeCode::kDLFloat) { TVM_FFI_VISIT_THROW(TypeError, call) << op << " expects that data must be float, but got " << data_ty->dtype; } @@ -890,7 +902,7 @@ Type InferTypeCrossEntropy(const Call& call, const BlockBuilder& ctx) { TensorType label_ty = input_ty[1]; // infer dtype - DataType dtype = InferBinaryArithOpOutDtype(call, ctx, pred_ty, label_ty); + PrimType dtype(InferBinaryArithOpOutDtype(call, ctx, pred_ty, label_ty)); // infer vdevice ffi::Optional vdevice = InferBinaryArithOpOutVDevice(call, ctx, pred_ty, label_ty); @@ -1002,23 +1014,26 @@ Type InferTypeNLLLoss(const Call& call, const BlockBuilder& ctx) { } // infer dtype, vdevice - DataType output_dtype; - ffi::Optional vdevice; - if (wgt_ty != nullptr) { - output_dtype = InferBinaryArithOpOutDtype(call, ctx, ffi::GetRef(pred_ty), - ffi::GetRef(wgt_ty)); - vdevice = InferBinaryArithOpOutVDevice(call, ctx, ffi::GetRef(pred_ty), - ffi::GetRef(wgt_ty)); - } else { - output_dtype = pred_ty->dtype; - vdevice = pred_ty->vdevice; - } + PrimType output_dtype = + wgt_ty != nullptr + ? PrimType(InferBinaryArithOpOutDtype(call, ctx, ffi::GetRef(pred_ty), + ffi::GetRef(wgt_ty))) + : pred_ty->dtype; + ffi::Optional vdevice = + wgt_ty != nullptr ? InferBinaryArithOpOutVDevice(call, ctx, ffi::GetRef(pred_ty), + ffi::GetRef(wgt_ty)) + : pred_ty->vdevice; // the type of targets must be int/uint. - if (!tgt_ty->IsUnknownDtype() && !tgt_ty->dtype.is_int() && !tgt_ty->dtype.is_uint()) { - TVM_FFI_VISIT_THROW(TypeError, call) - << "NLLLoss expects the dtype of targets to be int/uint. However, the dtype of targets is " - << tgt_ty->dtype; + if (!tgt_ty->IsUnknownDtype()) { + PrimType target_dtype = tgt_ty->dtype; + // NLLLoss only needs the target element kind; vector lanes do not affect target indexing. + if (!target_dtype.MatchesCode(DLDataTypeCode::kDLInt) && + !target_dtype.MatchesCode(DLDataTypeCode::kDLUInt)) { + TVM_FFI_VISIT_THROW(TypeError, call) << "NLLLoss expects the dtype of targets to be " + "int/uint. However, the dtype of targets is " + << tgt_ty->dtype; + } } // infer ndim diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc index 856cd75c5902..84f994bc612f 100644 --- a/src/relax/op/nn/pooling.cc +++ b/src/relax/op/nn/pooling.cc @@ -275,7 +275,8 @@ InferLayoutOutput InferLayoutPool2d( ffi::ObjectPtr new_attrs = ffi::make_object(*attrs); if (layout->layout.ndim() != layout->layout.ndim_primal()) { - tirx::SLayout in_layout(attrs->layout, DataType::Int(64)); + tvm::PrimType i64_ty = tvm::PrimType::Int(64); + tirx::SLayout in_layout(attrs->layout, i64_ty); auto desired_layout = TransposeSubLayoutLike(attrs->layout, InitialLayout(4), layout->layout); auto data_si = GetType(call->args[0]); TensorType data_ty = data_si.as().value(); @@ -675,7 +676,8 @@ InferLayoutOutput InferLayoutAdaptiveAvgPool2D( LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); ffi::ObjectPtr new_attrs = ffi::make_object(*attrs); if (layout->layout.ndim() != layout->layout.ndim_primal()) { - tirx::SLayout in_layout(attrs->layout, DataType::Int(64)); + tvm::PrimType i64_ty = tvm::PrimType::Int(64); + tirx::SLayout in_layout(attrs->layout, i64_ty); auto desired_layout = TransposeSubLayoutLike(attrs->layout, InitialLayout(4), layout->layout); auto data_si = GetType(call->args[0]); TensorType data_ty = data_si.as().value(); diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 9c58ab769950..16e5d5f20d0e 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -409,9 +409,9 @@ static ffi::Optional InferCallTIROutputTypeFromArguments( TVM_FFI_ICHECK(packed_tuple_ty); PrimType dummy_arg_ty = [&]() { if (packed_tuple_ty->values) { - return PrimType(packed_tuple_ty->values.value()[i].dtype()); + return PrimType(packed_tuple_ty->values.value()[i].ty()); } else { - return PrimType(DataType::Int(64)); + return PrimType::Int(64); } }(); dummy_args.push_back(Var("dummy_trailing_arg", dummy_arg_ty)); @@ -1119,7 +1119,7 @@ Type InferTypeSize(const Call& call, const BlockBuilder& ctx) { auto* tensor_ty = GetType(call->args[0]).as(); TVM_FFI_ICHECK(tensor_ty) << "size expects a tensor input, but received " << arg_ty << "; use MatchCast if necessary"; - return TensorType(ShapeExpr(ffi::Array{}), DataType::Int(64)); + return TensorType(ShapeExpr(ffi::Array{}), PrimType::Int(64)); } TVM_REGISTER_OP("relax.size") @@ -1182,7 +1182,7 @@ Type ReturnShapeToTensorType(const Call& call, const BlockBuilder& ctx) { const auto* ty = GetTypeAs(call->args[0]); TVM_FFI_ICHECK(ty); int32_t ndim = ty->ndim; - return TensorType(ShapeExpr({PrimExpr(ndim)}), DataType::Int(64)); + return TensorType(ShapeExpr({PrimExpr(ndim)}), PrimType::Int(64)); } TVM_REGISTER_OP("relax.shape_to_tensor") @@ -1209,10 +1209,10 @@ Type InferTypeAllocateTensor(const Call& call, const BlockBuilder& ctx) { << "must be ShapeExpr, but got " << call->args[0]->GetTypeKey(); TVM_FFI_ICHECK(call->args[1].as()) << "must be DataTypeImm, but got " << call->args[1]->GetTypeKey(); - DataType out_dtype; + PrimType out_dtype = PrimType::Void(); if (const auto* dtype_node = call->args[1].as()) { const DataTypeImm dtype_imm = ffi::GetRef(dtype_node); - out_dtype = dtype_imm->value; + out_dtype = PrimType(dtype_imm->value); } int64_t vdevice_index = -1; if (auto* prim_value_node = call->args[2].as()) { @@ -1284,10 +1284,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { Type InferTypeMemAllocTensor(const Call& call, const BlockBuilder& ctx) { TVM_FFI_ICHECK(GetTypeAs(call->args[2])) << "must be a Expr of ShapeType, but got " << call->args[1]->GetTypeKey(); - DataType out_dtype; + PrimType out_dtype = PrimType::Void(); if (const auto* dtype_node = call->args[3].as()) { const DataTypeImm dtype_imm = ffi::GetRef(dtype_node); - out_dtype = dtype_imm->value; + out_dtype = PrimType(dtype_imm->value); } if (call->args.size() == 5) { @@ -1408,10 +1408,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { // vm alloc_tensor Type InferTypeVMAllocTensor(const Call& call, const BlockBuilder& ctx) { - DataType out_dtype; + PrimType out_dtype = PrimType::Void(); if (const auto* dtype_node = call->args[3].as()) { const DataTypeImm dtype_imm = ffi::GetRef(dtype_node); - out_dtype = dtype_imm->value; + out_dtype = PrimType(dtype_imm->value); } int64_t vdevice_index = -1; if (auto* prim_value_node = call->args[4].as()) { diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index cb0d6034e2d1..a19f59d4d56a 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -33,6 +33,7 @@ #include #include +#include #include #include @@ -184,14 +185,12 @@ std::tuple GetArgType(const Call& call, const BlockBuilder& ctx) { tvm::ffi::reflection::GlobalDef().def("relax.op." OpRegName, OpName); \ } -/************ Utilities ************/ - /*! * \brief Infer the type for unary elementwise ops. * \param call The context Call to the operator. * \param ctx The error reporting context. * \param f_compute_out_dtype The function to compute the output dtype, with - * signature DataType f_compute_out_dtype(const TensorType& input_ty). + * signature DLDataType or PrimType f_compute_out_dtype(const TensorType& input_ty). * \tparam require_float_dtype whether this op requires the input dtype to be float * \tparam Ftype the type of f_compute_out_dtype * \return The inferred type. @@ -199,15 +198,21 @@ std::tuple GetArgType(const Call& call, const BlockBuilder& ctx) { template inline Type InferTypeUnary(const Call& call, const BlockBuilder& ctx, FType f_compute_out_dtype) { TensorType input_ty = GetUnaryInputTensorType(call, ctx); + DLDataType input_dtype = input_ty->dtype->dtype; if (require_float_dtype && !input_ty->IsUnknownDtype() && - (!input_ty->dtype.is_float() && !input_ty->dtype.is_bfloat())) { + (input_dtype.code != kDLFloat && input_dtype.code != kDLBfloat)) { TVM_FFI_VISIT_THROW(TypeError, call) << call->op << " requires the input tensor to have float dtype. However, the given input dtype is " << input_ty->dtype; } auto output_ty = ffi::make_object(*input_ty.get()); - output_ty->dtype = f_compute_out_dtype(input_ty); + auto computed_dtype = f_compute_out_dtype(input_ty); + if constexpr (std::is_same_v, PrimType>) { + output_ty->dtype = computed_dtype; + } else { + output_ty->dtype = PrimType(computed_dtype); + } if (call->ty_args.size() > 0) { auto defined_ty = call->ty_args[0].as(); TVM_FFI_ICHECK(defined_ty); @@ -274,9 +279,9 @@ InferLayoutOutput InferLayoutUnaryEwise( * \return The inferred element dtype. * \throw Throw exception if the Type doesn't have an element type. */ -inline std::optional GetElementDType(const Type& ty) { +inline std::optional GetElementDType(const Type& ty) { if (const auto* prim = ty.as()) { - return prim->dtype; + return ffi::GetRef(prim); } else if (const auto* tensor = ty.as()) { return tensor->dtype; } else { @@ -296,8 +301,8 @@ inline std::optional GetElementDType(const Type& ty) { * \return The inferred output dtype. * \throw Throw exception if the dtype of two input TensorType don’t match */ -inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& ctx, - const Type& lhs_ty, const Type& rhs_ty) { +inline DLDataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& ctx, + const Type& lhs_ty, const Type& rhs_ty) { auto opt_lhs_dtype = GetElementDType(lhs_ty); if (!opt_lhs_dtype) { TVM_FFI_VISIT_THROW(TypeError, call) @@ -318,15 +323,17 @@ inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& } auto rhs_dtype = opt_rhs_dtype.value(); - if (lhs_dtype.is_void() || rhs_dtype.is_void()) { - return DataType::Void(); - } else if (lhs_dtype != rhs_dtype && !lhs_dtype.is_bool() && !rhs_dtype.is_bool()) { + if (lhs_dtype.IsVoid() || rhs_dtype.IsVoid()) { + return DLDataType{kDLOpaqueHandle, 0, 0}; + } else if (lhs_dtype->dtype != rhs_dtype->dtype && + !lhs_dtype.MatchesCode(DLDataTypeCode::kDLBool) && + !rhs_dtype.MatchesCode(DLDataTypeCode::kDLBool)) { TVM_FFI_VISIT_THROW(TypeError, call) << "Binary operators must have the same datatype for both operands. " << "However, " << call << " uses datatype " << lhs_dtype << " on the LHS (Type of " << lhs_ty << "), and datatype " << rhs_dtype << " on the RHS (Type of " << rhs_ty << ")."; } - return lhs_dtype; + return lhs_dtype->dtype; } /*! @@ -469,7 +476,7 @@ bool IsIdentityPermutation(const std::vector& permutation); */ inline ffi::Array ConvertIntImmToInt64(const ffi::Array& int_imms) { return int_imms.Map( - [](const IntImm& i) { return cast(DataType::Int(64), i).as_or_throw(); }); + [](const IntImm& i) { return cast(PrimType::Int(64), i).as_or_throw(); }); } /************ Utilities for NN operators ************/ @@ -560,8 +567,9 @@ inline ffi::Array GetCompletePadding3D(ffi::Array padding) { inline std::pair CheckTensorLayout( const Call& call, const BlockBuilder& ctx, const ffi::String& tensor_layout, const ffi::String& tgt_layout, const ffi::String& tensor_name) { - tirx::SLayout _tensor_layout(tensor_layout, DataType::Int(64)); - tirx::SBijectiveLayout tensor2tgt(_tensor_layout, tirx::SLayout(tgt_layout, DataType::Int(64))); + tvm::PrimType i64_ty = tvm::PrimType::Int(64); + tirx::SLayout _tensor_layout(tensor_layout, i64_ty); + tirx::SBijectiveLayout tensor2tgt(_tensor_layout, tirx::SLayout(tgt_layout, i64_ty)); if (!tensor2tgt.defined()) { TVM_FFI_VISIT_THROW(ValueError, call) << call->op << " requires the given " << tensor_name << " layout to be convertible from " diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index 84c411238473..cbc786de0f8e 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -51,11 +51,11 @@ Type InferTypeBroadcast(const Call& call, const BlockBuilder& ctx, FType f_compu << "Arguments to binary operators must be either R.Tensor or R.Prim types, " << "but expression " << call << " has RHS " << call->args[1] << ", which has Type " << rhs_ty; - // DateType - DataType output_dtype = f_compute_out_dtype(call, ctx, lhs_ty, rhs_ty); + // Dtype + PrimType output_dtype(f_compute_out_dtype(call, ctx, lhs_ty, rhs_ty)); if (lhs_ty.as() && rhs_ty.as()) { - return PrimType(output_dtype); + return output_dtype; } // VDevice @@ -136,7 +136,7 @@ Type InferTypeBroadcastArith(const Call& call, const BlockBuilder& ctx) { Type InferTypeBroadcastCMP(const Call& call, const BlockBuilder& ctx) { return InferTypeBroadcast(call, ctx, [](const Call& call, const BlockBuilder& ctx, const Type& lhs_ty, - const Type& rhs_ty) { return DataType::Bool(); }); + const Type& rhs_ty) { return DLDataType{kDLBool, 8, 1}; }); } InferLayoutOutput InferLayoutBinaryEwise( diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index e7a972896569..fbe3a0b0c534 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -46,7 +46,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { /* relax.full */ Expr full(ffi::Variant> shape, Expr fill_value, - ffi::Optional dtype) { + ffi::Optional dtype) { Expr shape_in_expr{nullptr}; if (const auto* expr = shape.as()) { shape_in_expr = ffi::GetRef(expr); @@ -59,7 +59,7 @@ Expr full(ffi::Variant> shape, Expr fill_value, } ffi::ObjectPtr attrs = ffi::make_object(); - attrs->dtype = dtype.value_or(DataType::Void()); + attrs->dtype = dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})); static const Op& op = Op::Get("relax.full"); return Call(op, {std::move(shape_in_expr), std::move(fill_value)}, Attrs(attrs), {}); @@ -88,7 +88,8 @@ Type InferTypeFull(const Call& call, const BlockBuilder& ctx) { } const auto* attrs = call->attrs.as(); - DataType out_dtype = attrs->dtype.is_void() ? fill_value_ty->dtype : attrs->dtype; + PrimType out_dtype = attrs->dtype == DLDataType{kDLOpaqueHandle, 0, 0} ? fill_value_ty->dtype + : PrimType(attrs->dtype); return TensorType(/*shape=*/call->args[0], out_dtype, fill_value_ty->vdevice); } @@ -104,9 +105,9 @@ TVM_REGISTER_OP("relax.full") .set_attr("FPurity", true); /* relax.full_like */ -Expr full_like(Expr x, Expr fill_value, ffi::Optional dtype) { +Expr full_like(Expr x, Expr fill_value, ffi::Optional dtype) { ffi::ObjectPtr attrs = ffi::make_object(); - attrs->dtype = dtype.value_or(DataType::Void()); + attrs->dtype = dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})); static const Op& op = Op::Get("relax.full_like"); return Call(op, {std::move(x), std::move(fill_value)}, Attrs(attrs), {}); } @@ -127,11 +128,11 @@ Type InferTypeFullLike(const Call& call, const BlockBuilder& ctx) { } const auto* attrs = call->attrs.as(); - if (attrs->dtype.is_void()) { + if (attrs->dtype == DLDataType{kDLOpaqueHandle, 0, 0}) { return data_ty; } else { auto output_ty = ffi::make_object(*data_ty.get()); - output_ty->dtype = attrs->dtype; + output_ty->dtype = PrimType(attrs->dtype); return TensorType(output_ty); } } @@ -158,25 +159,26 @@ Type InferTypeOnesZeros(const Call& call, const BlockBuilder& ctx) { << call->args[0]->ty->GetTypeKey(); } const auto* attrs = call->attrs.as(); - return TensorType(/*shape=*/call->args[0], attrs->dtype); + return TensorType(/*shape=*/call->args[0], PrimType(attrs->dtype)); } // Structure info inference for ones_like and zeros_like Type InferTypeOnesLikeZerosLike(const Call& call, const BlockBuilder& ctx) { TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); - if (attrs->dtype.is_void()) { + if (attrs->dtype == DLDataType{kDLOpaqueHandle, 0, 0}) { return data_ty; } else { auto output_ty = ffi::make_object(*data_ty.get()); - output_ty->dtype = attrs->dtype; + output_ty->dtype = PrimType(attrs->dtype); return TensorType(output_ty); } } /* relax.ones & relax.ones_like */ -Expr ones(Expr shape, DataType dtype) { - TVM_FFI_ICHECK(!dtype.is_void()) << "Ones op expects the input dtype not to be void"; +Expr ones(Expr shape, DLDataType dtype) { + TVM_FFI_ICHECK((dtype != DLDataType{kDLOpaqueHandle, 0, 0})) + << "Ones op expects the input dtype not to be void"; ffi::ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; @@ -184,9 +186,9 @@ Expr ones(Expr shape, DataType dtype) { return Call(op, {std::move(shape)}, Attrs(attrs), {}); } -Expr ones_like(Expr x, ffi::Optional dtype) { +Expr ones_like(Expr x, ffi::Optional dtype) { ffi::ObjectPtr attrs = ffi::make_object(); - attrs->dtype = dtype.value_or(DataType::Void()); + attrs->dtype = dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})); static const Op& op = Op::Get("relax.ones_like"); return Call(op, {std::move(x)}, Attrs(attrs), {}); } @@ -212,8 +214,9 @@ TVM_REGISTER_OP("relax.ones_like") .set_attr("FPurity", true); /* relax.zeros & relax.zeros_like */ -Expr zeros(Expr shape, DataType dtype) { - TVM_FFI_ICHECK(!dtype.is_void()) << "Zeros op expects the input dtype not to be void"; +Expr zeros(Expr shape, DLDataType dtype) { + TVM_FFI_ICHECK((dtype != DLDataType{kDLOpaqueHandle, 0, 0})) + << "Zeros op expects the input dtype not to be void"; ffi::ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; @@ -221,9 +224,9 @@ Expr zeros(Expr shape, DataType dtype) { return Call(op, {std::move(shape)}, Attrs(attrs), {}); } -Expr zeros_like(Expr x, ffi::Optional dtype) { +Expr zeros_like(Expr x, ffi::Optional dtype) { ffi::ObjectPtr attrs = ffi::make_object(); - attrs->dtype = dtype.value_or(DataType::Void()); + attrs->dtype = dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})); static const Op& op = Op::Get("relax.zeros_like"); return Call(op, {std::move(x)}, Attrs(attrs), {}); } @@ -249,16 +252,16 @@ TVM_REGISTER_OP("relax.zeros_like") .set_attr("FPurity", true); /* relax.eye & relax.eye_like */ -Expr eye(PrimValue n, PrimValue m, PrimValue k, DataType dtype) { +Expr eye(PrimValue n, PrimValue m, PrimValue k, DLDataType dtype) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.eye"); return Call(op, {std::move(n), std::move(m), std::move(k)}, Attrs(attrs), {}); } -Expr eye_like(Expr x, PrimValue k, ffi::Optional dtype) { +Expr eye_like(Expr x, PrimValue k, ffi::Optional dtype) { ffi::ObjectPtr attrs = ffi::make_object(); - attrs->dtype = dtype.value_or(DataType::Void()); + attrs->dtype = dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})); static const Op& op = Op::Get("relax.eye_like"); return Call(op, {std::move(x), std::move(k)}, Attrs(attrs), {}); } @@ -285,8 +288,8 @@ Type InferTypeEye(const Call& call, const BlockBuilder& ctx) { PrimExpr n = get_prim_value(call->args[0], "n"); PrimExpr m = get_prim_value(call->args[1], "m"); - DataType dtype = call->attrs.as()->dtype; - return TensorType(ShapeExpr({n, m}), dtype); + DLDataType dtype = call->attrs.as()->dtype; + return TensorType(ShapeExpr({n, m}), PrimType(dtype)); } Type InferTypeEyeLike(const Call& call, const BlockBuilder& ctx) { @@ -309,7 +312,8 @@ Type InferTypeEyeLike(const Call& call, const BlockBuilder& ctx) { } const auto* attrs = call->attrs.as(); - DataType out_dtype = attrs->dtype.is_void() ? x_ty->dtype : attrs->dtype; + PrimType out_dtype = + attrs->dtype == DLDataType{kDLOpaqueHandle, 0, 0} ? x_ty->dtype : PrimType(attrs->dtype); return TensorType(x_ty->shape.value(), out_dtype, x_ty->vdevice); } @@ -333,7 +337,7 @@ TVM_REGISTER_OP("relax.eye_like") .set_attr("FPurity", true); /* relax.arange */ -Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype) { +Expr arange(PrimValue start, PrimValue stop, PrimValue step, DLDataType dtype) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.arange"); @@ -362,17 +366,18 @@ Type InferTypeArange(const Call& call, const BlockBuilder& ctx) { PrimExpr start = get_prim_value(call->args[0], "start"); PrimExpr end = get_prim_value(call->args[1], "end"); PrimExpr step = get_prim_value(call->args[2], "step"); - DataType dtype = call->attrs.as()->dtype; + DLDataType dtype = call->attrs.as()->dtype; PrimExpr num_elem; - if (start.dtype().is_int() && end.dtype().is_int() && step.dtype().is_int()) { + if (start.ty().code() == DLDataTypeCode::kDLInt && end.ty().code() == DLDataTypeCode::kDLInt && + step.ty().code() == DLDataTypeCode::kDLInt) { num_elem = tvm::floordiv((end - start + step - 1), step); } else { - num_elem = tvm::cast(tvm::DataType::Int(64), - tvm::ceil(tvm::cast(tvm::DataType::Float(32), end - start) / step)); + num_elem = tvm::cast(tvm::PrimType::Int(64), + tvm::ceil(tvm::cast(tvm::PrimType::Float(32), end - start) / step)); } arith::Analyzer analyzer; num_elem = analyzer->Simplify(num_elem); - return TensorType(ShapeExpr({num_elem}), dtype); + return TensorType(ShapeExpr({num_elem}), PrimType(dtype)); } TVM_REGISTER_OP("relax.arange") @@ -387,7 +392,7 @@ TVM_REGISTER_OP("relax.arange") /* relax.hamming_window */ Expr hamming_window(PrimValue window_size, PrimValue periodic, PrimValue alpha, PrimValue beta, - DataType dtype) { + DLDataType dtype) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("relax.hamming_window"); @@ -401,8 +406,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { } Type InferTypeHammingWindow(const Call& call, const BlockBuilder& ctx) { - DataType dtype = call->attrs.as()->dtype; - if (dtype.is_int() || dtype.is_uint() || dtype.is_uint()) { + DLDataType dtype = call->attrs.as()->dtype; + if (dtype.code == DLDataTypeCode::kDLInt || dtype.code == DLDataTypeCode::kDLUInt) { TVM_FFI_VISIT_THROW(TypeError, call) << "Hamming Window expects the datatype to be float but got " << dtype; } @@ -422,7 +427,7 @@ Type InferTypeHammingWindow(const Call& call, const BlockBuilder& ctx) { << window_size; } window_size = analyzer->Simplify(window_size); - return TensorType(ShapeExpr({window_size}), dtype); + return TensorType(ShapeExpr({window_size}), PrimType(dtype)); } TVM_REGISTER_OP("relax.hamming_window") diff --git a/src/relax/op/tensor/create.h b/src/relax/op/tensor/create.h index 284448111739..497a535a4d0f 100644 --- a/src/relax/op/tensor/create.h +++ b/src/relax/op/tensor/create.h @@ -42,7 +42,7 @@ namespace relax { * \return The result tensor. */ Expr full(ffi::Variant> shape, Expr fill_value, - ffi::Optional dtype); + ffi::Optional dtype); /*! * \brief Construct a tensor such that @@ -55,7 +55,7 @@ Expr full(ffi::Variant> shape, Expr fill_value, * void, the input tensor's dtype will be used. * \return The result tensor. */ -Expr full_like(Expr x, Expr fill_value, ffi::Optional dtype); +Expr full_like(Expr x, Expr fill_value, ffi::Optional dtype); /*! * \brief Construct a tensor of all ones, with the input shape and dtype. @@ -63,7 +63,7 @@ Expr full_like(Expr x, Expr fill_value, ffi::Optional dtype); * \param dtype The data type of the created tensor. * \return The result tensor. */ -Expr ones(Expr shape, DataType dtype); +Expr ones(Expr shape, DLDataType dtype); /*! * \brief Construct a tensor with all ones, with shape of the input tensor shape. @@ -73,7 +73,7 @@ Expr ones(Expr shape, DataType dtype); * void, the input tensor's dtype will be used. * \return The result tensor. */ -Expr ones_like(Expr x, ffi::Optional dtype); +Expr ones_like(Expr x, ffi::Optional dtype); /*! * \brief Construct a tensor of all zeros, with the input shape and dtype. @@ -81,7 +81,7 @@ Expr ones_like(Expr x, ffi::Optional dtype); * \param dtype The data type of the created tensor. * \return The result tensor. */ -Expr zeros(Expr shape, DataType dtype); +Expr zeros(Expr shape, DLDataType dtype); /*! * \brief Construct a tensor with all zeros, with shape of the input tensor shape. @@ -91,7 +91,7 @@ Expr zeros(Expr shape, DataType dtype); * void, the input tensor's dtype will be used. * \return The result tensor. */ -Expr zeros_like(Expr x, ffi::Optional dtype); +Expr zeros_like(Expr x, ffi::Optional dtype); /*! * \brief Construct a 2-D tensor with ones on the diagonal and zeros elsewhere. @@ -102,7 +102,7 @@ Expr zeros_like(Expr x, ffi::Optional dtype); * \param dtype The data type of the created tensor. * \return The result tensor. */ -Expr eye(PrimValue n, PrimValue m, PrimValue k, DataType dtype); +Expr eye(PrimValue n, PrimValue m, PrimValue k, DLDataType dtype); /*! * \brief Construct a tensor with ones on the diagonal and zeros elsewhere, @@ -115,10 +115,10 @@ Expr eye(PrimValue n, PrimValue m, PrimValue k, DataType dtype); * void, the input tensor's dtype will be used. * \return The result tensor. */ -Expr eye_like(Expr x, PrimValue k, ffi::Optional dtype); +Expr eye_like(Expr x, PrimValue k, ffi::Optional dtype); /*! \brief Construct a tensor with evenly spaced elements. */ -Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype); +Expr arange(PrimValue start, PrimValue stop, PrimValue step, DLDataType dtype); /*! * \brief Hamming window function. @@ -131,7 +131,7 @@ Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype); * \return The result tensor. */ Expr hamming_window(PrimValue window_size, PrimValue periodic, PrimValue alpha, PrimValue beta, - DataType dtype); + DLDataType dtype); /*! \brief Return the lower triangular part of a matrix or a batch of matrices. */ Expr tril(Expr x, Expr k); diff --git a/src/relax/op/tensor/datatype.cc b/src/relax/op/tensor/datatype.cc index 907dffb0b3f3..ec1043a025e1 100644 --- a/src/relax/op/tensor/datatype.cc +++ b/src/relax/op/tensor/datatype.cc @@ -38,7 +38,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { /* relax.astype */ -Expr astype(Expr x, DataType dtype) { +Expr astype(Expr x, DLDataType dtype) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; @@ -55,7 +55,7 @@ Type InferTypeAstype(const Call& call, const BlockBuilder& ctx) { TensorType ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); ffi::ObjectPtr new_ty = ffi::make_object(*ty.get()); - new_ty->dtype = attrs->dtype; + new_ty->dtype = PrimType(attrs->dtype); return TensorType(new_ty); } @@ -70,7 +70,7 @@ TVM_REGISTER_OP("relax.astype") /* relax.wrap_param */ -Expr MakeWrapParam(Expr data, DataType dtype) { +Expr MakeWrapParam(Expr data, DLDataType dtype) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; @@ -87,7 +87,7 @@ Type InferTypeWrapParam(const Call& call, const BlockBuilder& ctx) { TensorType ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); ffi::ObjectPtr new_ty = ffi::make_object(*ty.get()); - new_ty->dtype = attrs->dtype; + new_ty->dtype = PrimType(attrs->dtype); return TensorType(new_ty); } diff --git a/src/relax/op/tensor/datatype.h b/src/relax/op/tensor/datatype.h index b612c45fc941..db2ee396c0d6 100644 --- a/src/relax/op/tensor/datatype.h +++ b/src/relax/op/tensor/datatype.h @@ -37,7 +37,7 @@ namespace relax { * \param dtype The target data type * \return The casted result. */ -Expr astype(Expr x, DataType dtype); +Expr astype(Expr x, DLDataType dtype); /*! * \brief A wrapper to wrap the input const tensor to the given data type. @@ -45,7 +45,7 @@ Expr astype(Expr x, DataType dtype); * \param dtype The target data type * \return The wrapped result. */ -Expr wrap_param(Expr x, DataType dtype); +Expr wrap_param(Expr x, DLDataType dtype); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index 515f37126183..e42feb0ae06c 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -72,7 +72,7 @@ Type InferTypeTake(const Call& call, const BlockBuilder& ctx) { if (auto tensor_ty = ty.as()) { return tensor_ty.value(); } else if (auto prim_ty = ty.as()) { - return TensorType(ShapeExpr(ffi::Array{}), prim_ty->dtype); + return TensorType(ShapeExpr(ffi::Array{}), ffi::GetRef(prim_ty)); } else { TVM_FFI_VISIT_THROW(TypeError, call) << "Operator " << call->op << " requires the indices argument to be " @@ -84,11 +84,15 @@ Type InferTypeTake(const Call& call, const BlockBuilder& ctx) { if (indices_ty->IsUnknownDtype()) { LOG(WARNING) << "Data type of indices has not been specified. Assume it has an integer type."; - } else if (!(indices_ty->dtype.is_int() || indices_ty->dtype.is_uint())) { - TVM_FFI_VISIT_THROW(TypeError, call) - << "Take op requires the input indices to have integer dtype. However, the " - "given indices dtype is " - << indices_ty->dtype; + } else { + PrimType indices_dtype = indices_ty->dtype; + if (!indices_dtype.MatchesCode(DLDataTypeCode::kDLInt) && + !indices_dtype.MatchesCode(DLDataTypeCode::kDLUInt)) { + TVM_FFI_VISIT_THROW(TypeError, call) + << "Take op requires the input indices to have integer dtype. However, the " + "given indices dtype is " + << indices_ty->dtype; + } } const auto* attrs = call->attrs.as(); @@ -309,7 +313,7 @@ Type InferTypeStridedSlice(const Call& call, const BlockBuilder& ctx) { } }(); - TVM_FFI_ICHECK(IsBaseOf(relax::TensorType(DataType::Void(), kUnknownNDim), GetType(data))) + TVM_FFI_ICHECK(IsBaseOf(relax::TensorType(PrimType::Void(), kUnknownNDim), GetType(data))) << "Operator " << call->op << " requires the first argument to be a tensor. " << "However, in expression " << call << ", the first argument " << data << " has type " << GetType(data); @@ -325,9 +329,8 @@ Type InferTypeStridedSlice(const Call& call, const BlockBuilder& ctx) { const auto* tuple = ty.as(); if (!tuple) return false; - return std::all_of(tuple->fields.begin(), tuple->fields.end(), [](const Type& field) { - return IsBaseOf(tvm::PrimType(DataType::Int(64)), field); - }); + return std::all_of(tuple->fields.begin(), tuple->fields.end(), + [](const Type& field) { return IsBaseOf(tvm::PrimType::Int(64), field); }); }; auto check_tuple = [&](const char* name, Expr expr) { auto ty = GetType(expr); @@ -347,7 +350,7 @@ Type InferTypeStridedSlice(const Call& call, const BlockBuilder& ctx) { const auto* data_ty = data->ty.as(); - DataType dtype = DataType::Void(); + PrimType dtype(DLDataType{kDLOpaqueHandle, 0, 0}); ffi::Optional vdevice = std::nullopt; int ndim = kUnknownNDim; if (data_ty) { @@ -545,7 +548,7 @@ Type InferTypeDynStridedSlice(const Call& call, const BlockBuilder& ctx) { LOG(WARNING) << "Dynamic strided slice assumes " << name << " to be int64 when it is not specified."; } else { - TVM_FFI_ICHECK(ty->dtype == DataType::Int(64)) + TVM_FFI_ICHECK(ty->dtype == PrimType::Int(64)) << "Dynamic strided_slice expects the input " << name << "values to be all int64. However, " << name << " has dtype " << ty->dtype << "."; } diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc index bf57670e7f2a..97955eb62455 100644 --- a/src/relax/op/tensor/inspect.cc +++ b/src/relax/op/tensor/inspect.cc @@ -88,24 +88,21 @@ std::tuple> GetTensorArgInfoWithIndex(const C return {ffi::GetRef(tensor_ty), int_imm_axis}; } -DataType GetTensorDataType(const Call& call) { return GetTensorArgInfo(call)->dtype; } +tirx::PrimFunc GetDLTensorField(tirx::builtin::TVMStructFieldKind field, PrimType field_ty) { + tirx::Var dlpack_handle("dlpack_handle", PrimType::Handle()); -tirx::PrimFunc GetDLTensorField(tirx::builtin::TVMStructFieldKind field, DataType field_dtype) { - tirx::Var dlpack_handle("dlpack_handle", DataType::Handle()); - - tirx::Var value("value", field_dtype); + tirx::Var value("value", field_ty); tirx::Stmt body = tirx::SeqStmt( - {tirx::Bind(value, tirx::Call(field_dtype, tirx::builtin::tvm_struct_get(), + {tirx::Bind(value, tirx::Call(field_ty, tirx::builtin::tvm_struct_get(), {dlpack_handle, IntImm::Int32(0), IntImm::Int32(field)})), tirx::Evaluate(tvm::ret(value))}); DictAttrs attrs({{"tirx.is_scheduled", true}, {"tirx.is_host_func", true}}); - tirx::PrimFunc func(ffi::Array{dlpack_handle}, body, tvm::PrimType(field_dtype), {}, - attrs); + tirx::PrimFunc func(ffi::Array{dlpack_handle}, body, field_ty, {}, attrs); - FuncType ty({TensorType(DataType::Void(), kUnknownNDim)}, PrimType(field_dtype)); + FuncType ty({TensorType(PrimType::Void(), kUnknownNDim)}, field_ty); func->ty = ty; return func; @@ -120,23 +117,14 @@ Expr tensor_dtype_code(Expr expr) { return Call(op, {expr}); } -Type InferTypeTensorDtypeCode(const Call& call, const BlockBuilder&) { - auto dlpack_type = DataType::UInt(8); - - DataType dtype = GetTensorDataType(call); - if (dtype.is_void()) { - return PrimType(dlpack_type); - } else { - return PrimType(dlpack_type); - } -} +Type InferTypeTensorDtypeCode(const Call& call, const BlockBuilder&) { return PrimType::UInt(8); } Expr LegalizeTensorDtypeCode(const BlockBuilder& bb, const Call& call) { - auto field_dtype = call->ty.as_or_throw()->dtype; + PrimType field_ty = call->ty.as_or_throw(); Expr arg = call->args[0]; tirx::PrimFunc getter = - GetDLTensorField(tirx::builtin::TVMStructFieldKind::kDLTensorTypeCode, field_dtype); + GetDLTensorField(tirx::builtin::TVMStructFieldKind::kDLTensorTypeCode, field_ty); GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_dtype_code"); return Call(gvar_getter, {arg}); @@ -158,23 +146,14 @@ Expr tensor_dtype_bits(Expr expr) { return Call(op, {expr}); } -Type InferTypeTensorDtypeBits(const Call& call, const BlockBuilder&) { - auto dlpack_type = DataType::UInt(8); - - DataType dtype = GetTensorDataType(call); - if (dtype.is_void()) { - return PrimType(dlpack_type); - } else { - return PrimType(dlpack_type); - } -} +Type InferTypeTensorDtypeBits(const Call& call, const BlockBuilder&) { return PrimType::UInt(8); } Expr LegalizeTensorDtypeBits(const BlockBuilder& bb, const Call& call) { - auto field_dtype = call->ty.as_or_throw()->dtype; + PrimType field_ty = call->ty.as_or_throw(); Expr arg = call->args[0]; tirx::PrimFunc getter = - GetDLTensorField(tirx::builtin::TVMStructFieldKind::kDLTensorTypeBits, field_dtype); + GetDLTensorField(tirx::builtin::TVMStructFieldKind::kDLTensorTypeBits, field_ty); GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_dtype_bits"); return Call(gvar_getter, {arg}); @@ -196,23 +175,14 @@ Expr tensor_dtype_lanes(Expr expr) { return Call(op, {expr}); } -Type InferTypeTensorDtypeLanes(const Call& call, const BlockBuilder&) { - auto dlpack_type = DataType::UInt(16); - - DataType dtype = GetTensorDataType(call); - if (dtype.is_void()) { - return PrimType(dlpack_type); - } else { - return PrimType(dlpack_type); - } -} +Type InferTypeTensorDtypeLanes(const Call& call, const BlockBuilder&) { return PrimType::UInt(16); } Expr LegalizeTensorDtypeLanes(const BlockBuilder& bb, const Call& call) { - auto field_dtype = call->ty.as_or_throw()->dtype; + PrimType field_ty = call->ty.as_or_throw(); Expr arg = call->args[0]; tirx::PrimFunc getter = - GetDLTensorField(tirx::builtin::TVMStructFieldKind::kDLTensorTypeLanes, field_dtype); + GetDLTensorField(tirx::builtin::TVMStructFieldKind::kDLTensorTypeLanes, field_ty); GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_dtype_lanes"); return Call(gvar_getter, {arg}); @@ -234,23 +204,14 @@ Expr tensor_ndim(Expr expr) { return Call(op, {expr}); } -Type InferTypeTensorNDim(const Call& call, const BlockBuilder&) { - auto dlpack_type = DataType::Int(32); - - auto ty = GetTensorArgInfo(call); - if (ty->IsUnknownNdim()) { - return PrimType(dlpack_type); - } else { - return PrimType(dlpack_type); - } -} +Type InferTypeTensorNDim(const Call& call, const BlockBuilder&) { return PrimType::Int(32); } Expr LegalizeTensorNDim(const BlockBuilder& bb, const Call& call) { - auto field_dtype = call->ty.as_or_throw()->dtype; + PrimType field_ty = call->ty.as_or_throw(); Expr arg = call->args[0]; tirx::PrimFunc getter = - GetDLTensorField(tirx::builtin::TVMStructFieldKind::kDLTensorNDim, field_dtype); + GetDLTensorField(tirx::builtin::TVMStructFieldKind::kDLTensorNDim, field_ty); GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_ndim"); return Call(gvar_getter, {arg}); @@ -273,45 +234,45 @@ Expr tensor_shape_i(Expr expr) { } Type InferTypeTensorShape(const Call& call, const BlockBuilder&) { - auto dlpack_type = DataType::Int(64); + auto dlpack_type = PrimType::Int(64); auto [tensor_ty, int_imm_axis] = GetTensorArgInfoWithIndex(call); auto tensor_shape = tensor_ty->GetShape(); if (int_imm_axis && tensor_shape.defined()) { - return PrimType(tensor_shape.value()[int_imm_axis.value()].dtype()); + return tensor_shape.value()[int_imm_axis.value()].ty(); } else { - return PrimType(dlpack_type); + return dlpack_type; } } Expr LegalizeTensorShape(const BlockBuilder& bb, const Call& call) { - auto field_dtype = call->ty.as_or_throw()->dtype; + PrimType field_ty = call->ty.as_or_throw(); tirx::PrimFunc getter = [&]() -> tirx::PrimFunc { - tirx::Var dlpack_handle("dlpack_handle", DataType::Handle()); - tirx::Var axis("axis", DataType::Int(64)); + tirx::Var dlpack_handle("dlpack_handle", PrimType::Handle()); + tirx::Var axis("axis", PrimType::Int(64)); - tirx::Var ndim("ndim", DataType::Int(32)); + tirx::Var ndim("ndim", PrimType::Int(32)); - tirx::Buffer shape_buffer = tirx::decl_buffer({ndim}, field_dtype, "shape"); + tirx::Buffer shape_buffer = tirx::decl_buffer({ndim}, field_ty, "shape"); - tirx::Var extent("extent", field_dtype); + tirx::Var extent("extent", field_ty); tirx::Stmt body = tirx::SeqStmt( {tirx::AssertStmt(0 <= axis, tirx::StringImm("RuntimeError"), {tirx::StringImm("Specified axis may not be negative")}), tirx::Bind(ndim, - tirx::Call(ndim->dtype, tirx::builtin::tvm_struct_get(), + tirx::Call(ndim.ty(), tirx::builtin::tvm_struct_get(), {dlpack_handle, IntImm::Int32(0), IntImm::Int32(tirx::builtin::TVMStructFieldKind::kDLTensorNDim)})), tirx::AssertStmt( - axis < tvm::cast(axis->dtype, ndim), tirx::StringImm("RuntimeError"), + axis < tvm::cast(axis.ty(), ndim), tirx::StringImm("RuntimeError"), {tirx::StringImm( "Specified axis may not be larger than the tensor's dimensionality")}), tirx::Bind(shape_buffer->data, - tirx::Call(DataType::Handle(), tirx::builtin::tvm_struct_get(), + tirx::Call(tvm::PrimType::Handle(), tirx::builtin::tvm_struct_get(), {dlpack_handle, IntImm::Int32(0), IntImm::Int32(tirx::builtin::TVMStructFieldKind::kDLTensorShape)})), tirx::DeclBuffer(shape_buffer), tirx::Bind(extent, tirx::BufferLoad(shape_buffer, {axis})), @@ -319,10 +280,9 @@ Expr LegalizeTensorShape(const BlockBuilder& bb, const Call& call) { DictAttrs attrs({{"tirx.is_scheduled", true}, {"tirx.is_host_func", true}}); - tirx::PrimFunc func({dlpack_handle, axis}, body, tvm::PrimType(field_dtype), {}, attrs); + tirx::PrimFunc func({dlpack_handle, axis}, body, field_ty, {}, attrs); - FuncType ty({TensorType(DataType::Void(), kUnknownNDim), PrimType(axis->dtype)}, - PrimType(field_dtype)); + FuncType ty({TensorType(PrimType::Void(), kUnknownNDim), axis.ty()}, field_ty); func->ty = ty; return func; }(); @@ -349,7 +309,7 @@ Expr tensor_stride_i(Expr expr) { } Type InferTypeTensorStride(const Call& call, const BlockBuilder&) { - auto dlpack_type = DataType::Int(64); + auto dlpack_type = PrimType::Int(64); auto [tensor_ty, int_imm_axis] = GetTensorArgInfoWithIndex(call); @@ -373,9 +333,9 @@ Type InferTypeTensorStride(const Call& call, const BlockBuilder&) { for (size_t axis = int_imm_axis.value() + 1; axis < tensor_shape.size(); axis++) { stride = stride * tensor_shape[axis]; } - return PrimType(stride.dtype()); + return stride.ty(); } else { - return PrimType(dlpack_type); + return dlpack_type; } } @@ -396,7 +356,7 @@ Expr tensor_byte_offset(Expr expr) { } Type InferTypeTensorByteOffset(const Call& call, const BlockBuilder&) { - auto dlpack_type = DataType::UInt(64); + auto dlpack_type = PrimType::UInt(64); auto tensor_ty = GetTensorArgInfo(call); @@ -405,9 +365,9 @@ Type InferTypeTensorByteOffset(const Call& call, const BlockBuilder&) { // Relax implicitly requires that the byte offset is zero for any // legalizable tensor. See InferTypeTensorStride for full // explanation. - return PrimType(dlpack_type); + return dlpack_type; } else { - return PrimType(dlpack_type); + return dlpack_type; } } @@ -427,7 +387,7 @@ Expr tensor_elem_offset(Expr expr) { } Type InferTypeTensorElemOffset(const Call& call, const BlockBuilder&) { - auto dlpack_type = DataType::UInt(64); + auto dlpack_type = PrimType::UInt(64); auto tensor_ty = GetTensorArgInfo(call); @@ -436,9 +396,9 @@ Type InferTypeTensorElemOffset(const Call& call, const BlockBuilder&) { // Relax implicitly requires that the element offset is zero for // any legalizable tensor. See InferTypeTensorStride for // full explanation. - return PrimType(dlpack_type); + return dlpack_type; } else { - return PrimType(dlpack_type); + return dlpack_type; } } diff --git a/src/relax/op/tensor/inspect.h b/src/relax/op/tensor/inspect.h index 3f820ab58a83..92cc4c256c79 100644 --- a/src/relax/op/tensor/inspect.h +++ b/src/relax/op/tensor/inspect.h @@ -36,7 +36,7 @@ namespace inspect { * `TensorType`. * * \returns The uint8_t value of the type_code, with - * `PrimType(DataType::UInt(8))` + * `PrimType::UInt(8)` */ Expr tensor_dtype_code(Expr expr); @@ -46,7 +46,7 @@ Expr tensor_dtype_code(Expr expr); * `TensorType`. * * \returns The uint8_t value of the number of bits, with - * `PrimType(DataType::UInt(8))`. For vectorized types, returns + * `PrimType::UInt(8)`. For vectorized types, returns * the bit width of the underlying scalar type (e.g. 32 for * "float32x4", not 128). */ @@ -58,7 +58,7 @@ Expr tensor_dtype_bits(Expr expr); * `TensorType`. * * \returns The uint16_t value of the number of lanes, with - * `PrimType(DataType::UInt(16))` + * `PrimType::UInt(16)` */ Expr tensor_dtype_lanes(Expr expr); @@ -68,7 +68,7 @@ Expr tensor_dtype_lanes(Expr expr); * `TensorType`. * * \returns The int32_t value of the dimensionality, with - * `PrimType(DataType::Int(32))`. + * `PrimType::Int(32)`. */ Expr tensor_ndim(Expr expr); @@ -81,7 +81,7 @@ Expr tensor_ndim(Expr expr); * axis < tensor_ndim(expr)`, or else the results are undefined. * * \returns The int64_t extent of the specified tensor axis, with - * `PrimType(DataType::Int(64))`. + * `PrimType::Int(64)`. */ Expr tensor_shape_i(Expr expr, Expr axis); @@ -98,7 +98,7 @@ Expr tensor_shape_i(Expr expr, Expr axis); * axis < tensor_ndim(expr)`, or else the results are undefined. * * \returns The int64_t extent of the specified tensor axis, with - * `PrimType(DataType::Int(64))`. + * `PrimType::Int(64)`. */ Expr tensor_stride_i(Expr expr, Expr axis); @@ -107,7 +107,7 @@ Expr tensor_stride_i(Expr expr, Expr axis); * \param expr The relax expression to be inspected. Must have * `TensorType`. * - * \returns The uint64_t byte offset, with `PrimType(DataType::UInt(64))`. + * \returns The uint64_t byte offset, with `PrimType::UInt(64)`. */ Expr tensor_byte_offset(Expr expr); @@ -120,7 +120,7 @@ Expr tensor_byte_offset(Expr expr); * \param expr The relax expression to be inspected. Must have * `TensorType`. * - * \returns The uint64_t element offset, with `PrimType(DataType::UInt(64))`. + * \returns The uint64_t element offset, with `PrimType::UInt(64)`. */ Expr tensor_elem_offset(Expr expr); diff --git a/src/relax/op/tensor/linear_algebra.cc b/src/relax/op/tensor/linear_algebra.cc index a1693c6563f2..6ea68b422378 100644 --- a/src/relax/op/tensor/linear_algebra.cc +++ b/src/relax/op/tensor/linear_algebra.cc @@ -42,9 +42,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { /* relax.matmul */ -Expr matmul(Expr x1, Expr x2, ffi::Optional out_dtype) { +Expr matmul(Expr x1, Expr x2, ffi::Optional out_dtype) { ffi::ObjectPtr attrs = ffi::make_object(); - attrs->out_dtype = out_dtype.value_or(DataType::Void()); + attrs->out_dtype = out_dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})); static const Op& op = Op::Get("relax.matmul"); return Call(op, {std::move(x1), std::move(x2)}, Attrs(attrs), {}); @@ -74,9 +74,9 @@ Type InferTypeMatmul(const Call& call, const BlockBuilder& ctx) { } const auto* attrs = call->attrs.as(); - DataType out_dtype = attrs->out_dtype.is_void() - ? InferBinaryArithOpOutDtype(call, ctx, x1_ty, x2_ty) - : attrs->out_dtype; + PrimType out_dtype = PrimType(attrs->out_dtype == DLDataType{kDLOpaqueHandle, 0, 0} + ? InferBinaryArithOpOutDtype(call, ctx, x1_ty, x2_ty) + : attrs->out_dtype); if (x1_ty->IsUnknownNdim() || x2_ty->IsUnknownNdim()) { if (vdev.defined()) { @@ -158,7 +158,7 @@ Type InferTypeMatmul(const Call& call, const BlockBuilder& ctx) { return TensorType(ShapeExpr(output_shape), out_dtype); } -Call InferMixedPrecisionMatmul(const Call& call, const DataType& out_dtype) { +Call InferMixedPrecisionMatmul(const Call& call, DLDataType out_dtype) { return matmul(call->args[0], call->args[1], out_dtype).as_or_throw(); } @@ -218,17 +218,17 @@ Type InferTypeEinsum(const Call& call, const BlockBuilder& ctx) { ffi::String subscripts = attrs->subscripts; - DataType operand_dtype = operands_tensor_ty[0]->dtype; + PrimType operand_ty = operands_tensor_ty[0]->dtype; std::vector> input_shapes; input_shapes.reserve(operands_tensor_ty.size()); for (TensorType tensor_ty : operands_tensor_ty) { // Check the input tuple consists of tensors with same dtype - if (tensor_ty->dtype != operand_dtype) { + if (tensor_ty->dtype != operand_ty) { TVM_FFI_VISIT_THROW(TypeError, call) << "Einsum expects all input tensors to have the same dtype. However, the " "input contains tensors with dtype " - << operand_dtype << " and " << tensor_ty->dtype; + << operand_ty << " and " << tensor_ty->dtype; } // Get input shapes @@ -237,18 +237,18 @@ Type InferTypeEinsum(const Call& call, const BlockBuilder& ctx) { input_shapes.push_back(shape_expr->values); } else { if (!vdevice_unknown) { - return TensorType(operand_dtype, tensor_ty->ndim, vdev); + return TensorType(operand_ty, tensor_ty->ndim, vdev); } - return TensorType(operand_dtype, tensor_ty->ndim); + return TensorType(operand_ty, tensor_ty->ndim); } } // Calculate output shape using InferEinsumShape in topi ffi::Array oshape = topi::InferEinsumShape(subscripts, input_shapes); if (!vdevice_unknown) { - return TensorType(ShapeExpr(oshape), operand_dtype, vdev); + return TensorType(ShapeExpr(oshape), operand_ty, vdev); } - return TensorType(ShapeExpr(oshape), operand_dtype); + return TensorType(ShapeExpr(oshape), operand_ty); } TVM_REGISTER_OP("relax.einsum") diff --git a/src/relax/op/tensor/linear_algebra.h b/src/relax/op/tensor/linear_algebra.h index ddfceae4dc35..481193f897b8 100644 --- a/src/relax/op/tensor/linear_algebra.h +++ b/src/relax/op/tensor/linear_algebra.h @@ -41,7 +41,7 @@ namespace relax { * When it is not specified, the output dtype will be the same as input dtype. * \return The computed result. */ -Expr matmul(Expr x1, Expr x2, ffi::Optional out_dtype); +Expr matmul(Expr x1, Expr x2, ffi::Optional out_dtype); /*! * \brief Einstein summation on the operands. diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index caa730091383..8fe14c78555f 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -35,7 +35,7 @@ #include #include -#include "tvm/runtime/data_type.h" +#include "tvm/ffi/dtype.h" namespace tvm { namespace relax { @@ -219,7 +219,7 @@ Type InferTypeConcat(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); int output_ndim = attrs->axis.has_value() ? kUnknownNDim : 1; - DataType output_dtype = DataType::Void(); + PrimType output_dtype = PrimType::Void(); ffi::Optional vdev = std::nullopt; bool shape_unknown = false; bool is_void_dtype = false; @@ -229,9 +229,9 @@ Type InferTypeConcat(const Call& call, const BlockBuilder& ctx) { for (TensorType ty : tensor_ty) { // Update the output dtype. - if (ty->dtype.is_void()) { + if (ty->IsUnknownDtype()) { is_void_dtype = true; - } else if (output_dtype.is_void()) { + } else if (output_dtype.IsVoid()) { output_dtype = ty->dtype; } else if (ty->dtype != output_dtype) { TVM_FFI_VISIT_THROW(TypeError, call) @@ -285,7 +285,7 @@ Type InferTypeConcat(const Call& call, const BlockBuilder& ctx) { } if (is_void_dtype) { - output_dtype = DataType::Void(); + output_dtype = PrimType::Void(); } if (vdevice_unknown) { vdev = std::nullopt; @@ -573,14 +573,16 @@ Type InferTypeIndexTensor(const Call& call, const BlockBuilder& ctx) { << "index_tensor expects a non‑empty tuple of index tensors"; } - DataType output_dtype = data_ty->dtype; + PrimType output_dtype = data_ty->dtype; int n_indices = static_cast(indices_ty.size()); ffi::Optional vdev = data_ty->vdevice; // Indices must be integers for (int i = 0; i < n_indices; ++i) { const auto& s = indices_ty[i]; - if (!s->IsUnknownDtype() && !s->dtype.is_int()) { + PrimType index_dtype = s->dtype; + // Indexing only requires integer element kind; vector lanes do not affect shape inference. + if (!s->IsUnknownDtype() && index_dtype.code() != DLDataTypeCode::kDLInt) { TVM_FFI_VISIT_THROW(TypeError, call) << "index_tensor requires every index tensor to have an integer dtype; " << "index " << i << " has dtype " << s->dtype; @@ -725,9 +727,10 @@ Type InferTypeLayoutTransform(const Call& call, const BlockBuilder& ctx) { // Check pad_value has same dtype as input. if (optional_pad_value.defined()) { PrimExpr padded_value = optional_pad_value.value()->value; - if (padded_value->dtype != data_ty->dtype) { + PrimType padded_dtype = padded_value.ty(); + if (padded_dtype != data_ty->dtype) { TVM_FFI_VISIT_THROW(TypeError, call) - << "layout_transform pad_value dtype (" << padded_value->dtype << ") and input dtype (" + << "layout_transform pad_value dtype (" << padded_dtype << ") and input dtype (" << data_ty->dtype << ") must be the same"; } } @@ -916,9 +919,10 @@ Expr ConvertNewShapeToExpr(const Expr& data, "Array of PrimExprs. However, the given new shape is " << shape; PrimExpr len = ffi::GetRef(_len); - TVM_FFI_ICHECK(len->dtype.is_int()) << "Reshape requires the new shape values to be all " - "integers. However, the give new shape is " - << shape; + TVM_FFI_ICHECK(len.ty().code() == DLDataTypeCode::kDLInt) + << "Reshape requires the new shape values to be all " + "integers. However, the give new shape is " + << shape; const auto* int_len = len.as(); if (int_len != nullptr && int_len->value == 0) { // Note that this dimension should be copied from the original shape. @@ -1108,7 +1112,7 @@ Type InferTypeSplit(const Call& call, const BlockBuilder& ctx) { TVM_FFI_ICHECK_NE(axis, -1); - IntImm zero(DataType::Int(64), /*value=*/0); + IntImm zero(tvm::PrimType::Int(64), /*value=*/0); std::vector output_ty; for (size_t i = 0; i < p_indices.size() + 1; i++) { @@ -1489,7 +1493,7 @@ Type InferTypeStack(const Call& call, const BlockBuilder& ctx) { // Default axis is 0 if not specified int output_ndim = tensor_ty[0]->ndim + 1; // Stack adds one dimension - DataType output_dtype = DataType::Void(); + PrimType output_dtype = PrimType::Void(); ffi::Optional vdev = std::nullopt; bool shape_unknown = false; bool is_void_dtype = false; @@ -1499,9 +1503,9 @@ Type InferTypeStack(const Call& call, const BlockBuilder& ctx) { for (TensorType ty : tensor_ty) { // Check dtype consistency - if (ty->dtype.is_void()) { + if (ty->IsUnknownDtype()) { is_void_dtype = true; - } else if (output_dtype.is_void()) { + } else if (output_dtype.IsVoid()) { output_dtype = ty->dtype; } else if (ty->dtype != output_dtype) { TVM_FFI_VISIT_THROW(TypeError, call) @@ -1542,7 +1546,7 @@ Type InferTypeStack(const Call& call, const BlockBuilder& ctx) { } } - if (is_void_dtype) output_dtype = DataType::Void(); + if (is_void_dtype) output_dtype = PrimType::Void(); if (vdevice_unknown) vdev = std::nullopt; // Normalize axis (default to 0 if not specified) @@ -1650,7 +1654,7 @@ Type InferTypeCollapseSumLike(const Call& call, const BlockBuilder& ctx) { TensorType data_ty = input_ty[0]; TensorType collapse_target_ty = input_ty[1]; - DataType output_dtype = data_ty->dtype; + PrimType output_dtype = data_ty->dtype; ffi::Optional> data_shape_value; if (data_ty->shape.defined()) { @@ -1711,7 +1715,7 @@ Type InferTypeCollapseSumTo(const Call& call, const BlockBuilder& ctx) { << call->args[1]->ty->GetTypeKey(); } - DataType output_dtype = data_ty->dtype; + PrimType output_dtype = data_ty->dtype; ffi::Optional> data_shape_value; if (data_ty->shape.defined()) { @@ -2192,7 +2196,9 @@ Type InferTypeGatherElements(const Call& call, const BlockBuilder& ctx) { << call->args[1]->ty->GetTypeKey(); } - if (!indices_ty->IsUnknownDtype() && !indices_ty->dtype.is_int()) { + PrimType indices_dtype = indices_ty->dtype; + // Gather indices only require integer element kind; vector lanes do not affect shape inference. + if (!indices_ty->IsUnknownDtype() && indices_dtype.code() != DLDataTypeCode::kDLInt) { TVM_FFI_VISIT_THROW(TypeError, call) << "GatherElements requires the input indices to have int64 dtype. However, the " << "given indices dtype is " << indices_ty->dtype; @@ -2295,7 +2301,7 @@ Type InferTypeGatherND(const Call& call, const BlockBuilder& ctx) { TVM_FFI_ICHECK_GE(attrs->batch_dims, 0); int batch_dims = static_cast(attrs->batch_dims); int input_dims = data_ty->ndim; - if (!indices_ty->IsUnknownDtype() && indices_ty->dtype != DataType::Int(64)) { + if (!indices_ty->IsUnknownDtype() && indices_ty->dtype != PrimType::Int(64)) { TVM_FFI_VISIT_THROW(TypeError, call) << "GatherND requires the input indices to have int64 dtype. However, the " << "given indices dtype is " << indices_ty->dtype; @@ -2430,10 +2436,14 @@ Type InferTypeIndexPut(const Call& call, const BlockBuilder& ctx) { if (tensor_ty->IsUnknownDtype()) { LOG(WARNING) << "Data type of index tensor " << i << " has not been specified. Assume it has an integer type."; - } else if (!(tensor_ty->dtype.is_int() || tensor_ty->dtype.is_uint())) { - TVM_FFI_VISIT_THROW(TypeError, call) - << "IndexPut requires each index tensor to have integer dtype. " - << "However, index tensor " << i << " has dtype=" << tensor_ty->dtype; + } else { + PrimType index_dtype = tensor_ty->dtype; + if (!index_dtype.MatchesCode(DLDataTypeCode::kDLInt) && + !index_dtype.MatchesCode(DLDataTypeCode::kDLUInt)) { + TVM_FFI_VISIT_THROW(TypeError, call) + << "IndexPut requires each index tensor to have integer dtype. " + << "However, index tensor " << i << " has dtype=" << tensor_ty->dtype; + } } } @@ -2531,7 +2541,7 @@ Type InferTypeMeshgrid(const Call& call, const BlockBuilder& ctx) { } std::vector lengths; - DataType common_dtype = DataType::Void(); + PrimType common_dtype = PrimType::Void(); bool shape_unknown = false; ffi::Optional vdev = std::nullopt; bool vdevice_unknown = false; @@ -2545,9 +2555,9 @@ Type InferTypeMeshgrid(const Call& call, const BlockBuilder& ctx) { << i; } - if (ty->dtype.is_void()) { + if (ty->IsUnknownDtype()) { continue; - } else if (common_dtype.is_void()) { + } else if (common_dtype.IsVoid()) { common_dtype = ty->dtype; } else if (ty->dtype != common_dtype) { TVM_FFI_VISIT_THROW(TypeError, call) @@ -2683,11 +2693,15 @@ Type InferTypeScatterElements(const Call& call, const BlockBuilder& ctx) { if (indices_ty->IsUnknownDtype()) { LOG(WARNING) << "Data type of indices has not been specified. Assume it has an integer type."; - } else if (!(indices_ty->dtype.is_int() || indices_ty->dtype.is_uint())) { - TVM_FFI_VISIT_THROW(TypeError, call) - << "ScatterElements op requires the input indices to have integer dtype. However, the " - "given indices dtype is " - << indices_ty->dtype; + } else { + PrimType indices_dtype = indices_ty->dtype; + if (!indices_dtype.MatchesCode(DLDataTypeCode::kDLInt) && + !indices_dtype.MatchesCode(DLDataTypeCode::kDLUInt)) { + TVM_FFI_VISIT_THROW(TypeError, call) + << "ScatterElements op requires the input indices to have integer dtype. However, the " + "given indices dtype is " + << indices_ty->dtype; + } } const auto* indices_shape = indices_ty->shape.as(); @@ -2803,11 +2817,15 @@ Type InferTypeScatterND(const Call& call, const BlockBuilder& ctx) { if (indices_ty->IsUnknownDtype()) { LOG(WARNING) << "Data type of indices has not been specified. Assume it has an integer type."; - } else if (!(indices_ty->dtype.is_int() || indices_ty->dtype.is_uint())) { - TVM_FFI_VISIT_THROW(TypeError, call) - << "ScatterND op requires the input indices to have integer dtype. However, " - "the given indices dtype is " - << indices_ty->dtype; + } else { + PrimType indices_dtype = indices_ty->dtype; + if (!indices_dtype.MatchesCode(DLDataTypeCode::kDLInt) && + !indices_dtype.MatchesCode(DLDataTypeCode::kDLUInt)) { + TVM_FFI_VISIT_THROW(TypeError, call) + << "ScatterND op requires the input indices to have integer dtype. However, " + "the given indices dtype is " + << indices_ty->dtype; + } } const auto* data_shape = data_ty->shape.as(); @@ -3003,10 +3021,11 @@ Type InferTypeSliceScatter(const Call& call, const BlockBuilder& ctx) { << ") to be a PrimValue, but got " << arg_expr->GetTypeKey(); } const PrimExpr& prim_expr = prim_value_node->value; - if (!prim_expr.dtype().is_int() && !prim_expr.dtype().is_uint()) { + tvm::PrimType prim_ty = prim_expr.ty(); + if (prim_ty.code() != DLDataTypeCode::kDLInt && prim_ty.code() != DLDataTypeCode::kDLUInt) { TVM_FFI_VISIT_THROW(TypeError, call) << "SliceScatter expects `" << key << "` (" << prim_expr - << ") to be an integer PrimValue, but got dtype " << prim_expr.dtype(); + << ") to be an integer PrimValue, but got dtype " << prim_ty; } return prim_expr; }; @@ -3085,8 +3104,8 @@ Expr one_hot(Expr indices, PrimValue on_value, PrimValue off_value, int depth, i attrs->axis = axis; // Check if on_value and off_value have the same dtype - DataType on_dtype = on_value->value->dtype; - DataType off_dtype = off_value->value->dtype; + PrimType on_dtype = on_value->value.ty(); + PrimType off_dtype = off_value->value.ty(); TVM_FFI_ICHECK(on_dtype == off_dtype) << "one_hot: on_value and off_value must have the same dtype, " << "but got " << on_dtype << " and " << off_dtype; @@ -3108,19 +3127,25 @@ Type InferTypeOneHot(const Call& call, const BlockBuilder& ctx) { PrimValue on_value = call->args[1].as_or_throw(); PrimValue off_value = call->args[2].as_or_throw(); // Check if on_value and off_value have the same dtype - TVM_FFI_ICHECK(on_value->value->dtype == off_value->value->dtype) + PrimType on_dtype = on_value->value.ty(); + PrimType off_dtype = off_value->value.ty(); + TVM_FFI_ICHECK(on_dtype == off_dtype) << "one_hot: on_value and off_value must have the same dtype, " - << "but got " << on_value->value->dtype << " and " << off_value->value->dtype; - DataType dtype = on_value->value->dtype; + << "but got " << on_dtype << " and " << off_dtype; + PrimType dtype = on_dtype; // Check if indices has an integer dtype if (indices_ty->IsUnknownDtype()) { LOG(WARNING) << "Data type of indices has not been specified. Assume it has an integer type."; - } else if (!(indices_ty->dtype.is_int() || indices_ty->dtype.is_uint())) { - TVM_FFI_VISIT_THROW(TypeError, call) - << "one_hot op requires the input indices to have integer dtype. However, the " - "given indices dtype is " - << indices_ty->dtype; + } else { + PrimType indices_dtype = indices_ty->dtype; + if (!indices_dtype.MatchesCode(DLDataTypeCode::kDLInt) && + !indices_dtype.MatchesCode(DLDataTypeCode::kDLUInt)) { + TVM_FFI_VISIT_THROW(TypeError, call) + << "one_hot op requires the input indices to have integer dtype. However, the " + "given indices dtype is " + << indices_ty->dtype; + } } // Check if indices has unknown dimension if (indices_ty->IsUnknownNdim()) { diff --git a/src/relax/op/tensor/qdq.cc b/src/relax/op/tensor/qdq.cc index 974d70e7300a..8940594abc51 100644 --- a/src/relax/op/tensor/qdq.cc +++ b/src/relax/op/tensor/qdq.cc @@ -39,7 +39,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { QuantizeAttrs::RegisterReflection(); } /* relax.quantize */ -Expr quantize(Expr data, Expr scale, Expr zero_point, int axis, DataType out_dtype) { +Expr quantize(Expr data, Expr scale, Expr zero_point, int axis, DLDataType out_dtype) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->axis = axis; attrs->out_dtype = out_dtype; @@ -54,9 +54,14 @@ TVM_FFI_STATIC_INIT_BLOCK() { Type InferTypeQuantize(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); - if (attrs->out_dtype != DataType::Int(8) && attrs->out_dtype != DataType::UInt(8) && - attrs->out_dtype != DataType::Int(16) && attrs->out_dtype != DataType::UInt(16) && - attrs->out_dtype != DataType::Float8E4M3FN() && attrs->out_dtype != DataType::Float8E5M2()) { + if (attrs->out_dtype != DLDataType{kDLInt, 8, 1} && + attrs->out_dtype != DLDataType{kDLUInt, 8, 1} && + attrs->out_dtype != DLDataType{kDLInt, 16, 1} && + attrs->out_dtype != DLDataType{kDLUInt, 16, 1} && + attrs->out_dtype != DLDataType{static_cast(kDLFloat8_e4m3fn), + static_cast(8), static_cast(1)} && + attrs->out_dtype != DLDataType{static_cast(kDLFloat8_e5m2), static_cast(8), + static_cast(1)}) { TVM_FFI_VISIT_THROW(TypeError, call) << "Unsupported output datatype attribute for operation: '" << attrs->out_dtype; } @@ -64,24 +69,27 @@ Type InferTypeQuantize(const Call& call, const BlockBuilder& ctx) { TensorType input_ty = GetInputTensorType(call, ctx)[0]; TensorType scale_ty = GetInputTensorType(call, ctx)[1]; TensorType zp_ty = GetInputTensorType(call, ctx)[2]; + PrimType input_dtype = input_ty->dtype; + PrimType scale_dtype = scale_ty->dtype; + PrimType zp_dtype = zp_ty->dtype; // Check input datatype: - if (input_ty->dtype != DataType::Float(16) && input_ty->dtype != DataType::Float(32)) { + if (input_dtype != PrimType::Float(16) && input_dtype != PrimType::Float(32)) { TVM_FFI_VISIT_THROW(TypeError, call) << "Unsupported input datatype for operation: " << input_ty->dtype; } // Check datatype of scale param: - if (scale_ty->dtype != DataType::Float(32) && scale_ty->dtype != DataType::Float(16)) { + if (scale_dtype != PrimType::Float(32) && scale_dtype != PrimType::Float(16)) { TVM_FFI_VISIT_THROW(TypeError, call) << "scale param datatype should be one of [float16, float32], but got " << scale_ty->dtype; } // Check datatype of zero_point param: - if (zp_ty->dtype != DataType::Int(8) && zp_ty->dtype != DataType::UInt(8) && - zp_ty->dtype != DataType::Int(16) && zp_ty->dtype != DataType::UInt(16) && - zp_ty->dtype != DataType::Int(32) && zp_ty->dtype != DataType::UInt(32) && - zp_ty->dtype != DataType::Float(16)) { + if (zp_dtype != PrimType::Int(8) && zp_dtype != PrimType::UInt(8) && + zp_dtype != PrimType::Int(16) && zp_dtype != PrimType::UInt(16) && + zp_dtype != PrimType::Int(32) && zp_dtype != PrimType::UInt(32) && + zp_dtype != PrimType::Float(16)) { TVM_FFI_VISIT_THROW(TypeError, call) << "zero_point param datatype should be one of " << "['int8', 'uint8', 'int16', 'uint16', 'int32', 'uint32', 'float16'], " @@ -124,7 +132,7 @@ Type InferTypeQuantize(const Call& call, const BlockBuilder& ctx) { if (!is_scalar_or_singleton_vector(zp_ty)) check_param_size(zp_ty, input_ty, "zero_point"); auto output_ty = ffi::make_object(*input_ty.get()); - output_ty->dtype = attrs->out_dtype; + output_ty->dtype = PrimType(attrs->out_dtype); return TensorType(output_ty); } @@ -139,7 +147,7 @@ TVM_REGISTER_OP("relax.quantize") /* relax.dequantize */ -Expr dequantize(Expr data, Expr scale, Expr zero_point, int axis, DataType out_dtype) { +Expr dequantize(Expr data, Expr scale, Expr zero_point, int axis, DLDataType out_dtype) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->axis = axis; attrs->out_dtype = out_dtype; @@ -154,7 +162,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { Type InferTypeDequantize(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); - if (attrs->out_dtype != DataType::Float(16) && attrs->out_dtype != DataType::Float(32)) { + if (attrs->out_dtype != DLDataType{kDLFloat, 16, 1} && + attrs->out_dtype != DLDataType{kDLFloat, 32, 1}) { TVM_FFI_VISIT_THROW(TypeError, call) << "Unsupported output datatype attribute for operation: " << attrs->out_dtype; } @@ -162,28 +171,34 @@ Type InferTypeDequantize(const Call& call, const BlockBuilder& ctx) { TensorType input_ty = GetInputTensorType(call, ctx)[0]; TensorType scale_ty = GetInputTensorType(call, ctx)[1]; TensorType zp_ty = GetInputTensorType(call, ctx)[2]; + PrimType input_dtype = input_ty->dtype; + PrimType scale_dtype = scale_ty->dtype; + PrimType zp_dtype = zp_ty->dtype; // Check input datatype: - if (input_ty->dtype != DataType::Int(8) && input_ty->dtype != DataType::UInt(8) && - input_ty->dtype != DataType::Int(16) && input_ty->dtype != DataType::UInt(16) && - input_ty->dtype != DataType::Int(32) && input_ty->dtype != DataType::Float8E4M3FN() && - input_ty->dtype != DataType::Float8E5M2() && input_ty->dtype != DataType::Float(16) && - input_ty->dtype != DataType::Float(32)) { + if (input_dtype != PrimType::Int(8) && input_dtype != PrimType::UInt(8) && + input_dtype != PrimType::Int(16) && input_dtype != PrimType::UInt(16) && + input_dtype != PrimType::Int(32) && + input_dtype != PrimType(DLDataType{static_cast(kDLFloat8_e4m3fn), + static_cast(8), static_cast(1)}) && + input_dtype != PrimType(DLDataType{static_cast(kDLFloat8_e5m2), + static_cast(8), static_cast(1)}) && + input_dtype != PrimType::Float(16) && input_dtype != PrimType::Float(32)) { TVM_FFI_VISIT_THROW(TypeError, call) << "Unsupported input datatype for operation: " << attrs->out_dtype; } // Check datatype of scale param: - if (scale_ty->dtype != DataType::Float(32) && scale_ty->dtype != DataType::Float(16)) { + if (scale_dtype != PrimType::Float(32) && scale_dtype != PrimType::Float(16)) { TVM_FFI_VISIT_THROW(TypeError, call) << "scale param datatype should be one of [float16, float32], but got " << scale_ty->dtype; } // Check datatype of zero_point param: - if (zp_ty->dtype != DataType::Int(8) && zp_ty->dtype != DataType::UInt(8) && - zp_ty->dtype != DataType::Int(16) && zp_ty->dtype != DataType::UInt(16) && - zp_ty->dtype != DataType::Int(32) && zp_ty->dtype != DataType::UInt(32) && - zp_ty->dtype != DataType::Float(16)) { + if (zp_dtype != PrimType::Int(8) && zp_dtype != PrimType::UInt(8) && + zp_dtype != PrimType::Int(16) && zp_dtype != PrimType::UInt(16) && + zp_dtype != PrimType::Int(32) && zp_dtype != PrimType::UInt(32) && + zp_dtype != PrimType::Float(16)) { TVM_FFI_VISIT_THROW(TypeError, call) << "zero_point param datatype should be one of " << "['int8', 'uint8', 'int16', 'uint16', 'int32', 'uint32', 'float16'], " @@ -226,7 +241,7 @@ Type InferTypeDequantize(const Call& call, const BlockBuilder& ctx) { if (!is_scalar_or_singleton_vector(zp_ty)) check_param_size(zp_ty, input_ty, "zero_point"); auto output_ty = ffi::make_object(*input_ty.get()); - output_ty->dtype = attrs->out_dtype; + output_ty->dtype = PrimType(attrs->out_dtype); return TensorType(output_ty); } diff --git a/src/relax/op/tensor/qdq.h b/src/relax/op/tensor/qdq.h index 9d13dcde277f..bdb31f87e61e 100644 --- a/src/relax/op/tensor/qdq.h +++ b/src/relax/op/tensor/qdq.h @@ -40,7 +40,7 @@ namespace relax { * \param out_dtype The data type of the output tensor. * \return The computed result. */ -Expr quantize(Expr data, Expr scale, Expr zero_point, int axis, DataType out_dtype); +Expr quantize(Expr data, Expr scale, Expr zero_point, int axis, DLDataType out_dtype); /*! * \brief Dequantize op. @@ -53,7 +53,7 @@ Expr quantize(Expr data, Expr scale, Expr zero_point, int axis, DataType out_dty * \param out_dtype The data type of the output tensor. * \return The computed result. */ -Expr dequantize(Expr data, Expr scale, Expr zero_point, int axis, DataType out_dtype); +Expr dequantize(Expr data, Expr scale, Expr zero_point, int axis, DLDataType out_dtype); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/sampling.cc b/src/relax/op/tensor/sampling.cc index 27f9241e2c29..196e6f887649 100644 --- a/src/relax/op/tensor/sampling.cc +++ b/src/relax/op/tensor/sampling.cc @@ -37,7 +37,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { MultinomialFromUniformAttrs::RegisterReflection(); /* relax.multinomial_from_uniform */ -Expr multinomial_from_uniform(Expr prob, Expr uniform_sample, Expr sample_indices, DataType dtype) { +Expr multinomial_from_uniform(Expr prob, Expr uniform_sample, Expr sample_indices, + DLDataType dtype) { ffi::ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; @@ -59,19 +60,24 @@ Type InferTypeMultinomialFromUniform(const Call& call, const BlockBuilder& ctx) TensorType sample_indices_ty = GetInputTensorType(call, 2, ctx); const auto* attrs = call->attrs.as(); - if (!prob_ty->dtype.is_float()) { + // Only the element kind matters here; shape inference does not depend on vector lanes. + if (prob_ty->dtype.code() != DLDataTypeCode::kDLFloat && + prob_ty->dtype.code() != DLDataTypeCode::kDLBfloat) { TVM_FFI_VISIT_THROW(TypeError, call) << "Multinomial_from_uniform op requires the input prob to have float dtype. " "However, the given prob dtype is " << prob_ty->dtype; } - if (!uniform_sample_ty->dtype.is_float()) { + // Only the element kind matters here; shape inference does not depend on vector lanes. + if (uniform_sample_ty->dtype.code() != DLDataTypeCode::kDLFloat && + uniform_sample_ty->dtype.code() != DLDataTypeCode::kDLBfloat) { TVM_FFI_VISIT_THROW(TypeError, call) << "Multinomial_from_uniform op requires the input uniform_sample to have float " "dtype. However, the given uniform_sample dtype is " << uniform_sample_ty->dtype; } - if (!sample_indices_ty->dtype.is_int()) { + // Only the element kind matters here; shape inference does not depend on vector lanes. + if (sample_indices_ty->dtype.code() != DLDataTypeCode::kDLInt) { TVM_FFI_VISIT_THROW(TypeError, call) << "Multinomial from uniform op requires the input sample_indices to have int " "dtype. However, the given sample_indices dtype is " @@ -79,7 +85,7 @@ Type InferTypeMultinomialFromUniform(const Call& call, const BlockBuilder& ctx) } if (prob_ty->IsUnknownNdim() || uniform_sample_ty->IsUnknownNdim() || sample_indices_ty->IsUnknownNdim()) { - return TensorType(attrs->dtype, kUnknownNDim, prob_ty->vdevice); + return TensorType(PrimType(attrs->dtype), kUnknownNDim, prob_ty->vdevice); } if (prob_ty->ndim != 2) { TVM_FFI_VISIT_THROW(ValueError, call) @@ -109,7 +115,7 @@ Type InferTypeMultinomialFromUniform(const Call& call, const BlockBuilder& ctx) // The output shape is expected to be `(n, 1)` if (prob_shape == nullptr || uniform_sample_shape == nullptr || sample_indices_shape == nullptr) { - return TensorType(attrs->dtype, 2, prob_ty->vdevice); + return TensorType(PrimType(attrs->dtype), 2, prob_ty->vdevice); } PrimExpr batch = prob_shape->values[0]; @@ -132,7 +138,7 @@ Type InferTypeMultinomialFromUniform(const Call& call, const BlockBuilder& ctx) << uniform_sample_ty->shape << " and the given sample_indices tensor has shape " << sample_indices_ty->shape; } - return TensorType(ShapeExpr({n, 1}), attrs->dtype, prob_ty->vdevice); + return TensorType(ShapeExpr({n, 1}), PrimType(attrs->dtype), prob_ty->vdevice); } TVM_REGISTER_OP("relax.multinomial_from_uniform") diff --git a/src/relax/op/tensor/sampling.h b/src/relax/op/tensor/sampling.h index d13aa835d68d..077ef4313669 100644 --- a/src/relax/op/tensor/sampling.h +++ b/src/relax/op/tensor/sampling.h @@ -49,7 +49,8 @@ namespace relax { * \param dtype The data type of the output tensor. * \return The sampled result. */ -Expr multinomial_from_uniform(Expr prob, Expr uniform_sample, Expr sample_indices, DataType dtype); +Expr multinomial_from_uniform(Expr prob, Expr uniform_sample, Expr sample_indices, + DLDataType dtype); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/search.cc b/src/relax/op/tensor/search.cc index d80f484ebcf5..635879db2be3 100644 --- a/src/relax/op/tensor/search.cc +++ b/src/relax/op/tensor/search.cc @@ -64,10 +64,9 @@ Type InferTypeBucketize(const Call& call, const BlockBuilder& ctx) { } auto attrs = call->attrs.as(); - DataType out_dtype; - out_dtype = DataType::Int(64); + PrimType out_dtype = PrimType::Int(64); if (attrs->out_int32) { - out_dtype = DataType::Int(32); + out_dtype = PrimType::Int(32); } const auto* data_shape = input_tensor_info->shape.as(); @@ -119,13 +118,15 @@ Type InferTypeWhere(const Call& call, const BlockBuilder& ctx) { } } - if (!cond_ty->dtype.is_bool()) { + PrimType cond_dtype = cond_ty->dtype; + // Where condition validation only checks the boolean element kind; lanes are irrelevant here. + if (cond_dtype.code() != DLDataTypeCode::kDLBool) { TVM_FFI_VISIT_THROW(TypeError, call) << "Where requires the input condition tensor to have boolean dtype. However, " "the given condition dtype is " << cond_ty->dtype; } - DataType output_dtype = InferBinaryArithOpOutDtype(call, ctx, x1_ty, x2_ty); + PrimType output_dtype(InferBinaryArithOpOutDtype(call, ctx, x1_ty, x2_ty)); int output_ndim; if (cond_ty->IsUnknownNdim() || x1_ty->IsUnknownNdim() || x2_ty->IsUnknownNdim()) { @@ -209,7 +210,7 @@ Type InferTypeArgmaxArgmin(const Call& call, const BlockBuilder& ctx) { TVM_FFI_ICHECK_GE(out_ndim, 0); } - DataType out_dtype = DataType::Int(64); + PrimType out_dtype = PrimType::Int(64); // The inference rule for reduction operator output shapes: // - axes is None, keepdims is false -> return the zero-rank shape; // - axes is None, keepdims is true -> return the shape whose ndim is the same as input and every @@ -230,7 +231,7 @@ Type InferTypeArgmaxArgmin(const Call& call, const BlockBuilder& ctx) { } if (data_ty->ndim > 0) { - out_dtype = data_shape->values[0]->dtype; + out_dtype = data_shape->values[0].ty(); } ffi::Array out_shape; diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc index 57999a3356b7..a92cbee4a001 100644 --- a/src/relax/op/tensor/set.cc +++ b/src/relax/op/tensor/set.cc @@ -106,9 +106,9 @@ Type InferTypeUnique(const Call& call, const BlockBuilder& ctx) { if (f_convert_to_int64(return_index->value)) { if (data_ty->ndim == 0) { output_ty.push_back( - TensorType(ShapeExpr({IntImm::Int64(/*value=*/1)}), DataType::Int(64), data_ty->vdevice)); + TensorType(ShapeExpr({IntImm::Int64(/*value=*/1)}), PrimType::Int(64), data_ty->vdevice)); } else { - output_ty.push_back(TensorType(DataType::Int(64), /*ndim=*/1, data_ty->vdevice)); + output_ty.push_back(TensorType(PrimType::Int(64), /*ndim=*/1, data_ty->vdevice)); } } @@ -116,9 +116,9 @@ Type InferTypeUnique(const Call& call, const BlockBuilder& ctx) { if (f_convert_to_int64(return_inverse->value)) { if (data_ty->ndim == 0) { output_ty.push_back( - TensorType(ShapeExpr({IntImm::Int64(/*value=*/1)}), DataType::Int(64), data_ty->vdevice)); + TensorType(ShapeExpr({IntImm::Int64(/*value=*/1)}), PrimType::Int(64), data_ty->vdevice)); } else { - output_ty.push_back(TensorType(DataType::Int(64), /*ndim=*/1, data_ty->vdevice)); + output_ty.push_back(TensorType(PrimType::Int(64), /*ndim=*/1, data_ty->vdevice)); } } @@ -126,9 +126,9 @@ Type InferTypeUnique(const Call& call, const BlockBuilder& ctx) { if (f_convert_to_int64(return_counts->value)) { if (data_ty->ndim == 0) { output_ty.push_back( - TensorType(ShapeExpr({IntImm::Int64(/*value=*/1)}), DataType::Int(64), data_ty->vdevice)); + TensorType(ShapeExpr({IntImm::Int64(/*value=*/1)}), PrimType::Int(64), data_ty->vdevice)); } else { - output_ty.push_back(TensorType(DataType::Int(64), /*ndim=*/1, data_ty->vdevice)); + output_ty.push_back(TensorType(PrimType::Int(64), /*ndim=*/1, data_ty->vdevice)); } } @@ -175,7 +175,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { Type InferTypeNonzero(const Call& call, const BlockBuilder& ctx) { TensorType data_ty = GetInputTensorType(call, 0, ctx); - return TensorType(DataType::Int(64), 2, data_ty->vdevice); + return TensorType(PrimType::Int(64), 2, data_ty->vdevice); } TVM_REGISTER_OP("relax.nonzero") diff --git a/src/relax/op/tensor/sorting.cc b/src/relax/op/tensor/sorting.cc index 2d014cded4ec..c470fa0d4f6e 100644 --- a/src/relax/op/tensor/sorting.cc +++ b/src/relax/op/tensor/sorting.cc @@ -66,7 +66,7 @@ TVM_REGISTER_OP("relax.sort") /* relax.argsort */ -Expr argsort(Expr data, int axis, bool descending, DataType dtype) { +Expr argsort(Expr data, int axis, bool descending, DLDataType dtype) { auto attrs = ffi::make_object(); attrs->axis = std::move(axis); attrs->descending = std::move(descending); @@ -84,7 +84,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { Type InferTypeArgsort(const Call& call, const BlockBuilder& ctx) { TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); - DataType out_type = attrs->dtype.is_void() ? data_ty->dtype : attrs->dtype; + PrimType out_type = + attrs->dtype == DLDataType{kDLOpaqueHandle, 0, 0} ? data_ty->dtype : PrimType(attrs->dtype); if (data_ty->shape.defined()) { return TensorType(data_ty->shape.value(), out_type, data_ty->vdevice); } @@ -100,7 +101,7 @@ TVM_REGISTER_OP("relax.argsort") /* relax.topk */ -Expr topk(Expr data, int k, int axis, ffi::String ret_type, bool largest, DataType dtype) { +Expr topk(Expr data, int k, int axis, ffi::String ret_type, bool largest, DLDataType dtype) { auto attrs = ffi::make_object(); attrs->k = std::move(k); attrs->axis = std::move(axis); @@ -121,7 +122,8 @@ Type InferTypeTopK(const Call& call, const BlockBuilder& ctx) { TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* data_shape = data_ty->shape.as(); const auto* attrs = call->attrs.as(); - DataType indices_type = attrs->dtype.is_void() ? data_ty->dtype : attrs->dtype; + PrimType indices_type = + attrs->dtype == DLDataType{kDLOpaqueHandle, 0, 0} ? data_ty->dtype : PrimType(attrs->dtype); int ndim = data_ty->ndim; int k = attrs->k; ffi::String ret_type = attrs->ret_type; diff --git a/src/relax/op/tensor/sorting.h b/src/relax/op/tensor/sorting.h index a4154ce416ad..8a2ec98388df 100644 --- a/src/relax/op/tensor/sorting.h +++ b/src/relax/op/tensor/sorting.h @@ -51,7 +51,7 @@ Expr sort(Expr data, int axis, bool descending); * \param dtype The data type of the output indices. * \return The computed result. */ -Expr argsort(Expr data, int axis, bool descending, DataType dtype); +Expr argsort(Expr data, int axis, bool descending, DLDataType dtype); /*! * \brief Get the top k elements in an input tensor along the given axis. @@ -63,7 +63,7 @@ Expr argsort(Expr data, int axis, bool descending, DataType dtype); * \param dtype The data type of the indices output. * \return The computed result. */ -Expr topk(Expr data, int k, int axis, ffi::String ret_type, bool largest, DataType dtype); +Expr topk(Expr data, int k, int axis, ffi::String ret_type, bool largest, DLDataType dtype); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/statistical.cc b/src/relax/op/tensor/statistical.cc index 9fe68afe2901..15bbd701e67f 100644 --- a/src/relax/op/tensor/statistical.cc +++ b/src/relax/op/tensor/statistical.cc @@ -155,7 +155,8 @@ Type InferTypeScan(const Call& call, const BlockBuilder& ctx) { TensorType data_ty = GetUnaryInputTensorType(call, ctx); const auto* attrs = call->attrs.as(); - DataType out_type = attrs->dtype.is_void() ? data_ty->dtype : attrs->dtype; + PrimType out_type = + attrs->dtype == DLDataType{kDLOpaqueHandle, 0, 0} ? data_ty->dtype : PrimType(attrs->dtype); if (!attrs->axis.has_value()) { // flattened @@ -216,7 +217,7 @@ Type InferTypeStatisticalExtension(const Call& call, const BlockBuilder& ctx) { return TensorType(ShapeExpr(ffi::Array()), data_ty->dtype, data_ty->vdevice); } return TupleType({TensorType(data_ty->dtype, out_ndim, data_ty->vdevice), - TensorType(DataType::Int(64), out_ndim, data_ty->vdevice)}); + TensorType(PrimType::Int(64), out_ndim, data_ty->vdevice)}); } ffi::Array out_shape; @@ -234,15 +235,15 @@ Type InferTypeStatisticalExtension(const Call& call, const BlockBuilder& ctx) { return TensorType(ShapeExpr(out_shape), data_ty->dtype, data_ty->vdevice); else return TupleType({TensorType(ShapeExpr(out_shape), data_ty->dtype, data_ty->vdevice), - TensorType(ShapeExpr(out_shape), DataType::Int(64), data_ty->vdevice)}); + TensorType(ShapeExpr(out_shape), PrimType::Int(64), data_ty->vdevice)}); } /* relax.cumprod */ -Expr cumprod(Expr data, ffi::Optional axis, ffi::Optional dtype, +Expr cumprod(Expr data, ffi::Optional axis, ffi::Optional dtype, bool exclusive) { auto attrs = ffi::make_object(); attrs->axis = std::move(axis); - attrs->dtype = std::move(dtype.value_or(DataType::Void())); + attrs->dtype = dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})); attrs->exclusive = exclusive; static const Op& op = Op::Get("relax.cumprod"); @@ -262,10 +263,11 @@ TVM_REGISTER_OP("relax.cumprod") .set_attr("FPurity", true); /* relax.cumsum */ -Expr cumsum(Expr data, ffi::Optional axis, ffi::Optional dtype, bool exclusive) { +Expr cumsum(Expr data, ffi::Optional axis, ffi::Optional dtype, + bool exclusive) { auto attrs = ffi::make_object(); attrs->axis = std::move(axis); - attrs->dtype = std::move(dtype.value_or(DataType::Void())); + attrs->dtype = dtype.value_or((DLDataType{kDLOpaqueHandle, 0, 0})); attrs->exclusive = exclusive; static const Op& op = Op::Get("relax.cumsum"); diff --git a/src/relax/op/tensor/statistical.h b/src/relax/op/tensor/statistical.h index 2d80790926ed..3ab998110603 100644 --- a/src/relax/op/tensor/statistical.h +++ b/src/relax/op/tensor/statistical.h @@ -99,7 +99,7 @@ Expr sum(Expr x, ffi::Optional> axis, bool keepdims); * result. */ Expr cumprod(Expr data, ffi::Optional axis = std::nullopt, - ffi::Optional dtype = std::nullopt, bool exclusive = false); + ffi::Optional dtype = std::nullopt, bool exclusive = false); /*! * \brief Numpy style cumsum op. Return the cumulative inclusive sum of the elements along @@ -114,7 +114,7 @@ Expr cumprod(Expr data, ffi::Optional axis = std::nullopt, * \return The computed result. */ Expr cumsum(Expr data, ffi::Optional axis = std::nullopt, - ffi::Optional dtype = std::nullopt, bool exclusive = false); + ffi::Optional dtype = std::nullopt, bool exclusive = false); /*! \brief Computes the variance of tensor elements over given axes. */ Expr variance(Expr x, ffi::Optional> axis, bool keepdims); diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc index 6daacfe16578..1e21e7dbdcc7 100644 --- a/src/relax/op/tensor/ternary.cc +++ b/src/relax/op/tensor/ternary.cc @@ -57,9 +57,9 @@ Type InferTypeEwiseFMA(const Call& call, const BlockBuilder& ctx) { } } - DataType output_dtype; + PrimType output_dtype = PrimType::Void(); if (t1->IsUnknownDtype() || t2->IsUnknownDtype() || t3->IsUnknownDtype()) { - output_dtype = DataType::Void(); + output_dtype = PrimType::Void(); } else if (t1->dtype != t2->dtype || t2->dtype != t3->dtype) { TVM_FFI_VISIT_THROW(TypeError, call) << "Data types " << t1->dtype << ", " << t2->dtype << ", and " << t3->dtype << " must be equal for EwiseFMA"; diff --git a/src/relax/op/tensor/unary.cc b/src/relax/op/tensor/unary.cc index 598ec78aacda..bd15223df878 100644 --- a/src/relax/op/tensor/unary.cc +++ b/src/relax/op/tensor/unary.cc @@ -33,7 +33,7 @@ namespace relax { Type InferTypeUnaryCheck(const Call& call, const BlockBuilder& ctx) { return InferTypeUnary(call, ctx, - [](const TensorType& input_ty) { return DataType::Bool(); }); + [](const TensorType& input_ty) { return PrimType::Bool(); }); } /***************** Arithmetic operators *****************/ diff --git a/src/relax/op/vision/nms.cc b/src/relax/op/vision/nms.cc index bde579f0ed5a..6f289d6b8755 100644 --- a/src/relax/op/vision/nms.cc +++ b/src/relax/op/vision/nms.cc @@ -84,8 +84,8 @@ Type InferTypeAllClassNMS(const Call& call, const BlockBuilder& ctx) { ShapeExpr oshape(oshape_values); tvm::ffi::Array counts_values = {1}; ShapeExpr counts_shape(counts_values); - tvm::ffi::Array fields = {TensorType(oshape, DataType::Int(64), vdev), - TensorType(counts_shape, DataType::Int(64), vdev)}; + tvm::ffi::Array fields = {TensorType(oshape, PrimType::Int(64), vdev), + TensorType(counts_shape, PrimType::Int(64), vdev)}; return TupleType(fields); } @@ -96,9 +96,9 @@ Type InferTypeAllClassNMS(const Call& call, const BlockBuilder& ctx) { ShapeExpr scores_shape(scores_values); tvm::ffi::Array counts_values = {batch}; ShapeExpr counts_shape(counts_values); - tvm::ffi::Array fields = {TensorType(indices_shape, DataType::Int(64), vdev), - TensorType(scores_shape, DataType::Float(32), vdev), - TensorType(counts_shape, DataType::Int(64), vdev)}; + tvm::ffi::Array fields = {TensorType(indices_shape, PrimType::Int(64), vdev), + TensorType(scores_shape, PrimType::Float(32), vdev), + TensorType(counts_shape, PrimType::Int(64), vdev)}; return TupleType(fields); } @@ -153,9 +153,9 @@ Type InferTypeGetValidCounts(const Call& call, const BlockBuilder& ctx) { auto vdev = data_ty->vdevice; const auto* data_shape = data_ty->shape.as(); if (data_shape == nullptr) { - tvm::ffi::Array fields = {TensorType(DataType::Int(32), /*ndim=*/1, vdev), + tvm::ffi::Array fields = {TensorType(PrimType::Int(32), /*ndim=*/1, vdev), TensorType(data_ty->dtype, /*ndim=*/3, vdev), - TensorType(DataType::Int(32), /*ndim=*/2, vdev)}; + TensorType(PrimType::Int(32), /*ndim=*/2, vdev)}; return TupleType(fields); } @@ -177,9 +177,9 @@ Type InferTypeGetValidCounts(const Call& call, const BlockBuilder& ctx) { } tvm::ffi::Array fields = { - TensorType(ShapeExpr({batch}), DataType::Int(32), vdev), + TensorType(ShapeExpr({batch}), PrimType::Int(32), vdev), TensorType(ShapeExpr({batch, num_anchors, elem_length}), data_ty->dtype, vdev), - TensorType(ShapeExpr({batch, num_anchors}), DataType::Int(32), vdev)}; + TensorType(ShapeExpr({batch, num_anchors}), PrimType::Int(32), vdev)}; return TupleType(fields); } @@ -251,12 +251,12 @@ Type InferTypeNMS(const Call& call, const BlockBuilder& ctx) { TVM_FFI_VISIT_THROW(ValueError, call) << "non_max_suppression expects indices to be 2-D, got ndim " << indices_ty->ndim; } - if (!valid_count_ty->IsUnknownDtype() && valid_count_ty->dtype != DataType::Int(32)) { + if (!valid_count_ty->IsUnknownDtype() && valid_count_ty->dtype != PrimType::Int(32)) { TVM_FFI_VISIT_THROW(TypeError, call) << "non_max_suppression expects valid_count to have dtype int32, got " << valid_count_ty->dtype; } - if (!indices_ty->IsUnknownDtype() && indices_ty->dtype != DataType::Int(32)) { + if (!indices_ty->IsUnknownDtype() && indices_ty->dtype != PrimType::Int(32)) { TVM_FFI_VISIT_THROW(TypeError, call) << "non_max_suppression expects indices to have dtype int32, got " << indices_ty->dtype; } @@ -319,30 +319,30 @@ Type InferTypeNMS(const Call& call, const BlockBuilder& ctx) { // valid_box_count[batch, 1]) if (data_shape == nullptr) { tvm::ffi::Array fields = {TensorType(data_ty->dtype, /*ndim=*/3, vdev), - TensorType(DataType::Int(32), /*ndim=*/2, vdev), - TensorType(DataType::Int(32), /*ndim=*/2, vdev)}; + TensorType(PrimType::Int(32), /*ndim=*/2, vdev), + TensorType(PrimType::Int(32), /*ndim=*/2, vdev)}; return TupleType(fields); } auto batch = data_shape->values[0]; auto num_anchors = data_shape->values[1]; tvm::ffi::Array fields = { TensorType(ffi::GetRef(data_shape), data_ty->dtype, vdev), - TensorType(ShapeExpr({batch, num_anchors}), DataType::Int(32), vdev), - TensorType(ShapeExpr({batch, IntImm::Int64(1)}), DataType::Int(32), vdev)}; + TensorType(ShapeExpr({batch, num_anchors}), PrimType::Int(32), vdev), + TensorType(ShapeExpr({batch, IntImm::Int64(1)}), PrimType::Int(32), vdev)}; return TupleType(fields); } // Hard NMS returns (box_indices[batch, num_anchors], valid_box_count[batch, 1]) if (data_shape == nullptr) { - tvm::ffi::Array fields = {TensorType(DataType::Int(32), /*ndim=*/2, vdev), - TensorType(DataType::Int(32), /*ndim=*/2, vdev)}; + tvm::ffi::Array fields = {TensorType(PrimType::Int(32), /*ndim=*/2, vdev), + TensorType(PrimType::Int(32), /*ndim=*/2, vdev)}; return TupleType(fields); } auto batch = data_shape->values[0]; auto num_anchors = data_shape->values[1]; tvm::ffi::Array fields = { - TensorType(ShapeExpr({batch, num_anchors}), DataType::Int(32), vdev), - TensorType(ShapeExpr({batch, IntImm::Int64(1)}), DataType::Int(32), vdev)}; + TensorType(ShapeExpr({batch, num_anchors}), PrimType::Int(32), vdev), + TensorType(ShapeExpr({batch, IntImm::Int64(1)}), PrimType::Int(32), vdev)}; return TupleType(fields); } diff --git a/src/relax/script/printer/dependent_type.cc b/src/relax/script/printer/dependent_type.cc index a37c21406fac..e3a14c0cdafe 100644 --- a/src/relax/script/printer/dependent_type.cc +++ b/src/relax/script/printer/dependent_type.cc @@ -100,7 +100,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } if (!n->IsUnknownDtype()) { kwargs_keys.push_back("dtype"); - kwargs_values.push_back(LiteralDoc::DataType(n->dtype, n_p->Attr("dtype"))); + kwargs_values.push_back(LiteralDoc::DataType(n->dtype->dtype, n_p->Attr("dtype"))); } if (!n->shape.defined() && !n->IsUnknownNdim()) { kwargs_keys.push_back("ndim"); diff --git a/src/relax/script/printer/distributed.cc b/src/relax/script/printer/distributed.cc index f05ec8fe714a..97d800d5d139 100644 --- a/src/relax/script/printer/distributed.cc +++ b/src/relax/script/printer/distributed.cc @@ -61,11 +61,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } if (!n->tensor_ty->IsUnknownDtype()) { if (!require_kwargs) { - args.push_back(LiteralDoc::DataType(n->tensor_ty->dtype, n_p->Attr("dtype"))); + args.push_back(LiteralDoc::DataType(n->tensor_ty->dtype->dtype, n_p->Attr("dtype"))); } else { kwargs_keys.push_back("dtype"); kwargs_values.push_back( - LiteralDoc::DataType(n->tensor_ty->dtype, n_p->Attr("dtype"))); + LiteralDoc::DataType(n->tensor_ty->dtype->dtype, n_p->Attr("dtype"))); } } else { require_kwargs = true; diff --git a/src/relax/script/printer/expr.cc b/src/relax/script/printer/expr.cc index dfce2b40b1f9..7b2f39ecf335 100644 --- a/src/relax/script/printer/expr.cc +++ b/src/relax/script/printer/expr.cc @@ -81,21 +81,21 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); ffi::Optional SpecialScalar(const runtime::Tensor& n, const AccessPath& p) { - DataType dtype = n.DataType(); + DLDataType dtype = n.DataType(); const void* data = n->data; if (n->ndim != 0 || n->device.device_type != kDLCPU) { return std::nullopt; } - if (dtype == DataType::Int(8)) { + if (dtype == DLDataType{kDLInt, 8, 1}) { return LiteralDoc::Int(*reinterpret_cast(data), p); - } else if (dtype == DataType::Int(16)) { + } else if (dtype == DLDataType{kDLInt, 16, 1}) { return LiteralDoc::Int(*reinterpret_cast(data), p); - } else if (dtype == DataType::Int(32)) { + } else if (dtype == DLDataType{kDLInt, 32, 1}) { return LiteralDoc::Int(*reinterpret_cast(data), p); - } else if (dtype == DataType::Int(64)) { + } else if (dtype == DLDataType{kDLInt, 64, 1}) { return LiteralDoc::Int(*reinterpret_cast(data), p); - } else if (dtype == DataType::Float(16)) { + } else if (dtype == DLDataType{kDLFloat, 16, 1}) { // From IEEE-754 float16 definition // // Ref: https://en.wikipedia.org/wiki/Half-precision_floating-point_format @@ -122,11 +122,11 @@ ffi::Optional SpecialScalar(const runtime::Tensor& n, const AccessPath& } return LiteralDoc::Float(value, p); - } else if (dtype == DataType::Float(32)) { + } else if (dtype == DLDataType{kDLFloat, 32, 1}) { return LiteralDoc::Float(*reinterpret_cast(data), p); - } else if (dtype == DataType::Float(64)) { + } else if (dtype == DLDataType{kDLFloat, 64, 1}) { return LiteralDoc::Float(*reinterpret_cast(data), p); - } else if (dtype == DataType::Bool()) { + } else if (dtype == DLDataType{kDLBool, 8, 1}) { return LiteralDoc::Boolean(*reinterpret_cast(data), p); } else { return std::nullopt; diff --git a/src/relax/script/printer/tir.cc b/src/relax/script/printer/tir.cc index e0742f8edd44..06bce7c1ff8c 100644 --- a/src/relax/script/printer/tir.cc +++ b/src/relax/script/printer/tir.cc @@ -43,9 +43,10 @@ RelaxFrameNode* GetRelaxFrame(IRDocsifier d) { } Doc PrintTIRVar(tirx::Var n, AccessPath n_p, IRDocsifier d) { - TVM_FFI_CHECK(n->dtype.is_scalar(), TypeError) + PrimType n_ty = n.ty(); + TVM_FFI_CHECK(!n_ty.IsScalableVector() && !n_ty.IsFixedLengthVector(), TypeError) << "Relax only uses scalar TIR variables," - << "but received TIR variable " << n << " with dtype " << n->dtype; + << "but received TIR variable " << n << " with dtype " << n_ty->dtype; if (!d->IsVarDefined(n)) { RelaxFrameNode* f = GetRelaxFrame(d); @@ -77,7 +78,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "relax", [](tvm::IntImm n, AccessPath n_p, IRDocsifier d) -> Doc { // // TODO(@junrushao): support non-int64 cases - if (n->dtype.is_bool()) { + if (n->ty().MatchesElementType(DLDataTypeCode::kDLBool, 8)) { return LiteralDoc::Boolean(n->value, n_p); } else { return LiteralDoc::Int(n->value, n_p); diff --git a/src/relax/transform/adjust_matmul_order.cc b/src/relax/transform/adjust_matmul_order.cc index 4cf8831514dc..2d6e6fcc5e33 100644 --- a/src/relax/transform/adjust_matmul_order.cc +++ b/src/relax/transform/adjust_matmul_order.cc @@ -208,22 +208,24 @@ std::tuple)>> // If two of the three are compile-time, group those two values // together, to allow them to be lifted out and pre-computed. if (is_compile_time(expr_a) && is_compile_time(expr_b)) { - return matmul(matmul(expr_a, expr_b, DataType::Void()), expr_c, DataType::Void()); + return matmul(matmul(expr_a, expr_b, (DLDataType{kDLOpaqueHandle, 0, 0})), expr_c, + (DLDataType{kDLOpaqueHandle, 0, 0})); } else if (is_compile_time(expr_b) && is_compile_time(expr_c)) { - return matmul(expr_a, matmul(expr_b, expr_c, DataType::Void()), DataType::Void()); + return matmul(expr_a, matmul(expr_b, expr_c, (DLDataType{kDLOpaqueHandle, 0, 0})), + (DLDataType{kDLOpaqueHandle, 0, 0})); } // Otherwise, select the order that reduces the total number of // operations required, assuming a naive matmul (see below). if (shape_a.size() == 1) { - shape_a = {IntImm(shape_a[0].dtype(), 1), shape_a[0]}; + shape_a = {IntImm(shape_a[0].ty(), 1), shape_a[0]}; } if (shape_b.size() == 1) { if (matches.count(pat_matmul_on_lhs)) { - shape_b = {shape_b[0], IntImm(shape_b[0].dtype(), 1)}; + shape_b = {shape_b[0], IntImm(shape_b[0].ty(), 1)}; } else if (matches.count(pat_matmul_on_rhs)) { - shape_b = {IntImm(shape_b[0].dtype(), 1), shape_b[0]}; + shape_b = {IntImm(shape_b[0].ty(), 1), shape_b[0]}; } else { TVM_FFI_THROW(InternalError) << "OrPattern " << pat << " matched, but neither " << pat_matmul_on_lhs << " nor " @@ -231,7 +233,7 @@ std::tuple)>> } } if (shape_c.size() == 1) { - shape_c = {shape_c[0], IntImm(shape_c[0].dtype(), 1)}; + shape_c = {shape_c[0], IntImm(shape_c[0].ty(), 1)}; } PrimExpr size_N = shape_a[shape_a.size() - 2]; // row of A @@ -285,9 +287,11 @@ std::tuple)>> size_N > 0 && size_R > 0 && size_M > 0 && size_B > 0); if (analyzer->CanProve(ops_with_lhs_first < ops_with_rhs_first)) { - return matmul(matmul(expr_a, expr_b, DataType::Void()), expr_c, DataType::Void()); + return matmul(matmul(expr_a, expr_b, (DLDataType{kDLOpaqueHandle, 0, 0})), expr_c, + (DLDataType{kDLOpaqueHandle, 0, 0})); } else if (analyzer->CanProve(ops_with_rhs_first < ops_with_lhs_first)) { - return matmul(expr_a, matmul(expr_b, expr_c, DataType::Void()), DataType::Void()); + return matmul(expr_a, matmul(expr_b, expr_c, (DLDataType{kDLOpaqueHandle, 0, 0})), + (DLDataType{kDLOpaqueHandle, 0, 0})); } // If we cannot determine which order is best, keep the existing order. diff --git a/src/relax/transform/allocate_workspace.cc b/src/relax/transform/allocate_workspace.cc index 4dfc84b822da..a593cb7ffee7 100644 --- a/src/relax/transform/allocate_workspace.cc +++ b/src/relax/transform/allocate_workspace.cc @@ -61,7 +61,7 @@ class ExternFunctionRewriter : ExprMutator { // Append the workspace parameter to this function. ffi::Array new_params = func_node->params; - auto ty = TensorType(ShapeExpr({IntImm::Int32(max_workspace_size_)}), DataType::UInt(8)); + auto ty = TensorType(ShapeExpr({IntImm::Int32(max_workspace_size_)}), PrimType::UInt(8)); Var workspace_param(name_sup_->FreshName("workspace"), ty); if (func_node->GetAttr(attr::kCodegen)) { @@ -149,7 +149,7 @@ class WorkspaceProvider : ExprMutator { builder_->BeginDataflowBlock(); if (!workspace_var_main_.defined()) { auto shape = ShapeExpr({IntImm::Int32(max_workspace_size_)}); - auto ty = DataTypeImm(DataType::UInt(8)); + auto ty = DataTypeImm((DLDataType{kDLUInt, 8, 1})); auto workspace = MakeAllocTensor(shape, ty, PrimValue::Int64(0)); workspace_var_main_ = builder_->Emit(workspace, "workspace_main"); } diff --git a/src/relax/transform/alter_op_impl.cc b/src/relax/transform/alter_op_impl.cc index a938b946d20c..7a3b5743f423 100644 --- a/src/relax/transform/alter_op_impl.cc +++ b/src/relax/transform/alter_op_impl.cc @@ -45,7 +45,7 @@ static constexpr const char* kOperatorName = "operator_name"; /*! \brief Construct ranges from shape dimensions */ static ffi::Array ConstructRangeFromShape(const ffi::Array& shape) { - return shape.Map([](const PrimExpr& dim) { return Range(IntImm(dim.dtype(), 0), dim); }); + return shape.Map([](const PrimExpr& dim) { return Range(IntImm(dim.ty(), 0), dim); }); } static ffi::Array GetShapeFromTensorType(const TensorType& tensor_ty) { @@ -206,7 +206,7 @@ class AlterOpImplMutator : public ExprMutator { * \brief Adds the \p remove_pad op to the module if it has not already been added before. * \returns The global var associated with the remove_pad PrimFunc. */ - GlobalVar GetOrCreateRemovePadOp(const ffi::Array& old_shape, const DataType& dtype) { + GlobalVar GetOrCreateRemovePadOp(const ffi::Array& old_shape, DLDataType dtype) { int t_shape = old_shape.size(); if (remove_pad_map_.count(t_shape) != 0) { return remove_pad_map_[t_shape]; @@ -214,8 +214,8 @@ class AlterOpImplMutator : public ExprMutator { // Create dynamic shapes for input and output tensors ffi::Array dyn_padded_shape, dyn_old_shape; for (int i = 0; i < t_shape; i++) { - tirx::Var var1("p" + std::to_string(i), old_shape[i].dtype()); - tirx::Var var2("i" + std::to_string(i), old_shape[i].dtype()); + tirx::Var var1("p" + std::to_string(i), old_shape[i].ty()); + tirx::Var var2("i" + std::to_string(i), old_shape[i].ty()); dyn_padded_shape.push_back(var1); dyn_old_shape.push_back(var2); } @@ -264,7 +264,7 @@ class AlterOpImplMutator : public ExprMutator { TransformLayout(expr, inverse_index_map, axis_separator, input_axis_separator)); const auto& tensor_ty = padded_expr->ty.as_or_throw(); - GlobalVar gv_remove_pad = GetOrCreateRemovePadOp(old_shape, tensor_ty->dtype); + GlobalVar gv_remove_pad = GetOrCreateRemovePadOp(old_shape, tensor_ty->dtype->dtype); return Call(call_tir_op_, {gv_remove_pad, Tuple({padded_expr})}, {}, {old_tensor_ty}); } } diff --git a/src/relax/transform/call_tir_rewrite.cc b/src/relax/transform/call_tir_rewrite.cc index 61fee5be7f8d..5a1bbcaa0040 100644 --- a/src/relax/transform/call_tir_rewrite.cc +++ b/src/relax/transform/call_tir_rewrite.cc @@ -90,12 +90,12 @@ class CallTIRMutator : public ExprMutator { } if (!is_inplace) { - outs.push_back(builder_->Emit( - Call(alloc_tensor_op, - {output_ty->shape.value().as_or_throw(), - DataTypeImm(output_ty->dtype), PrimValue::Int64(dev_index), StringImm(scope)}, - Attrs(), {output_ty}), - "alloc")); + outs.push_back(builder_->Emit(Call(alloc_tensor_op, + {output_ty->shape.value().as_or_throw(), + DataTypeImm(output_ty->dtype->dtype), + PrimValue::Int64(dev_index), StringImm(scope)}, + Attrs(), {output_ty}), + "alloc")); } else { // if there is only one output, it must be an in-place argument, but check anyway TVM_FFI_ICHECK(inplace_attrs->inplace_indices[0] != -1) @@ -129,8 +129,8 @@ class CallTIRMutator : public ExprMutator { outs.push_back( builder_->Emit(Call(alloc_tensor_op, {field_tensor->shape.value().as_or_throw(), - DataTypeImm(field_tensor->dtype), PrimValue::Int64(dev_index), - StringImm(scope)}, + DataTypeImm(field_tensor->dtype->dtype), + PrimValue::Int64(dev_index), StringImm(scope)}, Attrs(), {field_tensor}), "alloc")); } else { diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index 1319356ee169..128202063695 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -202,7 +202,7 @@ ffi::TypedFunction(ffi::Map, ffi::Mapdtype; + DLDataType out_dtype = GetTensorType(matchings[patterns.matmul[indices[0]]])->dtype->dtype; auto matmul_combined = matmul(lhs, concat_rhs, out_dtype); if (branch_info.bias_dim) { diff --git a/src/relax/transform/compute_prim_value.cc b/src/relax/transform/compute_prim_value.cc index 4ad34d04367d..4c937fe135dc 100644 --- a/src/relax/transform/compute_prim_value.cc +++ b/src/relax/transform/compute_prim_value.cc @@ -43,11 +43,12 @@ class PrimValueComputeInjector : public ExprMutator { return node; } - auto ret_dtype = node->value->dtype; + tvm::PrimType ret_ty = node->value.ty(); auto param_vars = tirx::UndefinedVars(node->value); - tirx::Stmt body = tirx::Evaluate(tirx::Call(ret_dtype, tirx::builtin::ret(), {node->value})); + tirx::Stmt body = + tirx::Evaluate(tirx::Call(node->value.ty(), tirx::builtin::ret(), {node->value})); - tirx::PrimFunc func(param_vars, body, tvm::PrimType(ret_dtype), {}, + tirx::PrimFunc func(param_vars, body, ret_ty, {}, DictAttrs({{tirx::attr::kIsHostFunc, true}, {tvm::attr::kSTir, true}})); func = s_tir::RenewDefs(func); diff --git a/src/relax/transform/convert_layout.cc b/src/relax/transform/convert_layout.cc index ed2a9b1c8a8a..bd4631bb4cf8 100644 --- a/src/relax/transform/convert_layout.cc +++ b/src/relax/transform/convert_layout.cc @@ -102,7 +102,7 @@ class LayoutConvertMutator : public ExprMutator { ffi::Array initial_indices_expr; initial_indices.reserve(ndim); for (int i = 0; i < ndim; ++i) { - auto var = tvm::tirx::Var("i" + std::to_string(i), DataType::Int(32)); + auto var = tvm::tirx::Var("i" + std::to_string(i), PrimType::Int(32)); initial_indices.push_back(var); initial_indices_expr.push_back(var); } diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index fcedd3119599..289c1c3c3b40 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -383,7 +383,7 @@ std::unordered_set GatherCandidat const Type& result_ty) { if (auto* tensor_info = result_ty.as()) { // don't consider void dtype (don't know the size at compile time) - if (tensor_info->dtype.is_void()) { + if (tensor_info->dtype.IsVoid()) { return {}; } // don't consider cases where we don't know the shape at compile time diff --git a/src/relax/transform/decompose_ops.cc b/src/relax/transform/decompose_ops.cc index 494e4a67a4a4..156d3c278c46 100644 --- a/src/relax/transform/decompose_ops.cc +++ b/src/relax/transform/decompose_ops.cc @@ -66,7 +66,7 @@ Tuple DecomposeBatchNorm(const Call& call) { Expr moving_var = ExpandToMatchInput(call->args[4], ty->ndim, {attrs->axis}); // output = (x - mean) / sqrt(var + epsilon) * gamma + beta - Expr epsilon = MakeConstantScalar(attrs->epsilon, ty->dtype); + Expr epsilon = MakeConstantScalar(attrs->epsilon, ty->dtype->dtype); Expr sqrt_var = sqrt(add(moving_var, epsilon)); Expr out = divide(subtract(data, moving_mean), sqrt_var); @@ -103,8 +103,8 @@ Expr MutateBatchNormForTraining(Call call) { Expr data_mean = mean(data, reduce_axes, false); Expr data_var = variance(data, reduce_axes, false); - Expr momentum = MakeConstantScalar(attrs->momentum, ty->dtype); - Expr one_minus_mom = MakeConstantScalar(1 - attrs->momentum, ty->dtype); + Expr momentum = MakeConstantScalar(attrs->momentum, ty->dtype->dtype); + Expr one_minus_mom = MakeConstantScalar(1 - attrs->momentum, ty->dtype->dtype); Expr new_moving_mean = add(multiply(one_minus_mom, moving_mean), multiply(momentum, data_mean)); Expr new_moving_var = add(multiply(one_minus_mom, moving_var), multiply(momentum, data_var)); @@ -128,7 +128,7 @@ Expr DecomposeLayerNorm(const Call& call) { Expr data_var = variance(data, attrs->axes, true); // output = (x - mean) / sqrt(var + epsilon) * gamma + beta - Expr epsilon = MakeConstantScalar(attrs->epsilon, ty->dtype); + Expr epsilon = MakeConstantScalar(attrs->epsilon, ty->dtype->dtype); Expr sqrt_var = sqrt(add(data_var, epsilon)); Expr out = divide(subtract(data, data_mean), sqrt_var); @@ -159,7 +159,7 @@ Expr TensorToShape(const Call& call_node, const BlockBuilder& builder) { // ffi::Array), we define symbolic variables and returns them as a ShapeExpr. ffi::Array shape_var; for (int i = 0; i < ty->ndim; i++) { - shape_var.push_back(tirx::Var("x", DataType::Int(64))); + shape_var.push_back(tirx::Var("x", PrimType::Int(64))); } // bind symbolic variables to the shape tuple relax::Var var("y", ShapeType(shape_var)); diff --git a/src/relax/transform/expand_matmul_of_sum.cc b/src/relax/transform/expand_matmul_of_sum.cc index 1e768478fd95..9bf5fbd53b2d 100644 --- a/src/relax/transform/expand_matmul_of_sum.cc +++ b/src/relax/transform/expand_matmul_of_sum.cc @@ -88,7 +88,8 @@ std::tuple)>> rhs_b = permute_dims(rhs_b, axes); } - return add(matmul(lhs, rhs_a, DataType::Void()), matmul(lhs, rhs_b, DataType::Void())); + return add(matmul(lhs, rhs_a, (DLDataType{kDLOpaqueHandle, 0, 0})), + matmul(lhs, rhs_b, (DLDataType{kDLOpaqueHandle, 0, 0}))); }; return {pat_matmul, rewriter}; diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index d615c014709b..7c92ae49c578 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -197,7 +197,7 @@ class ConstantFolder : public ExprMutator { // Returns std::nullopt on failure. ffi::Optional ConstEvaluateCallTIR(tirx::PrimFunc tir_func, ffi::Array arr_args, ffi::Shape shape, - DataType ret_type) { + DLDataType ret_type) { // obtain function from the cache. ffi::Optional func = GetCachedBuild(tir_func); if (!func) return std::nullopt; @@ -243,7 +243,8 @@ class ConstantFolder : public ExprMutator { if (!shape) return std::nullopt; auto tensor_ty = tuple_ty->fields[i].as_or_throw(); if (tensor_ty->IsUnknownDtype()) return std::nullopt; - ret_tensors.push_back(runtime::Tensor::Empty(shape.value(), tensor_ty->dtype, cpu_dev)); + ret_tensors.push_back( + runtime::Tensor::Empty(shape.value(), tensor_ty->dtype->dtype, cpu_dev)); } // Pack input args + all output tensors. @@ -288,7 +289,8 @@ class ConstantFolder : public ExprMutator { ffi::Optional shape = MatchConstShape(call->ty_args[0]); if (shape) { TensorType ret_ty = call->ty.as_or_throw(); - return ConstEvaluateCallTIR(func.value(), arr_args.value(), shape.value(), ret_ty->dtype) + return ConstEvaluateCallTIR(func.value(), arr_args.value(), shape.value(), + ret_ty->dtype->dtype) .value_or({}); } return {}; @@ -391,7 +393,7 @@ class ConstantFolder : public ExprMutator { for (size_t i = 0; i < values.size(); i++) { PrimExpr val = values[i]; arr.push_back(val.as()->value); - is_known &= (val.dtype() == DataType::Int(64)); + is_known &= val.ty().MatchesElementType(DLDataTypeCode::kDLInt, 64); } if (is_known) { const auto func = tvm::ffi::Function::GetGlobalRequired("relax.run.shape_to_tensor"); diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index d5e656d15256..00c1029a98d1 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -60,10 +60,10 @@ class SymbolicMatcher : ExprFunctordtype + << " cannot match to argument " << other << " with dtype " << other.ty()->dtype; } else { ExprFunctor::VisitExpr(node, other); } @@ -120,9 +120,10 @@ class SymbolicMatcher : ExprFunctor(); if (!rhs) { - TVM_FFI_THROW(InternalError) << "Parameter expression " << ffi::GetRef(op) - << " expected an cast to " << op->dtype << " as the argument, " - << "but was provided with the argument " << other; + TVM_FFI_THROW(InternalError) + << "Parameter expression " << ffi::GetRef(op) << " expected an cast to " + << op->ty()->dtype << " as the argument, " + << "but was provided with the argument " << other; } VisitExpr(op->value, rhs->value); } @@ -132,10 +133,11 @@ class SymbolicMatcher : ExprFunctordtype.code() != rhs->dtype.code()) { + } else if (op->ty().code() != rhs.ty().code()) { TVM_FFI_THROW(InternalError) - << "Parameter expression " << ffi::GetRef(op) << " with dtype " << op->dtype - << " cannot match to argument " << rhs << " with dtype " << rhs.dtype(); + << "Parameter expression " << ffi::GetRef(op) << " with dtype " + << op->ty()->dtype << " cannot match to argument " << rhs << " with dtype " + << rhs.ty()->dtype; } else if (auto it = var_remap_->find(lhs); it != var_remap_->end()) { VisitExpr((*it).second, rhs); } else { @@ -592,7 +594,7 @@ class FusedTIRConstructor : public ExprVisitor { // printed, it's more readable when done explicitly. Since // Buffer is used more than param it gets the name with better // readability. - tirx::Var param = tirx::Var("p_" + buffer->name, tvm::PrimType(DataType::Handle())); + tirx::Var param = tirx::Var("p_" + buffer->name, tvm::PrimType::Handle()); func_info_.params.push_back(param); func_info_.buffer_map.Set(param, buffer); } @@ -636,8 +638,7 @@ class FusedTIRConstructor : public ExprVisitor { continue; } - tirx::Var param = - tirx::Var("p_output" + std::to_string(out_idx), tvm::PrimType(DataType::Handle())); + tirx::Var param = tirx::Var("p_output" + std::to_string(out_idx), tvm::PrimType::Handle()); out_idx++; func_info_.buffer_map.Set(param, buffers[i]); func_info_.params.push_back(param); @@ -855,9 +856,10 @@ class FusedTIRConstructor : public ExprVisitor { for (int64_t idx : output_indices) { int i = static_cast(idx); const tirx::Var& param = func->params[static_cast(i)]; - if (param->dtype.is_int() || param->dtype.is_uint()) { + tvm::PrimType param_ty = param.ty(); + if (param_ty.code() == DLDataTypeCode::kDLInt || param_ty.code() == DLDataTypeCode::kDLUInt) { if (symbolic_var_index == -1) symbolic_var_index = i; - } else if (param->dtype.is_handle()) { + } else if (param_ty.IsHandle()) { TVM_FFI_ICHECK(symbolic_var_index == -1) << "The scalar input should be at the ending of the " "parameter list."; @@ -865,7 +867,7 @@ class FusedTIRConstructor : public ExprVisitor { } else { TVM_FFI_THROW(InternalError) << "The params of PrimFunc are expected to be Buffer handle or scalar, but got: " - << param->dtype; + << param_ty->dtype; } } @@ -967,7 +969,7 @@ class FusedTIRConstructor : public ExprVisitor { // Case 1. The relax param is a Tensor, we directly create a tirx var and buffer const auto* shape_expr = tensor->shape.as(); TVM_FFI_ICHECK(shape_expr) << "FuseTIR expects all Tensor parameters have a known shape."; - DataType dtype = tensor->dtype; + DLDataType dtype = tensor->dtype->dtype; tirx::Buffer buffer; if (tir_buffer_param.defined()) { buffer = tirx::decl_buffer(shape_expr->values, dtype, name_hint, @@ -980,7 +982,7 @@ class FusedTIRConstructor : public ExprVisitor { } else if (const auto* prim_value = ty.as()) { // Case 2. The relax param is a scalar, we directly create a tirx var - out->push_back(tirx::Var(name_hint, prim_value->dtype)); + out->push_back(tirx::Var(name_hint, tvm::PrimType(prim_value->dtype))); } else if (const auto* shape_expr = ty.as()) { // Case 3. The relax param is a tuple of scalars, each represented as a tirx var @@ -1257,7 +1259,7 @@ class TIRFuseMutator : public ExprMutator { if (const auto* literal = arg.as()) { tir_vars.push_back(literal->value); } else if (const auto* var = arg.as()) { - tir_vars.push_back(tirx::Var(var->name_hint(), prim_value->dtype)); + tir_vars.push_back(tirx::Var(var->name_hint(), tvm::PrimType(prim_value->dtype))); } else { TVM_FFI_THROW(TypeError) << "FuseTIR expects scalar arguments to be PrimValue or Var, " << "but received " << arg; diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc index df22650e036d..e23524388435 100644 --- a/src/relax/transform/gradient.cc +++ b/src/relax/transform/gradient.cc @@ -304,7 +304,7 @@ class BackwardBindingGenerator : private ExprVisitor { // Initialize the adjoint of target_var as ones op. We have already checked the target. auto* target_ty = GetTypeAs(target_var); - generator.UpdateAdjoint(target_var, ones(target_ty->shape.value(), target_ty->dtype)); + generator.UpdateAdjoint(target_var, ones(target_ty->shape.value(), target_ty->dtype->dtype)); // Do reverse-mode ad, so visit bindings backwards for (auto it = forward_block->bindings.rbegin(); it != forward_block->bindings.rend(); ++it) { @@ -546,7 +546,7 @@ class BackwardBindingGenerator : private ExprVisitor { auto* tensor_ty = ty.as(); TVM_FFI_ICHECK(tensor_ty) << "The leaf of adjoint should be a Tensor."; TVM_FFI_ICHECK(tensor_ty->shape.defined()) << "Missing shape when building zeros tuple."; - const Expr& init = zeros(tensor_ty->shape.value(), tensor_ty->dtype); + const Expr& init = zeros(tensor_ty->shape.value(), tensor_ty->dtype->dtype); return init; }); return AdjointMsgToExpr(msg); @@ -707,7 +707,8 @@ class GradientMutator : private ExprMutator { static bool IsFloatTensorType(const Type& ty) { auto* tensor_ty = ty.as(); - return tensor_ty && tensor_ty->dtype.is_float(); + // Gradient eligibility preserves the old float-kind check; lanes do not affect this policy. + return tensor_ty && tensor_ty->dtype.code() == DLDataTypeCode::kDLFloat; } // When the return value is a Var, it is the target; diff --git a/src/relax/transform/infer_amp_utils.cc b/src/relax/transform/infer_amp_utils.cc index 41c6cfe5ae42..4952aeea8fa2 100644 --- a/src/relax/transform/infer_amp_utils.cc +++ b/src/relax/transform/infer_amp_utils.cc @@ -22,19 +22,19 @@ namespace tvm { namespace relax { -NType NTypeFrom(const Type& ty, DataType dtype) { +NType NTypeFrom(const Type& ty, DLDataType dtype) { auto fmapleaf = [&](const Type& ty) -> NType { const auto* tensor = ty.as(); TVM_FFI_ICHECK(tensor) << "Expected TensorType, but got " << ty; - if (dtype == DataType::Void()) - return NType(DLDataTypeToString(tensor->dtype)); + if (dtype == DLDataType{kDLOpaqueHandle, 0, 0}) + return NType(DLDataTypeToString(tensor->dtype->dtype)); else return NType(DLDataTypeToString(dtype)); }; return MapToNestedMsg(ty, fmapleaf); } -NType NTypeFrom(const Expr& expr, DataType dtype) { return NTypeFrom(GetType(expr), dtype); } +NType NTypeFrom(const Expr& expr, DLDataType dtype) { return NTypeFrom(GetType(expr), dtype); } NType NTypeMerge(const NType& a, const NType& b) { auto fcombine = [&](const ffi::String& a_str, const ffi::String& b_str) -> ffi::String { @@ -44,20 +44,20 @@ NType NTypeMerge(const NType& a, const NType& b) { return a_str; } - DataType a = DataType(ffi::StringToDLDataType(a_str)); - DataType b = DataType(ffi::StringToDLDataType(b_str)); - TVM_FFI_ICHECK_EQ(a.code(), b.code()); - TVM_FFI_ICHECK_EQ(a.lanes(), b.lanes()); - return a.bits() > b.bits() ? a_str : b_str; + DLDataType a = ffi::StringToDLDataType(a_str); + DLDataType b = ffi::StringToDLDataType(b_str); + TVM_FFI_ICHECK_EQ(a.code, b.code); + TVM_FFI_ICHECK_EQ(a.lanes, b.lanes); + return a.bits > b.bits ? a_str : b_str; }; return CombineNestedMsg(a, b, fcombine); } -ffi::Array InferMixedPrecisionFollow(const Call& call, const DataType& out_dtype) { +ffi::Array InferMixedPrecisionFollow(const Call& call, DLDataType out_dtype) { return {IntImm::Int32(MixedPrecisionPolicyKind::kFollow), call}; } -ffi::Array InferMixedPrecisionNever(const Call& call, const DataType& out_dtype) { +ffi::Array InferMixedPrecisionNever(const Call& call, DLDataType out_dtype) { return {IntImm::Int32(MixedPrecisionPolicyKind::kNever), call}; } diff --git a/src/relax/transform/infer_amp_utils.h b/src/relax/transform/infer_amp_utils.h index faa33edd4a18..7f9f884a29d0 100644 --- a/src/relax/transform/infer_amp_utils.h +++ b/src/relax/transform/infer_amp_utils.h @@ -58,10 +58,10 @@ struct NTypeEqual { }; // Construct a NType from an Type -NType NTypeFrom(const Type& ty, DataType dtype = DataType::Void()); +NType NTypeFrom(const Type& ty, DLDataType dtype = DLDataType{kDLOpaqueHandle, 0, 0}); // Construct a NType from an Expr -NType NTypeFrom(const Expr& expr, DataType dtype = DataType::Void()); +NType NTypeFrom(const Expr& expr, DLDataType dtype = DLDataType{kDLOpaqueHandle, 0, 0}); // Merge two messages, we keep the higher precision type for each leaf tensor NType NTypeMerge(const NType& a, const NType& b); @@ -70,12 +70,11 @@ NType NTypeMerge(const NType& a, const NType& b); using VarDTypeMap = std::unordered_map; // Call is a call node, out_dtype is the expected output_dtype -using FInferMixedPrecision = - ffi::TypedFunction; +using FInferMixedPrecision = ffi::TypedFunction; -ffi::Array InferMixedPrecisionFollow(const Call& call, const DataType& out_dtype); +ffi::Array InferMixedPrecisionFollow(const Call& call, DLDataType out_dtype); -ffi::Array InferMixedPrecisionNever(const Call& call, const DataType& out_dtype); +ffi::Array InferMixedPrecisionNever(const Call& call, DLDataType out_dtype); } // namespace relax } // namespace tvm diff --git a/src/relax/transform/lazy_transform_params.cc b/src/relax/transform/lazy_transform_params.cc index 7c42928d7d87..b800199610b8 100644 --- a/src/relax/transform/lazy_transform_params.cc +++ b/src/relax/transform/lazy_transform_params.cc @@ -65,8 +65,7 @@ class LazyInputMutator : public ExprMutator { param_lookup.insert({func->params[i], i - num_input_params}); } - Var fget_param("fget_param", - FuncType({PrimType(DataType::Int(64)), ObjectType()}, ObjectType())); + Var fget_param("fget_param", FuncType({PrimType::Int(64), ObjectType()}, ObjectType())); ffi::Array new_params(func->params.begin(), func->params.begin() + num_input_params); new_params.push_back(fget_param); @@ -145,7 +144,7 @@ class LazyOutputMutator : public ExprMutator { define_lookup(0, func_body->body); } - Var fset_output("fset_output", FuncType({PrimType(DataType::Int(64)), ObjectType()}, + Var fset_output("fset_output", FuncType({PrimType::Int(64), ObjectType()}, TupleType(ffi::Array{}), /* purity = */ false)); plan_ = FunctionPlan{std::move(output_lookup), fset_output}; diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index 00bd8e859ac3..2c518cfbbeae 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -282,7 +282,7 @@ class LegalizeMutator : public ExprMutator { // This fallback would only be applicable for cases where // both the dtype and the dimensionality are known. While // Relax can express a tensor with unknown dtype and - // dimensionality as `TensorType(DataType::Void(), + // dimensionality as `TensorType(DLDataType{kDLOpaqueHandle, 0, 0}, // kUnknownNDim)`, TIR cannot express unknown dtype or // unknown dimensionality. return false; diff --git a/src/relax/transform/lower_alloc_tensor.cc b/src/relax/transform/lower_alloc_tensor.cc index 66c2c95b89c2..67cbcc7e8791 100644 --- a/src/relax/transform/lower_alloc_tensor.cc +++ b/src/relax/transform/lower_alloc_tensor.cc @@ -72,7 +72,8 @@ class Mutator : public ExprMutator { }(); PrimExpr nbytes = [&]() -> PrimExpr { - PrimExpr nbytes = IntImm::Int64(dtype->value.bytes()); + PrimExpr nbytes = IntImm::Int64( + ((((dtype->value).bits * static_cast((dtype->value).lanes)) + 7) / 8)); for (const auto& dim : shape) { nbytes *= dim; } @@ -112,7 +113,7 @@ class Mutator : public ExprMutator { auto offset = PrimValue::Int64(0); Expr storage = relax::Call(mem_alloc_storage_op, {size, runtime_device_index, storage_scope, - DataTypeImm(DataType::UInt(8))}); + DataTypeImm((DLDataType{kDLUInt, 8, 1}))}); storage = builder_->Emit(storage, "storage"); Expr tensor = relax::Call(mem_alloc_tensor_op, {storage, offset, shape_arg, dtype, op->args[2]}); diff --git a/src/relax/transform/remove_unused_outputs.cc b/src/relax/transform/remove_unused_outputs.cc index 995fe019be04..f8a9e8cde70b 100644 --- a/src/relax/transform/remove_unused_outputs.cc +++ b/src/relax/transform/remove_unused_outputs.cc @@ -289,7 +289,7 @@ Pass RemoveUnusedOutputs() { // into the old tuple, but it's simpler to just let // CanonicalizeBindings and DCE handle it. new_results.push_back( - relax::PrimValue(FloatImm(DataType::Float(64), std::nan("")))); + relax::PrimValue(FloatImm(tvm::PrimType::Float(64), std::nan("")))); } } diff --git a/src/relax/transform/remove_unused_parameters.cc b/src/relax/transform/remove_unused_parameters.cc index ebe9fa000f77..4f28f9d13132 100644 --- a/src/relax/transform/remove_unused_parameters.cc +++ b/src/relax/transform/remove_unused_parameters.cc @@ -100,7 +100,7 @@ std::optional AnalyzeCallee(Function func) { } for (const auto& tir_var : free_tir_vars) { - Var relax_var("param_" + tir_var->name_hint, PrimType(tir_var.dtype())); + Var relax_var("param_" + tir_var->name_hint, PrimType(tir_var.ty())); params.push_back(relax_var); } diff --git a/src/relax/transform/reorder_take_after_matmul.cc b/src/relax/transform/reorder_take_after_matmul.cc index bd36c5cb89c5..7fd0fb7eecaa 100644 --- a/src/relax/transform/reorder_take_after_matmul.cc +++ b/src/relax/transform/reorder_take_after_matmul.cc @@ -92,7 +92,7 @@ std::tuple)>> // indices.shape = [outfeatures] // out_table.shape = [*batch, table_size] - auto out_table = matmul(lhs, weights, DataType::Void()); + auto out_table = matmul(lhs, weights, (DLDataType{kDLOpaqueHandle, 0, 0})); // new_output.shape = [*batch, outfeatures] auto new_output = take(out_table, indices, matmul_ty->ndim - 1); @@ -116,7 +116,7 @@ std::tuple)>> auto fused_weight = reshape(reordered_weight, ShapeExpr({weight_shape[1], weight_shape[0] * weight_shape[2]})); // fused_output.shape = [batch1, batch2, table_size * outfeatures] - auto fused_output = matmul(lhs, fused_weight, DataType::Void()); + auto fused_output = matmul(lhs, fused_weight, (DLDataType{kDLOpaqueHandle, 0, 0})); // indexed_output.shape = [batch1, batch2, table_size, outfeatures] auto indexed_output = reshape( fused_output, ShapeExpr({lhs_shape[0], lhs_shape[1], weight_shape[0], weight_shape[2]})); diff --git a/src/relax/transform/split_call_tir_by_pattern.cc b/src/relax/transform/split_call_tir_by_pattern.cc index 4d15c0fd88f5..19e0dfdf8f00 100644 --- a/src/relax/transform/split_call_tir_by_pattern.cc +++ b/src/relax/transform/split_call_tir_by_pattern.cc @@ -129,7 +129,7 @@ class ForMatcher : public TensorizeComparator { if (match) { evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); evaluated_symbols.back()[ffi::GetRef(operand_a)] = - MakeConstScalar(rhs_ptr->b.dtype(), 1); + MakeConstScalar(rhs_ptr->b.ty(), 1); return true; } } @@ -142,7 +142,7 @@ class ForMatcher : public TensorizeComparator { if (match) { evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); evaluated_symbols.back()[ffi::GetRef(operand_b)] = - MakeConstScalar(rhs_ptr->a.dtype(), 1); + MakeConstScalar(rhs_ptr->a.ty(), 1); return true; } } @@ -160,7 +160,7 @@ class ForMatcher : public TensorizeComparator { if (match) { evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); evaluated_symbols.back()[ffi::GetRef(operand_a)] = - MakeConstScalar(rhs_ptr->b.dtype(), 0); + MakeConstScalar(rhs_ptr->b.ty(), 0); return true; } } @@ -173,7 +173,7 @@ class ForMatcher : public TensorizeComparator { if (match) { evaluated_symbols.back().insert(symbol_map.begin(), symbol_map.end()); evaluated_symbols.back()[ffi::GetRef(operand_b)] = - MakeConstScalar(rhs_ptr->a.dtype(), 0); + MakeConstScalar(rhs_ptr->a.ty(), 0); return true; } } @@ -622,7 +622,7 @@ std::pair> SplitFunctions( } } arg_partition->push_back(arg_partition1); - new_params1.push_back(Var("output", DataType::Handle())); + new_params1.push_back(Var("output", PrimType::Handle())); ffi::Map new_buffer_map1; for (const auto& kv : func->buffer_map) { if (partitioner.input1.count(kv.second)) { @@ -635,7 +635,7 @@ std::pair> SplitFunctions( // Step 4. Craft the second function. ffi::Array new_params2; std::vector arg_partition2; - new_params2.push_back(Var("input", DataType::Handle())); + new_params2.push_back(Var("input", PrimType::Handle())); for (int i = 0; i < static_cast(func->params.size()); i++) { Var param = func->params[i]; if (partitioner.input2.count(func->buffer_map[param])) { @@ -752,7 +752,7 @@ class SplitMutator : public ExprMutator { TVM_FFI_ICHECK(lib_func->IsInstance()); builder_->UpdateFunction(gv, lib_func); tirx::Buffer intermediate_buffer = func1->buffer_map.at(func1->params.back()); - DataType dtype = intermediate_buffer->dtype; + PrimType dtype = intermediate_buffer->dtype; Call call1(call_dps_packed_, {lib_func, Tuple(args1)}, call->attrs, {TensorType(ShapeExpr(intermediate_buffer->shape), dtype)}); Var call_var1 = builder_->Emit(call1); diff --git a/src/relax/transform/split_layout_rewrite_preproc.cc b/src/relax/transform/split_layout_rewrite_preproc.cc index 0560582fac59..e09e377e8a70 100644 --- a/src/relax/transform/split_layout_rewrite_preproc.cc +++ b/src/relax/transform/split_layout_rewrite_preproc.cc @@ -65,11 +65,11 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { ffi::Map buffer_map; for (const auto& info : rewrite_infos_) { - params.push_back(Var(info.pre_rewrite_buffer->name, DataType::Handle())); + params.push_back(Var(info.pre_rewrite_buffer->name, PrimType::Handle())); buffer_map.Set(params.back(), info.pre_rewrite_buffer); } for (const auto& info : rewrite_infos_) { - params.push_back(Var(info.post_rewrite_buffer->name, DataType::Handle())); + params.push_back(Var(info.post_rewrite_buffer->name, PrimType::Handle())); buffer_map.Set(params.back(), info.post_rewrite_buffer); } diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index 651b70961090..3d4fcb256d0e 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -106,7 +106,7 @@ class StorageTokenNode : public ffi::Object { /*! \brief Number of bytes that this token requires. */ PrimExpr bytes; /*! \brief The dtype of this token. */ - DataType dtype; + DLDataType dtype; /*! \brief The memory scope of the token. */ std::string storage_scope; /*! \brief The VDevice information. */ @@ -135,10 +135,10 @@ class StorageTokenNode : public ffi::Object { */ class StorageToken : public ffi::ObjectRef { public: - explicit StorageToken(ffi::Array shape, DataType dtype, std::string storage_scope, + explicit StorageToken(ffi::Array shape, DLDataType dtype, std::string storage_scope, ffi::Optional vdevice = std::nullopt) { // Compute the tensor size from the shape. - int64_t const_coeff = dtype.bytes() * dtype.lanes(); + int64_t const_coeff = ((((dtype).bits * static_cast((dtype).lanes)) + 7) / 8); PrimExpr size = IntImm::Int64(1); bool size_computed = false; @@ -303,13 +303,16 @@ class TokenAllocatorMixed { } private: - /*! \brief The hash class to enable std::pair as map key class. */ - struct PairHash { - template - std::size_t operator()(const std::pair& p) const { - auto h1 = std::hash{}(p.first); - auto h2 = std::hash{}(p.second); - return h1 ^ h2; + using PoolKey = std::pair; + + /*! \brief The hash class to enable storage scope and raw dtype as map key class. */ + struct PoolKeyHash { + std::size_t operator()(const PoolKey& p) const { + std::size_t h = std::hash{}(p.first); + h ^= static_cast(p.second.code) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= static_cast(p.second.bits) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= static_cast(p.second.lanes) + 0x9e3779b9 + (h << 6) + (h >> 2); + return h; } }; @@ -318,9 +321,7 @@ class TokenAllocatorMixed { /*! \brief A constant scale representing the token search range. */ const int match_range_{16}; /*! \brief The pool of available storage tokens for each storage scope and dtype. */ - std::unordered_map, std::multimap, - PairHash> - available_pool_; + std::unordered_map, PoolKeyHash> available_pool_; /*! \brief All the storage tokens that have been allocated with actual storage. */ std::vector full_pool_; }; @@ -636,7 +637,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { const auto* shape = ty->shape.as(); TVM_FFI_ICHECK_NOTNULL(shape); TVM_FFI_ICHECK(!ty->IsUnknownDtype()); - TVM_FFI_ICHECK(ty->dtype == call->args[1].as_or_throw()->value); + TVM_FFI_ICHECK(ty->dtype->dtype == call->args[1].as_or_throw()->value); TVM_FFI_ICHECK(!token_map_.count(call)); // Use the upper bounds of TIR vars as their values. The upper bound shape can still be dynamic @@ -653,7 +654,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { } ffi::Optional vdevice = GetGlobalVDevice(ctx_mod_, vdevice_index); - StorageToken token(upper_bounded_shape, ty->dtype, storage_scope->value, vdevice); + StorageToken token(upper_bounded_shape, ty->dtype->dtype, storage_scope->value, vdevice); Tokens tokens(token); SetTokens(call, tokens); @@ -938,7 +939,7 @@ class StorageAllocationRewriter : public ExprMutator { if (it_token == token2storage_var_.end()) { ShapeExpr size({token->bytes}); PrimValue virtual_device_index = runtime_device_index; - DataType dtype = token->dtype; + DLDataType dtype = token->dtype; Call alloc_storage(mem_alloc_storage, {std::move(size), virtual_device_index, StringImm(token->storage_scope), DataTypeImm(dtype)}, @@ -951,7 +952,7 @@ class StorageAllocationRewriter : public ExprMutator { // And always create a `memory.alloc_tensor` for the old `builtin.alloc_tensor`. PrimValue offset = PrimValue::Int64(0); - DataType dtype = ty->dtype; + DLDataType dtype = ty->dtype->dtype; return Call(mem_alloc_tensor, {storage_var, offset, ty->shape.value(), DataTypeImm(dtype), call->args[2]}, Attrs()); @@ -970,22 +971,23 @@ class StorageAllocationRewriter : public ExprMutator { GetUpperBoundShape(shape->values, ana_.get(), dom_map_); if (!IsStaticShape(shape->values)) { TVM_FFI_ICHECK(!ty->IsUnknownDtype()); - TVM_FFI_ICHECK_EQ(ty->dtype, call->args[1].as_or_throw()->value); + TVM_FFI_ICHECK_EQ(ty->dtype->dtype, call->args[1].as_or_throw()->value); PrimExpr bytes = upper_bounded_shape[0]; for (int i = 1; i < static_cast(upper_bounded_shape.size()); ++i) { bytes *= upper_bounded_shape[i]; } - bytes *= ty->dtype.bytes() * ty->dtype.lanes(); + DLDataType dtype = ty->dtype->dtype; + bytes *= ((((dtype).bits * static_cast((dtype).lanes)) + 7) / 8); Call alloc_storage(mem_alloc_storage, {/*size=*/ShapeExpr({bytes}), /*virtual_device_index=*/call->args[2].as_or_throw(), /*storage_scope=*/call->args[3].as_or_throw(), // - /*dtype=*/DataTypeImm(ty->dtype)}); + /*dtype=*/DataTypeImm(dtype)}); Var storage = builder_->Emit(alloc_storage, "storage"); return Call(mem_alloc_tensor, {storage, // /*offset=*/PrimValue::Int64(0), /*shape=*/ffi::GetRef(shape), // - /*dtype=*/DataTypeImm(ty->dtype), + /*dtype=*/DataTypeImm(dtype), /*vdevice_index=*/call->args[2]}); } } @@ -1040,7 +1042,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.transform.StaticPlanBlockMemory", StaticPlanBlockMemory); } -PrimExpr GetTextureMemorySizeFromVDevice(ffi::Array pshape, DataType dtype, +PrimExpr GetTextureMemorySizeFromVDevice(ffi::Array pshape, DLDataType dtype, VDevice vdevice) { int image_row_align = static_cast( vdevice->target->GetAttr("image_base_address_alignment").value_or(64)); @@ -1056,7 +1058,9 @@ PrimExpr GetTextureMemorySizeFromVDevice(ffi::Array pshape, DataType d }; auto shape = Shape{pshape}; - size_t size = runtime::GetTextureMemorySize(shape, dtype.bytes() * 8, dtype.lanes(), + int lanes = static_cast(dtype.lanes); + TVM_FFI_ICHECK_GE(lanes, 0) << "Can't fetch the bytes of a scalable vector at a compile time."; + size_t size = runtime::GetTextureMemorySize(shape, dtype.bits, lanes, vdevice->memory_scope, image_row_align); return IntImm::Int64(size); } diff --git a/src/relax/transform/to_mixed_precision.cc b/src/relax/transform/to_mixed_precision.cc index ddd23ce2ea7b..45d2af9b8579 100644 --- a/src/relax/transform/to_mixed_precision.cc +++ b/src/relax/transform/to_mixed_precision.cc @@ -116,9 +116,9 @@ int GetMixedPrecisionInfo(const CallNode* call_node) { */ class DTypeDecisionCollector : public ExprVisitor { public: - explicit DTypeDecisionCollector(DataType output_dtype) : output_dtype_(output_dtype) {} + explicit DTypeDecisionCollector(DLDataType output_dtype) : output_dtype_(output_dtype) {} - static VarDTypeMap Collect(Function func, DataType output_dtype) { + static VarDTypeMap Collect(Function func, DLDataType output_dtype) { DTypeDecisionCollector collector(output_dtype); collector.VisitExpr(func); return std::move(collector.only_fp16_map_); @@ -165,7 +165,7 @@ class DTypeDecisionCollector : public ExprVisitor { } // merge the message for all vars in the expr list - void RequireArgsToType(ffi::Array args, DataType to) { + void RequireArgsToType(ffi::Array args, DLDataType to) { std::vector arg_arr; std::vector to_arr; for (const Expr& arg : args) { @@ -262,16 +262,16 @@ class DTypeDecisionCollector : public ExprVisitor { } } - DataType unknown_ = DataType(DataType::TypeCode::kFloat, 0, 1); - DataType fp16_ = DataType(DataType::TypeCode::kFloat, 16, 1); - DataType fp32_ = DataType(DataType::TypeCode::kFloat, 32, 1); - DataType output_dtype_; + DLDataType unknown_ = DLDataType{kDLFloat, 0, 1}; + DLDataType fp16_ = DLDataType{kDLFloat, 16, 1}; + DLDataType fp32_ = DLDataType{kDLFloat, 32, 1}; + DLDataType output_dtype_; VarDTypeMap only_fp16_map_; }; class ToMixedPrecisionRewriter : public ExprMutator { public: - explicit ToMixedPrecisionRewriter(const VarDTypeMap* only_fp16_map, DataType output_dtype, + explicit ToMixedPrecisionRewriter(const VarDTypeMap* only_fp16_map, DLDataType output_dtype, const std::unordered_set& fp16_input_names) : only_fp16_map_(only_fp16_map), output_dtype_(output_dtype), @@ -290,7 +290,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { if (tensor_ty->vdevice.defined()) { vdev = tensor_ty->vdevice.value(); } - TensorType fp16_ty(tensor_ty->shape.value(), DataType::Float(16), vdev, tensor_ty->span); + TensorType fp16_ty(tensor_ty->shape.value(), PrimType::Float(16), vdev, tensor_ty->span); Var fp16_var(var->vid, fp16_ty, var->span); var_remap_[var->vid] = fp16_var; return fp16_var; @@ -315,13 +315,14 @@ class ToMixedPrecisionRewriter : public ExprMutator { if (NTypeEqual()(to[0], NTypeFrom(expr))) return expr; // We only rewrite the expr if the dtype is fp16 or fp32, dtypes such as int32, float64 is not // supported to be rewritten - if (tensor->dtype != fp16_ && tensor->dtype != fp32_) return expr; - return astype(expr, DataType(ffi::StringToDLDataType(to[0].LeafValue()))); + DLDataType tensor_dtype = tensor->dtype->dtype; + if (tensor_dtype != fp16_ && tensor_dtype != fp32_) return expr; + return astype(expr, ffi::StringToDLDataType(to[0].LeafValue())); }; return TransformTupleLeaf(expr, std::array({to}), fvisitleaf); } - ffi::Array RewriteArgs(const ffi::Array& args, DataType to) { + ffi::Array RewriteArgs(const ffi::Array& args, DLDataType to) { ffi::Array new_args; for (const Expr& arg : args) { if (IsNestedTensor(arg)) { @@ -346,7 +347,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { bool AllFP16Castable(const ffi::Array& args) { auto is_fp16 = [](Type ty) { if (auto tensor_ty = ty.as(); - tensor_ty && tensor_ty->dtype == DataType::Float(16)) { + tensor_ty && tensor_ty->dtype == PrimType::Float(16)) { return true; } return false; @@ -359,7 +360,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { return false; } - if (data.DataType() == DataType::Float(16)) { + if (data.DataType() == DLDataType{kDLFloat, 16, 1}) { return true; } @@ -372,17 +373,17 @@ class ToMixedPrecisionRewriter : public ExprMutator { std::vector bytes(size_1d * elem_bytes); data.CopyToBytes(bytes.data(), bytes.size()); - if (data.DataType() == DataType::Float(32)) { + if (data.DataType() == DLDataType{kDLFloat, 32, 1}) { return CheckInFP16Range(bytes, size_1d); - } else if (data.DataType() == DataType::Float(64)) { + } else if (data.DataType() == DLDataType{kDLFloat, 64, 1}) { return CheckInFP16Range(bytes, size_1d); - } else if (data.DataType() == DataType::Int(8)) { + } else if (data.DataType() == DLDataType{kDLInt, 8, 1}) { return CheckInFP16Range(bytes, size_1d); - } else if (data.DataType() == DataType::Int(16)) { + } else if (data.DataType() == DLDataType{kDLInt, 16, 1}) { return CheckInFP16Range(bytes, size_1d); - } else if (data.DataType() == DataType::Int(32)) { + } else if (data.DataType() == DLDataType{kDLInt, 32, 1}) { return CheckInFP16Range(bytes, size_1d); - } else if (data.DataType() == DataType::Int(64)) { + } else if (data.DataType() == DLDataType{kDLInt, 64, 1}) { return CheckInFP16Range(bytes, size_1d); } return false; @@ -476,7 +477,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { new_call.CopyOnWrite()->args = RemapArgs(new_call->args); // Then we rewrite the args according to the policy - std::optional opt_new_dtype = std::nullopt; + std::optional opt_new_dtype = std::nullopt; if (policy == kAlways) { opt_new_dtype = fp16_; @@ -589,16 +590,16 @@ class ToMixedPrecisionRewriter : public ExprMutator { const VarDTypeMap* only_fp16_map_; - DataType fp16_ = DataType(DataType::TypeCode::kFloat, 16, 1); - DataType fp32_ = DataType(DataType::TypeCode::kFloat, 32, 1); - DataType output_dtype_; + DLDataType fp16_ = DLDataType{kDLFloat, 16, 1}; + DLDataType fp32_ = DLDataType{kDLFloat, 32, 1}; + DLDataType output_dtype_; ffi::Array params_; std::unordered_set fp16_input_names_; const Op& wrap_param_op = Op::Get("relax.wrap_param"); }; -Expr ToMixedPrecision(const Function& f, const DataType& out_dtype, +Expr ToMixedPrecision(const Function& f, DLDataType out_dtype, ffi::Optional> fp16_input_names) { VarDTypeMap only_fp16_map = DTypeDecisionCollector::Collect(f, out_dtype); std::unordered_set fp16_input_names_set; @@ -611,7 +612,7 @@ Expr ToMixedPrecision(const Function& f, const DataType& out_dtype, namespace transform { -Pass ToMixedPrecision(const DataType& out_dtype, +Pass ToMixedPrecision(DLDataType out_dtype, ffi::Optional> fp16_input_names) { auto pass_func = [=](Function f, IRModule m, PassContext pc) { return ToMixedPrecision(f, out_dtype, fp16_input_names).as_or_throw(); diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h index 275c7ca94f8d..0dd6aa6e54a6 100644 --- a/src/relax/transform/utils.h +++ b/src/relax/transform/utils.h @@ -319,39 +319,39 @@ class FunctionCopier : public SymbolicVarRenewMutator { * \return A Constant. */ template -inline Constant MakeConstantScalar(T value, DataType dtype) { +inline Constant MakeConstantScalar(T value, DLDataType dtype) { runtime::Tensor arr = runtime::Tensor::Empty({}, dtype, {kDLCPU, 0}); - if (dtype == DataType::Float(32)) { + if (dtype == DLDataType{kDLFloat, 32, 1}) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::Float(64)) { + } else if (dtype == DLDataType{kDLFloat, 64, 1}) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::Int(32)) { + } else if (dtype == DLDataType{kDLInt, 32, 1}) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::Int(64)) { + } else if (dtype == DLDataType{kDLInt, 64, 1}) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::Bool()) { + } else if (dtype == DLDataType{kDLBool, 1, 1}) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::UInt(8)) { + } else if (dtype == DLDataType{kDLUInt, 8, 1}) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::UInt(16)) { + } else if (dtype == DLDataType{kDLUInt, 16, 1}) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::UInt(32)) { + } else if (dtype == DLDataType{kDLUInt, 32, 1}) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::UInt(64)) { + } else if (dtype == DLDataType{kDLUInt, 64, 1}) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::Int(8)) { + } else if (dtype == DLDataType{kDLInt, 8, 1}) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::Int(16)) { + } else if (dtype == DLDataType{kDLInt, 16, 1}) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::Int(32)) { + } else if (dtype == DLDataType{kDLInt, 32, 1}) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::Int(64)) { + } else if (dtype == DLDataType{kDLInt, 64, 1}) { *static_cast(arr->data) = static_cast(value); - } else if (dtype == DataType::Float(16)) { + } else if (dtype == DLDataType{kDLFloat, 16, 1}) { // convert to float16 storage is uint16_t *static_cast(arr->data) = __truncXfYf2__(static_cast(value)); - } else if (dtype == DataType::BFloat(16)) { + } else if (dtype == DLDataType{kDLBfloat, 16, 1}) { // convert to bfloat16 storage is uint16_t *static_cast(arr->data) = __truncXfYf2__(static_cast(value)); diff --git a/src/relax/utils.cc b/src/relax/utils.cc index 370947e4b01f..2f5cc6d9dea8 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -179,11 +179,11 @@ tvm::ffi::Map InferSymbolicVarMap( } bool IsBoolType(const Type& ty, bool permit_unknown_rank, bool permit_unknown_dtype) { - DataType dtype; + DLDataType dtype; int ndim; if (const auto* tensor = ty.as()) { - dtype = tensor->dtype; + dtype = tensor->dtype->dtype; ndim = tensor->ndim; } else if (const auto* prim = ty.as()) { dtype = prim->dtype; @@ -192,7 +192,9 @@ bool IsBoolType(const Type& ty, bool permit_unknown_rank, bool permit_unknown_dt return false; } - bool correct_dtype = dtype.is_bool() || (permit_unknown_dtype && dtype.is_void()); + // Bool-type matching preserves the old element-code-only behavior; rank is checked separately. + bool correct_dtype = dtype.code == DLDataTypeCode::kDLBool || + (permit_unknown_dtype && dtype == DLDataType{kDLOpaqueHandle, 0, 0}); bool correct_rank = ndim == 0 || (permit_unknown_rank && ndim == -1); return correct_dtype && correct_rank; } diff --git a/src/runtime/extra/contrib/cblas/cblas.cc b/src/runtime/extra/contrib/cblas/cblas.cc index d71eaeb17672..aae0a5acce1c 100644 --- a/src/runtime/extra/contrib/cblas/cblas.cc +++ b/src/runtime/extra/contrib/cblas/cblas.cc @@ -21,10 +21,10 @@ * \file Use external cblas library call. */ #include +#include #include #include #include -#include extern "C" { #include diff --git a/src/runtime/extra/contrib/cblas/dnnl_blas.cc b/src/runtime/extra/contrib/cblas/dnnl_blas.cc index 08d72e57b7ad..c267a37aa58e 100644 --- a/src/runtime/extra/contrib/cblas/dnnl_blas.cc +++ b/src/runtime/extra/contrib/cblas/dnnl_blas.cc @@ -21,10 +21,10 @@ * \file Use external cblas library call. */ #include +#include #include #include #include -#include extern "C" { #include diff --git a/src/runtime/extra/contrib/cblas/gemm_common.h b/src/runtime/extra/contrib/cblas/gemm_common.h index 52f306e86238..65b13aa4c728 100644 --- a/src/runtime/extra/contrib/cblas/gemm_common.h +++ b/src/runtime/extra/contrib/cblas/gemm_common.h @@ -26,8 +26,8 @@ #define TVM_RUNTIME_CONTRIB_CBLAS_GEMM_COMMON_H_ #include +#include #include -#include #include #include @@ -37,7 +37,6 @@ namespace contrib { using ffi::Any; using ffi::PackedArgs; -using runtime::TypeMatch; inline int ColumnStride(const DLTensor* tensor) { // If the tensor itself is transposed then it will have strides @@ -96,8 +95,8 @@ inline void CallGemm(ffi::PackedArgs args, ffi::Any* ret, TGemmOp op) { transa = IsInPlaceTransposed(A) ? !transa : transa; transb = IsInPlaceTransposed(B) ? !transb : transb; - TVM_FFI_ICHECK(TypeMatch(B->dtype, kDLFloat, bit_depth)); - TVM_FFI_ICHECK(TypeMatch(C->dtype, kDLFloat, bit_depth)); + TVM_FFI_ICHECK((B->dtype == DLDataType{kDLFloat, static_cast(bit_depth), 1})); + TVM_FFI_ICHECK((C->dtype == DLDataType{kDLFloat, static_cast(bit_depth), 1})); double alpha = args.size() > 5 ? args[5].cast() : 1.0; double beta = args.size() > 6 ? args[6].cast() : 0.0; op(transb, transa, ColumnCount(B, transb), RowCount(A, transa), ColumnCount(A, transa), @@ -143,9 +142,9 @@ inline void CallU8S8S32Gemm(ffi::PackedArgs args, ffi::Any* ret, TGemmOp op) { transa = IsInPlaceTransposed(A) ? !transa : transa; transb = IsInPlaceTransposed(B) ? !transb : transb; - TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLUInt, 8)); - TVM_FFI_ICHECK(TypeMatch(B->dtype, kDLInt, 8)); - TVM_FFI_ICHECK(TypeMatch(C->dtype, kDLInt, 32)); + TVM_FFI_ICHECK((A->dtype == DLDataType{kDLUInt, 8, 1})); + TVM_FFI_ICHECK((B->dtype == DLDataType{kDLInt, 8, 1})); + TVM_FFI_ICHECK((C->dtype == DLDataType{kDLInt, 32, 1})); double alpha = args.size() > 5 ? args[5].cast() : 1.0; double beta = args.size() > 6 ? args[6].cast() : 0.0; op(transb, transa, ColumnCount(B, transb), RowCount(A, transa), ColumnCount(A, transa), @@ -207,8 +206,8 @@ inline void CallBatchGemm(ffi::PackedArgs args, ffi::Any* ret, TBatchGemmOp op) transa = IsInPlaceTransposed3D(A) ? !transa : transa; transb = IsInPlaceTransposed3D(B) ? !transb : transb; - TVM_FFI_ICHECK(TypeMatch(B->dtype, kDLFloat, bit_depth)); - TVM_FFI_ICHECK(TypeMatch(C->dtype, kDLFloat, bit_depth)); + TVM_FFI_ICHECK((B->dtype == DLDataType{kDLFloat, static_cast(bit_depth), 1})); + TVM_FFI_ICHECK((C->dtype == DLDataType{kDLFloat, static_cast(bit_depth), 1})); double alpha = args.size() > 5 ? args[5].cast() : 1.0; double beta = args.size() > 6 ? args[6].cast() : 0.0; diff --git a/src/runtime/extra/contrib/cblas/mkl.cc b/src/runtime/extra/contrib/cblas/mkl.cc index 20f0c539076b..f039df8e676f 100644 --- a/src/runtime/extra/contrib/cblas/mkl.cc +++ b/src/runtime/extra/contrib/cblas/mkl.cc @@ -21,10 +21,10 @@ * \file Use external mkl library call. */ #include +#include #include #include #include -#include extern "C" { #include diff --git a/src/runtime/extra/contrib/coreml/coreml_runtime.mm b/src/runtime/extra/contrib/coreml/coreml_runtime.mm index a72948b250a7..d9823407fb0a 100644 --- a/src/runtime/extra/contrib/coreml/coreml_runtime.mm +++ b/src/runtime/extra/contrib/coreml/coreml_runtime.mm @@ -44,15 +44,15 @@ [shape addObject:[NSNumber numberWithInteger:data_in->shape[i]]]; } - DataType dtype(data_in->dtype); + DLDataType dtype = data_in->dtype; MLMultiArrayDataType dataType; - if (dtype == DataType::Float(64)) { + if (dtype == DLDataType{kDLFloat, 64, 1}) { dataType = MLMultiArrayDataTypeDouble; size *= sizeof(double); - } else if (dtype == DataType::Float(32)) { + } else if (dtype == DLDataType{kDLFloat, 32, 1}) { dataType = MLMultiArrayDataTypeFloat32; size *= sizeof(float); - } else if (dtype == DataType::Int(32)) { + } else if (dtype == DLDataType{kDLInt, 32, 1}) { dataType = MLMultiArrayDataTypeInt32; size *= sizeof(int); } else { @@ -87,15 +87,15 @@ shape.push_back(n); } - DataType dtype; + DLDataType dtype = DLDataType{kDLOpaqueHandle, 0, 0}; if (data_desc.dataType == MLMultiArrayDataTypeDouble) { - dtype = DataType::Float(64); + dtype = DLDataType{kDLFloat, 64, 1}; size *= sizeof(double); } else if (data_desc.dataType == MLMultiArrayDataTypeFloat32) { - dtype = DataType::Float(32); + dtype = DLDataType{kDLFloat, 32, 1}; size *= sizeof(float); } else if (data_desc.dataType == MLMultiArrayDataTypeInt32) { - dtype = DataType::Int(32); + dtype = DLDataType{kDLInt, 32, 1}; size *= sizeof(int); } else { LOG(FATAL) << "unexpected data type " << data_desc.dataType; diff --git a/src/runtime/extra/contrib/cublas/cublas.cc b/src/runtime/extra/contrib/cublas/cublas.cc index 4ef1b702c16c..f114cfa6e939 100644 --- a/src/runtime/extra/contrib/cublas/cublas.cc +++ b/src/runtime/extra/contrib/cublas/cublas.cc @@ -21,11 +21,11 @@ * \file Use external cblas library call. */ #include +#include #include #include #include #include -#include #include "../../../../../3rdparty/compiler-rt/builtin_fp16.h" #include "../cblas/gemm_common.h" @@ -170,9 +170,9 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, ab_type = CUDA_R_16BF; } else if (TypeMatch(A->dtype, kDLInt, 8)) { ab_type = CUDA_R_8I; - } else if (TypeMatch(A->dtype, DataType::TypeCode::kFloat8_e4m3fn, 8)) { + } else if (TypeMatch(A->dtype, kDLFloat8_e4m3fn, 8)) { #if CUDART_VERSION >= 11080 - TVM_FFI_ICHECK(TypeMatch(B->dtype, DataType::TypeCode::kFloat8_e4m3fn, 8)); + TVM_FFI_ICHECK(TypeMatch(B->dtype, kDLFloat8_e4m3fn, 8)); ab_type = CUDA_R_8F_E4M3; #else TVM_FFI_THROW(InternalError) << "Float8 (E4M3) is only supported in CUDA 11.8 and above."; diff --git a/src/runtime/extra/contrib/cudnn/conv_backward.cc b/src/runtime/extra/contrib/cudnn/conv_backward.cc index df3d7c8e6ff7..97832248fe53 100644 --- a/src/runtime/extra/contrib/cudnn/conv_backward.cc +++ b/src/runtime/extra/contrib/cudnn/conv_backward.cc @@ -21,9 +21,9 @@ * \file cuDNN kernel calls for backward algorithms. */ #include +#include #include #include -#include #include #include diff --git a/src/runtime/extra/contrib/cudnn/conv_forward.cc b/src/runtime/extra/contrib/cudnn/conv_forward.cc index 3a573297f29e..b7257d35f2b5 100644 --- a/src/runtime/extra/contrib/cudnn/conv_forward.cc +++ b/src/runtime/extra/contrib/cudnn/conv_forward.cc @@ -21,9 +21,9 @@ * \file cuDNN kernel calls for the forward algorithm. */ #include +#include #include #include -#include #include #include diff --git a/src/runtime/extra/contrib/cudnn/cudnn_utils.cc b/src/runtime/extra/contrib/cudnn/cudnn_utils.cc index 5c34d4a2b0a6..3edb20dbacbc 100644 --- a/src/runtime/extra/contrib/cudnn/cudnn_utils.cc +++ b/src/runtime/extra/contrib/cudnn/cudnn_utils.cc @@ -23,10 +23,10 @@ #include "cudnn_utils.h" +#include #include #include #include -#include #include #include diff --git a/src/runtime/extra/contrib/cutlass/fp16_group_gemm.cuh b/src/runtime/extra/contrib/cutlass/fp16_group_gemm.cuh index 35c4a5767236..85653222169b 100644 --- a/src/runtime/extra/contrib/cutlass/fp16_group_gemm.cuh +++ b/src/runtime/extra/contrib/cutlass/fp16_group_gemm.cuh @@ -49,17 +49,17 @@ void tvm_cutlass_group_gemm_impl(Tensor x, Tensor weight, Tensor indptr, Tensor float alpha = 1.0f; float beta = 0.0f; - if (DataType(x->dtype) == DataType::Float(16)) { - TVM_FFI_ICHECK(DataType(weight->dtype) == DataType::Float(16)); - TVM_FFI_ICHECK(DataType(out->dtype) == DataType::Float(16)); + if (x->dtype == DLDataType{kDLFloat, 16, 1}) { + TVM_FFI_ICHECK((weight->dtype == DLDataType{kDLFloat, 16, 1})); + TVM_FFI_ICHECK((out->dtype == DLDataType{kDLFloat, 16, 1})); using Dtype = cutlass::half_t; CutlassGroupGemm::run( static_cast(x->data), static_cast(weight->data), static_cast(indptr->data), static_cast(workspace->data), workspace->shape[0], n, k, num_groups, alpha, beta, static_cast(out->data), stream); - } else if (DataType(x->dtype) == DataType::BFloat(16)) { - TVM_FFI_ICHECK(DataType(weight->dtype) == DataType::BFloat(16)); - TVM_FFI_ICHECK(DataType(out->dtype) == DataType::BFloat(16)); + } else if (x->dtype == DLDataType{kDLBfloat, 16, 1}) { + TVM_FFI_ICHECK((weight->dtype == DLDataType{kDLBfloat, 16, 1})); + TVM_FFI_ICHECK((out->dtype == DLDataType{kDLBfloat, 16, 1})); using Dtype = cutlass::bfloat16_t; CutlassGroupGemm::run( static_cast(x->data), static_cast(weight->data), diff --git a/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh b/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh index db88ec0faaed..1af60af4da3a 100644 --- a/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh +++ b/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh @@ -66,14 +66,15 @@ void tvm_cutlass_fp8_groupwise_scaled_gemm_impl(Tensor a, Tensor b, Tensor scale TVM_FFI_ICHECK_EQ((n + block_size_0 - 1) / block_size_0, scales_b->shape[0]); TVM_FFI_ICHECK_EQ(scales_b->shape[1] * block_size_1, k); - using tvm::runtime::DataType; - TVM_FFI_ICHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN()); - TVM_FFI_ICHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN()); - TVM_FFI_ICHECK_EQ(DataType(scales_a->dtype), DataType::Float(32)); - TVM_FFI_ICHECK_EQ(DataType(scales_b->dtype), DataType::Float(32)); - TVM_FFI_ICHECK_EQ(DataType(workspace->dtype), DataType::UInt(8)); + TVM_FFI_ICHECK_EQ(a->dtype, DLDataType{kDLFloat8_e4m3fn, 8, 1}); + TVM_FFI_ICHECK_EQ(b->dtype, DLDataType{kDLFloat8_e4m3fn, 8, 1}); + TVM_FFI_ICHECK_EQ(scales_a->dtype, DLDataType{kDLFloat, 32, 1}); + TVM_FFI_ICHECK_EQ(scales_b->dtype, DLDataType{kDLFloat, 32, 1}); + TVM_FFI_ICHECK_EQ(workspace->dtype, DLDataType{kDLUInt, 8, 1}); + int64_t workspace_nbytes = + workspace->shape[0] * ((workspace->dtype.bits * workspace->dtype.lanes + 7) / 8); - if (DataType(out->dtype) == DataType::Float(16)) { + if (out->dtype == DLDataType{kDLFloat, 16, 1}) { CutlassFP8GroupwiseGemm::run(static_cast(a->data), @@ -81,10 +82,9 @@ void tvm_cutlass_fp8_groupwise_scaled_gemm_impl(Tensor a, Tensor b, Tensor scale static_cast(scales_a->data), static_cast(scales_b->data), static_cast(out->data), - static_cast(workspace->data), - workspace->shape[0] * DataType(workspace->dtype).bytes(), m, + static_cast(workspace->data), workspace_nbytes, m, n, k, 1, stream); - } else if (DataType(out->dtype) == DataType::BFloat(16)) { + } else if (out->dtype == DLDataType{kDLBfloat, 16, 1}) { CutlassFP8GroupwiseGemm::run(static_cast(a->data), @@ -92,11 +92,10 @@ void tvm_cutlass_fp8_groupwise_scaled_gemm_impl(Tensor a, Tensor b, Tensor scale static_cast(scales_a->data), static_cast(scales_b->data), static_cast(out->data), - static_cast(workspace->data), - workspace->shape[0] * DataType(workspace->dtype).bytes(), m, + static_cast(workspace->data), workspace_nbytes, m, n, k, 1, stream); } else { - LOG(FATAL) << "Unsupported output dtype: " << DataType(out->dtype); + LOG(FATAL) << "Unsupported output dtype: " << out->dtype; } } @@ -131,14 +130,15 @@ void tvm_cutlass_fp8_groupwise_scaled_bmm_impl(Tensor a, Tensor b, Tensor scales TVM_FFI_ICHECK_EQ(scales_b->shape[1] * block_size_0, n); TVM_FFI_ICHECK_EQ(scales_b->shape[2] * block_size_1, k); - using tvm::runtime::DataType; - TVM_FFI_ICHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN()); - TVM_FFI_ICHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN()); - TVM_FFI_ICHECK_EQ(DataType(scales_a->dtype), DataType::Float(32)); - TVM_FFI_ICHECK_EQ(DataType(scales_b->dtype), DataType::Float(32)); - TVM_FFI_ICHECK_EQ(DataType(workspace->dtype), DataType::UInt(8)); + TVM_FFI_ICHECK_EQ(a->dtype, DLDataType{kDLFloat8_e4m3fn, 8, 1}); + TVM_FFI_ICHECK_EQ(b->dtype, DLDataType{kDLFloat8_e4m3fn, 8, 1}); + TVM_FFI_ICHECK_EQ(scales_a->dtype, DLDataType{kDLFloat, 32, 1}); + TVM_FFI_ICHECK_EQ(scales_b->dtype, DLDataType{kDLFloat, 32, 1}); + TVM_FFI_ICHECK_EQ(workspace->dtype, DLDataType{kDLUInt, 8, 1}); + int64_t workspace_nbytes = + workspace->shape[0] * ((workspace->dtype.bits * workspace->dtype.lanes + 7) / 8); - if (DataType(out->dtype) == DataType::Float(16)) { + if (out->dtype == DLDataType{kDLFloat, 16, 1}) { CutlassFP8GroupwiseGemm::run(static_cast(a->data), @@ -146,10 +146,9 @@ void tvm_cutlass_fp8_groupwise_scaled_bmm_impl(Tensor a, Tensor b, Tensor scales static_cast(scales_a->data), static_cast(scales_b->data), static_cast(out->data), - static_cast(workspace->data), - workspace->shape[0] * DataType(workspace->dtype).bytes(), m, + static_cast(workspace->data), workspace_nbytes, m, n, k, batch_size, stream); - } else if (DataType(out->dtype) == DataType::BFloat(16)) { + } else if (out->dtype == DLDataType{kDLBfloat, 16, 1}) { CutlassFP8GroupwiseGemm::run(static_cast(a->data), @@ -157,11 +156,10 @@ void tvm_cutlass_fp8_groupwise_scaled_bmm_impl(Tensor a, Tensor b, Tensor scales static_cast(scales_a->data), static_cast(scales_b->data), static_cast(out->data), - static_cast(workspace->data), - workspace->shape[0] * DataType(workspace->dtype).bytes(), m, + static_cast(workspace->data), workspace_nbytes, m, n, k, batch_size, stream); } else { - LOG(FATAL) << "Unsupported output dtype: " << DataType(out->dtype); + LOG(FATAL) << "Unsupported output dtype: " << out->dtype; } } diff --git a/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu b/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu index ea70eee38650..6bd9f45ab25e 100644 --- a/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu +++ b/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu @@ -57,15 +57,14 @@ void tvm_fp8_groupwise_scaled_group_gemm_sm100(Tensor a, Tensor b, Tensor scales TVM_FFI_ICHECK_EQ((n + block_size_0 - 1) / block_size_0, scales_b->shape[1]); TVM_FFI_ICHECK_EQ((k + block_size_1 - 1) / block_size_1, scales_b->shape[2]); - using tvm::runtime::DataType; - TVM_FFI_ICHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN()); - TVM_FFI_ICHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN()); - TVM_FFI_ICHECK_EQ(DataType(scales_a->dtype), DataType::Float(32)); - TVM_FFI_ICHECK_EQ(DataType(scales_b->dtype), DataType::Float(32)); - TVM_FFI_ICHECK_EQ(DataType(indptr->dtype), DataType::Int(64)); - TVM_FFI_ICHECK_EQ(DataType(workspace->dtype), DataType::UInt(8)); + TVM_FFI_ICHECK_EQ(a->dtype, DLDataType{kDLFloat8_e4m3fn, 8, 1}); + TVM_FFI_ICHECK_EQ(b->dtype, DLDataType{kDLFloat8_e4m3fn, 8, 1}); + TVM_FFI_ICHECK_EQ(scales_a->dtype, DLDataType{kDLFloat, 32, 1}); + TVM_FFI_ICHECK_EQ(scales_b->dtype, DLDataType{kDLFloat, 32, 1}); + TVM_FFI_ICHECK_EQ(indptr->dtype, DLDataType{kDLInt, 64, 1}); + TVM_FFI_ICHECK_EQ(workspace->dtype, DLDataType{kDLUInt, 8, 1}); - if (DataType(out->dtype) == DataType::Float(16)) { + if (out->dtype == DLDataType{kDLFloat, 16, 1}) { using Dtype = cutlass::half_t; cutlass_fp8_groupwise_scaled_group_gemm_sm100( @@ -73,7 +72,7 @@ void tvm_fp8_groupwise_scaled_group_gemm_sm100(Tensor a, Tensor b, Tensor scales static_cast(scales_a->data), static_cast(scales_b->data), static_cast(indptr->data), static_cast(workspace->data), workspace->shape[0], n, k, num_groups, static_cast(out->data), stream); - } else if (DataType(out->dtype) == DataType::BFloat(16)) { + } else if (out->dtype == DLDataType{kDLBfloat, 16, 1}) { using Dtype = cutlass::bfloat16_t; cutlass_fp8_groupwise_scaled_group_gemm_sm100( diff --git a/src/runtime/extra/contrib/dnnl/dnnl_utils.cc b/src/runtime/extra/contrib/dnnl/dnnl_utils.cc index 23992209f2ad..e41d378b3d30 100644 --- a/src/runtime/extra/contrib/dnnl/dnnl_utils.cc +++ b/src/runtime/extra/contrib/dnnl/dnnl_utils.cc @@ -32,21 +32,21 @@ namespace contrib { dnnl::memory::data_type dtype_dl2dnnl(DLDataType dltype) { using dt = dnnl::memory::data_type; dt dnnl_type = dt::undef; - if (dltype.code == DataType::TypeCode::kFloat) { + if (dltype.code == DLDataTypeCode::kDLFloat) { if (dltype.bits == 16) { dnnl_type = dt::f16; } else if (dltype.bits == 32) { dnnl_type = dt::f32; } - } else if (dltype.code == DataType::TypeCode::kBFloat && dltype.bits == 16) { + } else if (dltype.code == DLDataTypeCode::kDLBfloat && dltype.bits == 16) { dnnl_type = dt::bf16; - } else if (dltype.code == DataType::TypeCode::kInt) { + } else if (dltype.code == DLDataTypeCode::kDLInt) { if (dltype.bits == 8) { dnnl_type = dt::s8; } else if (dltype.bits == 32) { dnnl_type = dt::s32; } - } else if (dltype.code == DataType::TypeCode::kUInt && dltype.bits == 8) { + } else if (dltype.code == DLDataTypeCode::kDLUInt && dltype.bits == 8) { dnnl_type = dt::u8; } if (dnnl_type == dt::undef) { diff --git a/src/runtime/extra/contrib/dnnl/dnnl_utils.h b/src/runtime/extra/contrib/dnnl/dnnl_utils.h index a598b6704450..6f36ed4d8fbe 100644 --- a/src/runtime/extra/contrib/dnnl/dnnl_utils.h +++ b/src/runtime/extra/contrib/dnnl/dnnl_utils.h @@ -34,7 +34,7 @@ // -Wzero-as-null-pointer-constant and -Wdocumentation-unknown-command #include -#include "tvm/runtime/data_type.h" +#include "tvm/ffi/dtype.h" namespace tvm { namespace runtime { diff --git a/src/runtime/extra/contrib/hipblas/hipblas.cc b/src/runtime/extra/contrib/hipblas/hipblas.cc index 5276b4f7956d..eae6f7241cc7 100644 --- a/src/runtime/extra/contrib/hipblas/hipblas.cc +++ b/src/runtime/extra/contrib/hipblas/hipblas.cc @@ -21,10 +21,10 @@ * \file Use external hipblas library call. */ #include +#include #include #include #include -#include #include "../../../../../3rdparty/compiler-rt/builtin_fp16.h" #include "../cblas/gemm_common.h" diff --git a/src/runtime/extra/contrib/json/json_node.h b/src/runtime/extra/contrib/json/json_node.h index c165f6b05cf3..40c96d826914 100644 --- a/src/runtime/extra/contrib/json/json_node.h +++ b/src/runtime/extra/contrib/json/json_node.h @@ -29,9 +29,9 @@ #include #include #include +#include #include #include -#include #include #include diff --git a/src/runtime/extra/contrib/nvshmem/memory_allocator.cc b/src/runtime/extra/contrib/nvshmem/memory_allocator.cc index cb6e3520c8c1..1483563b6200 100644 --- a/src/runtime/extra/contrib/nvshmem/memory_allocator.cc +++ b/src/runtime/extra/contrib/nvshmem/memory_allocator.cc @@ -57,7 +57,7 @@ class NVSHMEMAllocator final : public PooledAllocator { return allocator; } - Tensor Empty(ffi::Shape shape, DataType dtype, Device device) { + Tensor Empty(ffi::Shape shape, DLDataType dtype, Device device) { class NVSHMEMAlloc { public: explicit NVSHMEMAlloc(Buffer buffer) : buffer_(buffer) {} @@ -87,7 +87,7 @@ class NVSHMEMAllocator final : public PooledAllocator { void DeviceFreeDataSpace(Device dev, void* ptr) final { nvshmem_free(ptr); } }; -Tensor NVSHMEMEmpty(ffi::Shape shape, DataType dtype, ffi::Optional device) { +Tensor NVSHMEMEmpty(ffi::Shape shape, DLDataType dtype, ffi::Optional device) { return NVSHMEMAllocator::Global()->Empty(shape, dtype, UseDefaultDeviceIfNone(device)); } diff --git a/src/runtime/extra/contrib/random/random.cc b/src/runtime/extra/contrib/random/random.cc index a3d0cd8b85a8..81db658cb86e 100644 --- a/src/runtime/extra/contrib/random/random.cc +++ b/src/runtime/extra/contrib/random/random.cc @@ -21,10 +21,10 @@ * \file External random functions for tensor. */ #include +#include #include #include #include -#include #include #include diff --git a/src/runtime/extra/contrib/sort/sort.cc b/src/runtime/extra/contrib/sort/sort.cc index 51a94111b6e6..6e3a99f93522 100644 --- a/src/runtime/extra/contrib/sort/sort.cc +++ b/src/runtime/extra/contrib/sort/sort.cc @@ -23,10 +23,10 @@ #include #include +#include #include #include #include -#include #include #include @@ -36,8 +36,6 @@ namespace tvm { namespace contrib { -using namespace runtime; - template bool CompareAscend(const std::pair& lhs, const std::pair& rhs) { if constexpr (stable_comparison) { diff --git a/src/runtime/extra/contrib/vllm/cache_alloc.cc b/src/runtime/extra/contrib/vllm/cache_alloc.cc index 266138406cb9..42601d7a5e69 100644 --- a/src/runtime/extra/contrib/vllm/cache_alloc.cc +++ b/src/runtime/extra/contrib/vllm/cache_alloc.cc @@ -39,9 +39,9 @@ ffi::Array AllocateKVCache(int head_size, int num_layers, int num_heads, for (int i = 0; i < num_layers; ++i) { Tensor key_blocks = Tensor::Empty({num_blocks, num_heads, head_size / vec_size, block_size, vec_size}, - runtime::DataType::Float(16), dev); + DLDataType{kDLFloat, 16, 1}, dev); Tensor value_blocks = Tensor::Empty({num_blocks, num_heads, head_size, block_size}, - runtime::DataType::Float(16), dev); + DLDataType{kDLFloat, 16, 1}, dev); cache.push_back(key_blocks); cache.push_back(value_blocks); } diff --git a/src/runtime/extra/contrib/vllm/cache_kernels.cu b/src/runtime/extra/contrib/vllm/cache_kernels.cu index 5af93a1fd904..6a09497a8d12 100644 --- a/src/runtime/extra/contrib/vllm/cache_kernels.cu +++ b/src/runtime/extra/contrib/vllm/cache_kernels.cu @@ -206,16 +206,16 @@ TVM_FFI_STATIC_INIT_BLOCK() { DLDevice dev = key_cache->device; Tensor key_cache_ptrs_gpu = - Tensor::Empty({static_cast(num_layers)}, runtime::DataType::Int(64), dev); + Tensor::Empty({static_cast(num_layers)}, DLDataType{kDLInt, 64, 1}, dev); Tensor value_cache_ptrs_gpu = - Tensor::Empty({static_cast(num_layers)}, runtime::DataType::Int(64), dev); + Tensor::Empty({static_cast(num_layers)}, DLDataType{kDLInt, 64, 1}, dev); key_cache_ptrs_gpu.CopyFromBytes(key_cache_ptrs.data(), sizeof(int64_t) * key_cache_ptrs.size()); value_cache_ptrs_gpu.CopyFromBytes(value_cache_ptrs.data(), sizeof(int64_t) * value_cache_ptrs.size()); Tensor block_mapping_gpu = - Tensor::Empty(block_mapping.Shape(), runtime::DataType::Int(64), dev); + Tensor::Empty(block_mapping.Shape(), DLDataType{kDLInt, 64, 1}, dev); block_mapping_gpu.CopyFromBytes(block_mapping->data, sizeof(int64_t) * block_mapping->shape[0]); diff --git a/src/runtime/extra/disco/builtin.cc b/src/runtime/extra/disco/builtin.cc index da9f472b3e76..d9d5fc132768 100644 --- a/src/runtime/extra/disco/builtin.cc +++ b/src/runtime/extra/disco/builtin.cc @@ -71,7 +71,7 @@ ffi::Module LoadVMModule(std::string path, ffi::Optional device) { return mod; } -Tensor DiscoEmptyTensor(ffi::Shape shape, DataType dtype, ffi::Optional device) { +Tensor DiscoEmptyTensor(ffi::Shape shape, DLDataType dtype, ffi::Optional device) { return Tensor::Empty(shape, dtype, UseDefaultDeviceIfNone(device)); } @@ -131,7 +131,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef() .def("runtime.disco.load_vm_module", LoadVMModule) .def("runtime.disco.empty", - [](ffi::Shape shape, DataType dtype, ffi::Optional device, bool worker0_only, + [](ffi::Shape shape, DLDataType dtype, ffi::Optional device, bool worker0_only, bool in_group) -> ffi::Optional { int worker_id = WorkerId(); int group_size = diff --git a/src/runtime/extra/disco/cuda_ipc/cuda_ipc_memory.cc b/src/runtime/extra/disco/cuda_ipc/cuda_ipc_memory.cc index 426557b7b7ad..a8a8030f0169 100644 --- a/src/runtime/extra/disco/cuda_ipc/cuda_ipc_memory.cc +++ b/src/runtime/extra/disco/cuda_ipc/cuda_ipc_memory.cc @@ -97,10 +97,12 @@ class CUDAIPCMemoryAllocator final : public memory::PooledAllocator { auto [data_ptr, data_comm_ptrs] = AllocIPCMemory(dev, size, alignment, type_hint, /*reset_memory_to_zero=*/false); int barrier_ptr_size = sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE; - auto [barrier_in_ptr, barrier_in_comm_ptrs] = AllocIPCMemory( - dev, barrier_ptr_size, alignment, DataType::UInt(32), /*reset_memory_to_zero=*/true); - auto [barrier_out_ptr, barrier_out_comm_ptrs] = AllocIPCMemory( - dev, barrier_ptr_size, alignment, DataType::UInt(32), /*reset_memory_to_zero=*/true); + auto [barrier_in_ptr, barrier_in_comm_ptrs] = + AllocIPCMemory(dev, barrier_ptr_size, alignment, DLDataType{kDLUInt, 32, 1}, + /*reset_memory_to_zero=*/true); + auto [barrier_out_ptr, barrier_out_comm_ptrs] = + AllocIPCMemory(dev, barrier_ptr_size, alignment, DLDataType{kDLUInt, 32, 1}, + /*reset_memory_to_zero=*/true); // Create the CUDAIPCMemory object. ffi::ObjectPtr ipc_memory = ffi::make_object(); diff --git a/src/runtime/extra/disco/cuda_ipc/custom_allreduce.cc b/src/runtime/extra/disco/cuda_ipc/custom_allreduce.cc index ffe00d5feef9..3eaca5ba98d4 100644 --- a/src/runtime/extra/disco/cuda_ipc/custom_allreduce.cc +++ b/src/runtime/extra/disco/cuda_ipc/custom_allreduce.cc @@ -81,7 +81,7 @@ void CustomAllReduce(DLTensor* send, int strategy, DLTensor* recv) { // Dispatch to nccl AllReduce if the customized all-reduce cannot apply. deviceStream_t stream = ctx->GetDefaultStream(); NCCL_CALL(ncclAllReduce(send->data, recv->data, num_elements, - /*datatype=*/nccl::AsNCCLDataType(DataType(send->dtype)), + /*datatype=*/nccl::AsNCCLDataType(send->dtype), /*op=*/ncclSum, ctx->global_comm, stream)); return; } diff --git a/src/runtime/extra/disco/loader.cc b/src/runtime/extra/disco/loader.cc index 86caac6573ed..f714112aecf3 100644 --- a/src/runtime/extra/disco/loader.cc +++ b/src/runtime/extra/disco/loader.cc @@ -17,10 +17,10 @@ * under the License. */ #include +#include #include #include #include -#include #include #include @@ -45,7 +45,7 @@ using ParamRecord = TensorCacheMetadata::FileRecord::ParamRecord; struct ShardInfo { struct TensorInfo { ffi::Shape shape; - DataType dtype; + DLDataType dtype; }; struct ShardFunc { std::string name; @@ -67,8 +67,7 @@ ShardInfo::TensorInfo LoadTensorInfoFromJSON(const json::Array& json_tensor_info shape.push_back(shape_json[i].cast()); } std::string dtype = json_tensor_info[1].cast(); - return ShardInfo::TensorInfo{ffi::Shape(std::move(shape)), - DataType(ffi::StringToDLDataType(dtype))}; + return ShardInfo::TensorInfo{ffi::Shape(std::move(shape)), ffi::StringToDLDataType(dtype)}; } ShardInfo::ShardFunc LoadShardFuncFromJSON(const json::Array& json_shard_func) { @@ -301,7 +300,7 @@ Tensor ShardLoaderObj::Load(int weight_index) const { bool needs_sharding = !param_info.shard_info.funcs.empty(); if (needs_sharding) { ffi::Shape shape = param_info.shard_info.funcs.back().output_info.shape; - DataType dtype = param_info.shard_info.funcs.back().output_info.dtype; + DLDataType dtype = param_info.shard_info.funcs.back().output_info.dtype; TVM_FFI_CHECK(shape.size() >= 1 && shape[0] == num_shards, ValueError) << "The first dimension of the " << "output shape must be equal to the " diff --git a/src/runtime/extra/disco/nccl/nccl.cc b/src/runtime/extra/disco/nccl/nccl.cc index 887f440b1b4f..cd00a1ac3d6b 100644 --- a/src/runtime/extra/disco/nccl/nccl.cc +++ b/src/runtime/extra/disco/nccl/nccl.cc @@ -122,8 +122,8 @@ void AllReduce(Tensor send, ReduceKind reduce_kind, bool in_group, Tensor recv) ffi::Shape shape = send.Shape(); int64_t numel = shape->Product(); deviceStream_t stream = ctx->GetDefaultStream(); - DataType dtype = DataType(send->dtype); - if (dtype == DataType::Float8E4M3FN() || dtype == DataType::Float8E5M2()) { + DLDataType dtype = send->dtype; + if (dtype == DLDataType{kDLFloat8_e4m3fn, 8, 1} || dtype == DLDataType{kDLFloat8_e5m2, 8, 1}) { TVM_FFI_THROW(InternalError) << "Float8 data type cannot be allreduced, as nccl does not support this data type."; } @@ -139,7 +139,7 @@ void AllGather(Tensor send, bool in_group, Tensor recv) { int64_t numel = shape->Product(); deviceStream_t stream = ctx->GetDefaultStream(); NCCL_CALL(ncclAllGather(send->data, recv->data, numel, - /*datatype=*/AsNCCLDataType(DataType(send->dtype)), + /*datatype=*/AsNCCLDataType(send->dtype), in_group ? ctx->group_comm : ctx->global_comm, stream)); } @@ -162,7 +162,7 @@ void BroadcastFromWorker0(ffi::Optional send, bool in_group, Tensor recv deviceStream_t stream = ctx->GetDefaultStream(); NCCL_CALL(ncclBroadcast(send_data, recv->data, numel, - /*datatype=*/AsNCCLDataType(DataType(recv->dtype)), + /*datatype=*/AsNCCLDataType(recv->dtype), /*root=*/0, in_group ? ctx->group_comm : ctx->global_comm, stream)); } @@ -185,9 +185,9 @@ void ScatterFromWorker0(ffi::Optional send, bool in_group, Tensor recv) "of elements in the buffer to be " "divisible by the number of workers, but got numel = " << numel << " and " << num_receiver << " workers."; - DataType dtype(buffer->dtype); + DLDataType dtype = buffer->dtype; int64_t numel_per_shard = numel / num_receiver; - int64_t bytes_per_shard = numel_per_shard * dtype.bytes(); + int64_t bytes_per_shard = numel_per_shard * ((dtype.bits * dtype.lanes + 7) / 8); TVM_FFI_CHECK_EQ(numel_per_shard, recv.Shape().Product(), ValueError) << "The number of elements in buffer `recv` must be the same as each shard " "of " @@ -209,7 +209,7 @@ void ScatterFromWorker0(ffi::Optional send, bool in_group, Tensor recv) NCCL_CALL(ncclGroupStart()); } int64_t numel = recv.Shape().Product(); - DataType dtype(recv->dtype); + DLDataType dtype = recv->dtype; NCCL_CALL(ncclRecv(recv->data, numel, AsNCCLDataType(dtype), 0, in_group ? ctx->group_comm : ctx->global_comm, stream)); NCCL_CALL(ncclGroupEnd()); @@ -234,9 +234,9 @@ void GatherToWorker0(Tensor send, bool in_group, ffi::Optional recv) { "of elements in the buffer to be " "divisible by the number of workers, but got numel = " << numel << " and " << num_receiver << " workers."; - DataType dtype(buffer->dtype); + DLDataType dtype = buffer->dtype; int64_t numel_per_shard = numel / num_receiver; - int64_t bytes_per_shard = numel_per_shard * dtype.bytes(); + int64_t bytes_per_shard = numel_per_shard * ((dtype.bits * dtype.lanes + 7) / 8); TVM_FFI_CHECK_EQ(numel_per_shard, send.Shape().Product(), ValueError) << "The number of elements in buffer `send` must be the same as each shard " "of " @@ -258,7 +258,7 @@ void GatherToWorker0(Tensor send, bool in_group, ffi::Optional recv) { NCCL_CALL(ncclGroupStart()); } int64_t numel = send.Shape().Product(); - DataType dtype(send->dtype); + DLDataType dtype = send->dtype; NCCL_CALL(ncclSend(send->data, numel, AsNCCLDataType(dtype), 0, in_group ? ctx->group_comm : ctx->global_comm, stream)); NCCL_CALL(ncclGroupEnd()); diff --git a/src/runtime/extra/disco/nccl/nccl_context.h b/src/runtime/extra/disco/nccl/nccl_context.h index 7a99be0897c0..d529ab441d11 100644 --- a/src/runtime/extra/disco/nccl/nccl_context.h +++ b/src/runtime/extra/disco/nccl/nccl_context.h @@ -86,39 +86,39 @@ inline void StreamDestroy(deviceStream_t stream) { ROCM_CALL(hipStreamDestroy(st #endif -/*! \brief Convert DataType to ncclDataType. */ -inline ncclDataType_t AsNCCLDataType(runtime::DataType dtype) { - if (dtype == DataType::Int(8)) { +/*! \brief Convert DLPack dtype to ncclDataType. */ +inline ncclDataType_t AsNCCLDataType(DLDataType dtype) { + if (dtype == DLDataType{kDLInt, 8, 1}) { return ncclInt8; } - if (dtype == DataType::UInt(8) || dtype == DataType::Float8E4M3FN() || - dtype == DataType::Float8E5M2()) { + if (dtype == DLDataType{kDLUInt, 8, 1} || dtype == DLDataType{kDLFloat8_e4m3fn, 8, 1} || + dtype == DLDataType{kDLFloat8_e5m2, 8, 1}) { // For float8 data type, pretend to be uint8 in nccl. // And will throw error when allreduce, as it makes no sense in this case. return ncclUint8; } - if (dtype == DataType::Int(32)) { + if (dtype == DLDataType{kDLInt, 32, 1}) { return ncclInt32; } - if (dtype == DataType::UInt(32)) { + if (dtype == DLDataType{kDLUInt, 32, 1}) { return ncclUint32; } - if (dtype == DataType::Int(64)) { + if (dtype == DLDataType{kDLInt, 64, 1}) { return ncclInt64; } - if (dtype == DataType::UInt(64)) { + if (dtype == DLDataType{kDLUInt, 64, 1}) { return ncclUint64; } - if (dtype == DataType::Float(16)) { + if (dtype == DLDataType{kDLFloat, 16, 1}) { return ncclFloat16; } - if (dtype == DataType::Float(32)) { + if (dtype == DLDataType{kDLFloat, 32, 1}) { return ncclFloat32; } - if (dtype == DataType::Float(64)) { + if (dtype == DLDataType{kDLFloat, 64, 1}) { return ncclFloat64; } - if (dtype == DataType::BFloat(16)) { + if (dtype == DLDataType{kDLBfloat, 16, 1}) { return ncclBfloat16; } TVM_FFI_THROW(ValueError) << "Unsupported data type " << dtype; diff --git a/src/runtime/tensor.cc b/src/runtime/tensor.cc index 887d576537f2..ed12d0b4885a 100644 --- a/src/runtime/tensor.cc +++ b/src/runtime/tensor.cc @@ -33,7 +33,7 @@ #include "../support/base64.h" #include "../support/bytes_io.h" -#include "tvm/runtime/data_type.h" +#include "tvm/ffi/dtype.h" namespace tvm { namespace runtime { @@ -52,11 +52,11 @@ inline void VerifyDataType(DLDataType dtype) { return; else if (dtype.bits == 4 && dtype.code == kDLInt) return; - else if (dtype.bits == 6 && dtype.code == DataType::kFloat6_e2m3fn) + else if (dtype.bits == 6 && dtype.code == kDLFloat6_e2m3fn) return; - else if (dtype.bits == 6 && dtype.code == DataType::kFloat6_e3m2fn) + else if (dtype.bits == 6 && dtype.code == kDLFloat6_e3m2fn) return; - else if (dtype.bits == 4 && dtype.code == DataType::kFloat4_e2m1fn) + else if (dtype.bits == 4 && dtype.code == kDLFloat4_e2m1fn) return; else TVM_FFI_ICHECK_EQ(dtype.bits % 8, 0); diff --git a/src/runtime/vm/attn_backend.h b/src/runtime/vm/attn_backend.h index 067fa8d10dc1..6aececc755ea 100644 --- a/src/runtime/vm/attn_backend.h +++ b/src/runtime/vm/attn_backend.h @@ -321,7 +321,7 @@ class PagedDecodeFunc : public AttnBackendFunc { Tensor page_locked_int_workspace_buffer, HostMemoryVector* page_indptr, int64_t batch_size, int64_t page_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim, int64_t v_head_dim, - RoPEMode rope_mode, DataType q_dtype, DataType kv_dtype, + RoPEMode rope_mode, DLDataType q_dtype, DLDataType kv_dtype, TVMStreamHandle copy_stream) { // Do nothing. Subclasses can override to customize behavior. } @@ -377,7 +377,7 @@ class FlashInferPagedDecodeFunc : public PagedDecodeFunc { Tensor page_locked_int_workspace_buffer, HostMemoryVector* page_indptr, int64_t batch_size, int64_t page_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim, int64_t v_head_dim, - RoPEMode rope_mode, DataType q_dtype, DataType kv_dtype, + RoPEMode rope_mode, DLDataType q_dtype, DLDataType kv_dtype, TVMStreamHandle copy_stream) final { // Todo(tvm-team): enable cuda graph ffi::Shape plan_info_vec = diff --git a/src/runtime/vm/attn_utils.h b/src/runtime/vm/attn_utils.h index 7a2c93414c0f..4f9cd648e9d7 100644 --- a/src/runtime/vm/attn_utils.h +++ b/src/runtime/vm/attn_utils.h @@ -359,7 +359,7 @@ class HostMemoryVector { explicit HostMemoryVector(int64_t reserved_size, DLDataType dtype, Device device) : reserved_size_(reserved_size) { - TVM_FFI_ICHECK(DataType(dtype) == DataType::Int(32)); + TVM_FFI_ICHECK((dtype == DLDataType{kDLInt, 32, 1})); data_ = Tensor::Empty({reserved_size}, dtype, device); } @@ -368,7 +368,7 @@ class HostMemoryVector { if (current_size_ == reserved_size_) { reserved_size_ *= 2; Tensor new_data = Tensor::Empty({reserved_size_}, data_->dtype, data_->device); - std::memcpy(new_data->data, data_->data, current_size_ * DataType(data_->dtype).bytes()); + std::memcpy(new_data->data, data_->data, current_size_ * (((data_->dtype).bits + 7) / 8)); data_ = new_data; } static_cast(data_->data)[current_size_++] = value; @@ -382,7 +382,7 @@ class HostMemoryVector { reserved_size_ *= 2; } Tensor new_data = Tensor::Empty({reserved_size_}, data_->dtype, data_->device); - std::memcpy(new_data->data, data_->data, current_size_ * DataType(data_->dtype).bytes()); + std::memcpy(new_data->data, data_->data, current_size_ * (((data_->dtype).bits + 7) / 8)); data_ = new_data; } std::memcpy(static_cast(data_->data) + current_size_, values.data(), @@ -466,7 +466,7 @@ class PagedKVCacheAuxDataManager { device_(device), preferred_host_device_(preferred_host_device), copy_stream_(copy_stream) { - TVM_FFI_ICHECK(DataType(dtype_aux) == DataType::Int(32)); + TVM_FFI_ICHECK((dtype_aux == DLDataType{kDLInt, 32, 1})); } virtual ~PagedKVCacheAuxDataManager() = default; diff --git a/src/runtime/vm/builtin.cc b/src/runtime/vm/builtin.cc index 8fc18c5c0722..30fbf77b9c7f 100644 --- a/src/runtime/vm/builtin.cc +++ b/src/runtime/vm/builtin.cc @@ -22,11 +22,11 @@ #include #include #include +#include #include #include #include #include -#include #include #include #include @@ -243,14 +243,14 @@ TVM_FFI_STATIC_INIT_BLOCK() { void CheckTensorInfo(ffi::PackedArgs args, ffi::Any* rv) { ffi::AnyView arg = args[0]; int ndim = args[1].cast(); - DataType dtype; + DLDataType dtype; ffi::Optional err_ctx; if (args.size() == 3) { - dtype = DataType::Void(); + dtype = DLDataType{kDLOpaqueHandle, 0, 0}; err_ctx = args[2].cast>(); } else { - dtype = args[2].cast(); + dtype = args[2].cast(); err_ctx = args[3].cast>(); } @@ -264,10 +264,10 @@ void CheckTensorInfo(ffi::PackedArgs args, ffi::Any* rv) { << err_ctx.value_or("") << " expect Tensor with ndim " << ndim << " but get " << ptr->ndim; } - if (dtype != DataType::Void()) { - TVM_FFI_CHECK(DataType(ptr->dtype) == dtype, ValueError) + if (dtype != DLDataType{kDLOpaqueHandle, 0, 0}) { + TVM_FFI_CHECK(ptr->dtype == dtype, ValueError) << err_ctx.value_or("") << " expect Tensor with dtype " << dtype << " but get " - << DataType(ptr->dtype); + << ptr->dtype; } } @@ -301,23 +301,24 @@ TVM_FFI_STATIC_INIT_BLOCK() { /*! * \brief Builtin function to check if arg is PrimValue(dtype) * \param arg The input argument. - * \param dtype Expected dtype of the PrimValue. Can be DataType::Void() for unknown dtype. + * \param dtype Expected dtype of the PrimValue. Can be DLDataType{kDLOpaqueHandle, 0, 0} for + * unknown dtype. * \param err_ctx Additional context if error occurs. */ -void CheckPrimValueInfo(ffi::AnyView arg, DataType dtype, ffi::Optional err_ctx) { +void CheckPrimValueInfo(ffi::AnyView arg, DLDataType dtype, ffi::Optional err_ctx) { if (auto opt_obj = arg.as()) { TVM_FFI_THROW(TypeError) << err_ctx.value_or("") << ", expected dtype " << dtype << ", but received ObjectRef of type " << opt_obj.value()->GetTypeKey(); - } else if (dtype.is_bool()) { + } else if (((dtype).code == kDLBool)) { arg.cast(); - } else if (dtype.is_int()) { + } else if (((dtype).code == kDLInt)) { arg.cast(); - } else if (dtype.is_uint()) { + } else if (((dtype).code == kDLUInt)) { arg.cast(); - } else if (dtype.is_float()) { + } else if (((dtype).code == kDLFloat)) { arg.cast(); - } else if (dtype.is_handle()) { + } else if (dtype.code == kDLOpaqueHandle && !(dtype.bits == 0 && dtype.lanes == 0)) { arg.cast(); } else { TVM_FFI_THROW(TypeError) << err_ctx.value_or("") << ", unsupported dtype " << dtype; @@ -398,7 +399,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { Storage sobj = args[0].cast(); int64_t offset = args[1].cast(); ffi::Shape shape = args[2].cast(); - DataType dtype = args[3].cast(); + DLDataType dtype = args[3].cast(); if (args.size() == 5) { ffi::String scope = args[4].cast(); *rv = sobj->AllocTensorScoped(offset, shape, dtype, scope); diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 33ff1503f823..9e3a5f932309 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -101,8 +101,7 @@ std::string VMExecutable::Stats() const { oss << opt_int.value(); oss << ", "; } else if (auto opt_dtype = it.as()) { - DataType dtype(opt_dtype.value()); - oss << dtype; + oss << opt_dtype.value(); oss << ", "; } else { TVM_FFI_THROW(InternalError) << "Unsupported constant pool type " << it.GetTypeKey(); diff --git a/src/runtime/vm/lm_support.cc b/src/runtime/vm/lm_support.cc index 51b271441a27..2516e0d8a1af 100644 --- a/src/runtime/vm/lm_support.cc +++ b/src/runtime/vm/lm_support.cc @@ -362,7 +362,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { // NOTE this is a built-in highly related to LM so we put it here. int SampleTopPFromLogits(Tensor logits, double temperature, double top_p, double uniform_sample) { TVM_FFI_ICHECK(logits.IsContiguous()); - TVM_FFI_ICHECK(logits.DataType() == DataType::Float(32)); + TVM_FFI_ICHECK((logits.DataType() == DLDataType{kDLFloat, 32, 1})); if (logits->device.device_type != kDLCPU) { logits = logits.CopyTo(DLDevice{kDLCPU, 0}); @@ -428,7 +428,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { int SampleTopPFromProb(Tensor prob, double top_p, double uniform_sample) { TVM_FFI_ICHECK(prob.IsContiguous()); - TVM_FFI_ICHECK(prob.DataType() == DataType::Float(32)); + TVM_FFI_ICHECK((prob.DataType() == DLDataType{kDLFloat, 32, 1})); if (prob->device.device_type != kDLCPU) { prob = prob.CopyTo(DLDevice{kDLCPU, 0}); @@ -543,7 +543,8 @@ Tensor MultinomialFromUniform(Tensor prob, Tensor uniform_sample) { int64_t vocab_size = prob->shape[prob->ndim - 1]; const float* pprob = static_cast(prob->data); const float* psample = static_cast(uniform_sample->data); - Tensor new_array = Tensor::Empty({batch_size, 1}, DataType::Int(64), uniform_sample->device); + Tensor new_array = + Tensor::Empty({batch_size, 1}, DLDataType{kDLInt, 64, 1}, uniform_sample->device); int64_t* parray = static_cast(new_array->data); for (int64_t i = 0; i < batch_size; ++i) { float cum_sum_prob = 0.0f; @@ -569,8 +570,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { void ApplyRepetitionPenalty(Tensor logits, Tensor token_ids, double penalty) { TVM_FFI_ICHECK(logits.IsContiguous()); TVM_FFI_ICHECK(token_ids.IsContiguous()); - TVM_FFI_ICHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; - TVM_FFI_ICHECK(token_ids.DataType() == DataType::Int(32)) << "token ids must be int32!"; + TVM_FFI_ICHECK((logits.DataType() == DLDataType{kDLFloat, 32, 1})) + << "Logits data type is not float32!"; + TVM_FFI_ICHECK((token_ids.DataType() == DLDataType{kDLInt, 32, 1})) << "token ids must be int32!"; TVM_FFI_ICHECK(logits->device.device_type == kDLCPU) << "logits device must be CPU!"; TVM_FFI_ICHECK(token_ids->device.device_type == kDLCPU) << "token_ids device must be CPU!"; float* logits_raw_data = static_cast(logits->data); @@ -606,9 +608,11 @@ void ApplyPresenceAndFrequencyPenalty(Tensor logits, Tensor token_ids, Tensor to TVM_FFI_ICHECK(logits.IsContiguous()); TVM_FFI_ICHECK(token_ids.IsContiguous()); TVM_FFI_ICHECK(token_freqs.IsContiguous()); - TVM_FFI_ICHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; - TVM_FFI_ICHECK(token_ids.DataType() == DataType::Int(32)) << "token ids must be int32!"; - TVM_FFI_ICHECK(token_freqs.DataType() == DataType::Int(32)) << "token freqs must be int32!"; + TVM_FFI_ICHECK((logits.DataType() == DLDataType{kDLFloat, 32, 1})) + << "Logits data type is not float32!"; + TVM_FFI_ICHECK((token_ids.DataType() == DLDataType{kDLInt, 32, 1})) << "token ids must be int32!"; + TVM_FFI_ICHECK((token_freqs.DataType() == DLDataType{kDLInt, 32, 1})) + << "token freqs must be int32!"; TVM_FFI_ICHECK(logits->device.device_type == kDLCPU) << "logits device must be CPU!"; TVM_FFI_ICHECK(token_ids->device.device_type == kDLCPU) << "token_ids device must be CPU!"; TVM_FFI_ICHECK(token_freqs->device.device_type == kDLCPU) << "token_ids device must be CPU!"; @@ -633,7 +637,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { // This is an inplace operation. void ApplySoftmaxWithTemperature(Tensor logits, double temperature) { TVM_FFI_ICHECK(logits.IsContiguous()); - TVM_FFI_ICHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; + TVM_FFI_ICHECK((logits.DataType() == DLDataType{kDLFloat, 32, 1})) + << "Logits data type is not float32!"; TVM_FFI_ICHECK(logits->device.device_type == kDLCPU) << "logits device must be CPU!"; int vocab_size = logits->shape[logits->ndim - 1]; float* logits_raw_data = static_cast(logits->data); diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc index e5c4576e01c1..cd7920d6eef0 100644 --- a/src/runtime/vm/paged_kv_cache.cc +++ b/src/runtime/vm/paged_kv_cache.cc @@ -116,9 +116,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { const ffi::Optional rope_ext_factors_; /*! \brief The KV cache dtype. */ - const DataType kv_dtype_; + const DLDataType kv_dtype_; /*! \brief We fix int32 to be the index dtype of auxiliary data. */ - const DLDataType dtype_aux_ = DLDataType(DataType::Int(32, 1)); + const DLDataType dtype_aux_ = DLDataType{kDLInt, 32, 1}; /********************* Page Structures *********************/ @@ -326,7 +326,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { rotary_scale_(rotary_scale), rotary_theta_(rotary_theta), rope_ext_factors_(std::move(rope_ext_factors)), - kv_dtype_(DataType(dtype)), + kv_dtype_(dtype), reserved_num_seqs_(reserved_num_seqs), f_transpose_append_mha_(std::move(f_transpose_append_mha)), f_transpose_append_mla_(std::move(f_transpose_append_mla)), @@ -372,7 +372,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { pages_.push_back(nvshmem_pages_.CreateView( {num_total_pages_, 2, num_kv_heads_, page_size_, qk_head_dim_}, nvshmem_pages_->dtype, i * num_total_pages_ * 2 * num_kv_heads_ * page_size_ * qk_head_dim_ * - nvshmem_pages_.DataType().bytes())); + (nvshmem_pages_.DataType().bits + 7) / 8)); } const auto f_transfer_kv_ptr = tvm::ffi::Function::GetGlobal("nvshmem.KVTransfer"); @@ -450,9 +450,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) { if (NeedKernelBeginForward()) { temp_int_attn_workspace_.push_back( - Tensor::Empty({kIntAttnWorkspaceByte}, DataType::UInt(8), device)); + Tensor::Empty({kIntAttnWorkspaceByte}, DLDataType{kDLUInt, 8, 1}, device)); temp_int_pinned_attn_workspace_.push_back(Tensor::Empty( - {kIntAttnWorkspaceByte}, DataType::UInt(8), GetPreferredHostDevice(device))); + {kIntAttnWorkspaceByte}, DLDataType{kDLUInt, 8, 1}, GetPreferredHostDevice(device))); } qo_indptr_on_depths_view_.push_back(Tensor()); page_indptr_on_depths_view_.push_back(Tensor()); @@ -470,11 +470,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // Additional workspace for the "prefill with ragged kv" kernel. if (NeedKernelBeginForward()) { temp_int_attn_workspace_.push_back( - Tensor::Empty({kIntAttnWorkspaceByte}, DataType::UInt(8), device)); + Tensor::Empty({kIntAttnWorkspaceByte}, DLDataType{kDLUInt, 8, 1}, device)); temp_int_pinned_attn_workspace_.push_back(Tensor::Empty( - {kIntAttnWorkspaceByte}, DataType::UInt(8), GetPreferredHostDevice(device))); + {kIntAttnWorkspaceByte}, DLDataType{kDLUInt, 8, 1}, GetPreferredHostDevice(device))); temp_float_attn_workspace_ = - Tensor::Empty({kFloatAttnWorkspaceByte}, DataType::UInt(8), device); + Tensor::Empty({kFloatAttnWorkspaceByte}, DLDataType{kDLUInt, 8, 1}, device); } if (std::find(attn_kinds_.begin(), attn_kinds_.end(), AttnKind::kMHA) != attn_kinds_.end()) { @@ -488,9 +488,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { temp_attn_output_device_ = Tensor::Empty({prefill_chunk_size_, num_qo_heads, v_head_dim}, dtype, device); temp_attn_lse_device_ = - Tensor::Empty({prefill_chunk_size_, num_qo_heads}, DataType::Float(32), device); + Tensor::Empty({prefill_chunk_size_, num_qo_heads}, DLDataType{kDLFloat, 32, 1}, device); merged_attn_lse_device_ = - Tensor::Empty({prefill_chunk_size_, num_qo_heads}, DataType::Float(32), device); + Tensor::Empty({prefill_chunk_size_, num_qo_heads}, DLDataType{kDLFloat, 32, 1}, device); for (int64_t page_id = num_total_pages - 1; page_id >= 0; --page_id) { free_page_ids_.push_back(page_id); } diff --git a/src/runtime/vm/rnn_state.cc b/src/runtime/vm/rnn_state.cc index 9926b3d235e8..a38acf6e1cdf 100644 --- a/src/runtime/vm/rnn_state.cc +++ b/src/runtime/vm/rnn_state.cc @@ -83,7 +83,7 @@ class RNNStateImpObj : public RNNStateObj { const ffi::Array init_layer_value_; /*! \brief We fix int32 to be the index dtype of auxiliary data. */ - const DLDataType dtype_aux_ = DLDataType(DataType::Int(32, 1)); + const DLDataType dtype_aux_ = DLDataType{kDLInt, 32, 1}; /******************* Storage Structures *******************/ diff --git a/src/runtime/vm/tensor_cache_support.cc b/src/runtime/vm/tensor_cache_support.cc index ee77c5ddd8f0..62fd1a34c62f 100644 --- a/src/runtime/vm/tensor_cache_support.cc +++ b/src/runtime/vm/tensor_cache_support.cc @@ -64,7 +64,7 @@ TensorCacheMetadata::FileRecord::ParamRecord JSONAsParamRecord(const json::Objec TensorCacheMetadata::FileRecord::ParamRecord result; std::string dtype = json["dtype"].cast(); result.name = json["name"].cast(); - result.dtype = DataType(ffi::StringToDLDataType(dtype)); + result.dtype = ffi::StringToDLDataType(dtype); result.format = json["format"].cast(); result.nbytes = json["nbytes"].cast(); result.byte_offset = json["byteOffset"].cast(); @@ -154,7 +154,7 @@ void CopyTensorFromBytes(Tensor param, const void* data, size_t nbytes, Tensor TensorCacheMetadata::FileRecord::ParamRecord::Load( Device device, const std::string* raw_data, ffi::Optional* staging_buffer) const { Tensor arr = Tensor::Empty(shape, dtype, device); - if (dtype == DataType::Float(32) && format == "f32-to-bf16") { + if (dtype == DLDataType{kDLFloat, 32, 1} && format == "f32-to-bf16") { // decode bf16 to f32 std::vector buffer(nbytes / 2); std::vector decoded(nbytes / 2); diff --git a/src/s_tir/analysis/calculate_allocated_memory.cc b/src/s_tir/analysis/calculate_allocated_memory.cc index 51330a63e88b..41df4ee4bb8a 100644 --- a/src/s_tir/analysis/calculate_allocated_memory.cc +++ b/src/s_tir/analysis/calculate_allocated_memory.cc @@ -76,7 +76,7 @@ class AllocBufferCalculator : public StmtExprVisitor { break; } } - size *= op->buffer->dtype.bytes() * op->buffer->dtype.lanes(); + size *= ((op->buffer->dtype.bits() + 7) / 8) * op->buffer->dtype.lanes(); _current_size[storage_scope] += size; _max_size[storage_scope] = std::max(_current_size[storage_scope], _max_size[storage_scope]); StmtExprVisitor::VisitStmt_(op); diff --git a/src/s_tir/analysis/estimate_flops.cc b/src/s_tir/analysis/estimate_flops.cc index d77e715db1b6..bcde2d4b70bd 100644 --- a/src/s_tir/analysis/estimate_flops.cc +++ b/src/s_tir/analysis/estimate_flops.cc @@ -26,15 +26,13 @@ namespace tvm { namespace s_tir { using namespace tvm::tirx; -int32_t DataType2Int(const tvm::DataType& dtype) { +int32_t DataType2Int(DLDataType dtype) { static_assert(sizeof(DLDataType) == sizeof(int32_t), "Incorrect size of DLDataType"); union { DLDataType src; int32_t dst; } converter; - converter.src.code = dtype.code(); - converter.src.bits = dtype.bits(); - converter.src.lanes = dtype.lanes(); + converter.src = dtype; return converter.dst; } @@ -57,7 +55,7 @@ ffi::String Int2DataTypeStr(int32_t dtype) { struct TResult { TResult() = default; - void Add(const tvm::DataType& dtype) { data_[DataType2Int(dtype)] += 1; } + void Add(DLDataType dtype) { data_[DataType2Int(dtype)] += 1; } TResult operator+=(const TResult& rhs) { for (const auto& kv : rhs.data_) { @@ -98,7 +96,7 @@ class FlopEstimator : private ExprFunctor, TResult VisitExpr_(const Node* op) final { \ TResult result = VisitExpr(op->a); \ result += VisitExpr(op->b); \ - result.Add(op->dtype); \ + result.Add(op->ty()->dtype); \ return result; \ } TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(AddNode); diff --git a/src/s_tir/analysis/sblock_access_region_detector.cc b/src/s_tir/analysis/sblock_access_region_detector.cc index 18eef8e2fe01..9fa0a7b0b325 100644 --- a/src/s_tir/analysis/sblock_access_region_detector.cc +++ b/src/s_tir/analysis/sblock_access_region_detector.cc @@ -348,7 +348,7 @@ ffi::Array BlockReadWriteDetector::CollectRegions( const tvm::arith::IntSet& range = regions[i][j]; if (range.CanProveSinglePoint(ana_)) { PrimExpr min = range.min(); - region.push_back(Range::FromMinExtent(min, MakeConst(min.dtype(), 1))); + region.push_back(Range::FromMinExtent(min, MakeConst(min.ty(), 1))); } else { region.push_back(range.CoverRange(Range::FromMinExtent(0, buffers[i]->shape[j]))); } diff --git a/src/s_tir/analysis/verify_gpu_code.cc b/src/s_tir/analysis/verify_gpu_code.cc index bd7b7c92ba7c..837485d32de1 100644 --- a/src/s_tir/analysis/verify_gpu_code.cc +++ b/src/s_tir/analysis/verify_gpu_code.cc @@ -76,20 +76,22 @@ class GPUCodeVerifier : public StmtExprVisitor { break; } } + DLDataType dtype = op->buffer->dtype->dtype; if (storage_scope.rank == runtime::StorageRank::kLocal) { - local_memory_per_block_ += - static_cast(const_size) * op->buffer->dtype.bytes() * op->buffer->dtype.lanes(); + local_memory_per_block_ += static_cast(const_size) * (((dtype).bits + 7) / 8) * + static_cast((dtype).lanes); } else if (storage_scope.rank == runtime::StorageRank::kShared) { - shared_memory_per_block_ += - static_cast(const_size) * op->buffer->dtype.bytes() * op->buffer->dtype.lanes(); + shared_memory_per_block_ += static_cast(const_size) * (((dtype).bits + 7) / 8) * + static_cast((dtype).lanes); } - if (op->buffer->dtype.is_vector()) { - if (static_cast(op->buffer->dtype.lanes() * op->buffer->dtype.bytes()) > + if ((static_cast((dtype).lanes) > 1)) { + if (static_cast(static_cast((dtype).lanes) * (((dtype).bits + 7) / 8)) > max_vector_bytes_) { std::stringstream s; - s << "Number of lanes (" << op->buffer->dtype.lanes() << ") times number of bytes (" - << op->buffer->dtype.bytes() << ") for dtype " << op->buffer->dtype - << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")"; + s << "Number of lanes (" << static_cast((dtype).lanes) + << ") times number of bytes (" << (((dtype).bits + 7) / 8) << ") for dtype " + << op->buffer->dtype << " is greater than the maximum number of vector bytes (" + << max_vector_bytes_ << ")"; errors_.push_back(s.str()); } } @@ -202,12 +204,16 @@ class GPUCodeVerifier : public StmtExprVisitor { void CheckBufferIndicesVectorizable(const ffi::Array indices) { for (const auto index : indices) { if (const auto* ramp = index.as()) { - if (!is_one(ramp->stride) && - static_cast(ramp->dtype.lanes() * ramp->dtype.bytes()) > max_vector_bytes_) { + PrimType ramp_ty = ramp->ty(); + DLDataType ramp_dtype = ramp_ty->dtype; + if (!is_one(ramp->stride) && ramp_ty.IsFixedLengthVector() && + static_cast(static_cast((ramp_dtype).lanes) * + (((ramp_dtype).bits + 7) / 8)) > max_vector_bytes_) { std::stringstream s; - s << "Number of lanes (" << ramp->dtype.lanes() << ") times number of bytes (" - << ramp->dtype.bytes() << ") for dtype " << ramp->dtype - << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")"; + s << "Number of lanes (" << static_cast((ramp_dtype).lanes) + << ") times number of bytes (" << (((ramp_dtype).bits + 7) / 8) << ") for dtype " + << ramp_dtype << " is greater than the maximum number of vector bytes (" + << max_vector_bytes_ << ")"; errors_.push_back(s.str()); } } @@ -215,12 +221,16 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitExpr_(const CastNode* op) { - if (op->dtype.is_vector()) { - if (static_cast(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { + PrimType op_ty = op->ty(); + DLDataType op_dtype = op_ty->dtype; + if (op_ty.IsFixedLengthVector()) { + if (static_cast(static_cast((op_dtype).lanes) * + (((op_dtype).bits + 7) / 8)) > max_vector_bytes_) { std::stringstream s; - s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes (" - << op->dtype.bytes() << ") for dtype " << op->dtype - << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")"; + s << "Number of lanes (" << static_cast((op_dtype).lanes) + << ") times number of bytes (" << (((op_dtype).bits + 7) / 8) << ") for dtype " + << op_dtype << " is greater than the maximum number of vector bytes (" + << max_vector_bytes_ << ")"; errors_.push_back(s.str()); } } @@ -228,12 +238,16 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitExpr_(const BufferLoadNode* op) { - if (op->dtype.is_vector()) { - if (static_cast(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { + PrimType op_ty = op->ty(); + DLDataType op_dtype = op_ty->dtype; + if (op_ty.IsFixedLengthVector()) { + if (static_cast(static_cast((op_dtype).lanes) * + (((op_dtype).bits + 7) / 8)) > max_vector_bytes_) { std::stringstream s; - s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes (" - << op->dtype.bytes() << ") for dtype " << op->dtype - << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")"; + s << "Number of lanes (" << static_cast((op_dtype).lanes) + << ") times number of bytes (" << (((op_dtype).bits + 7) / 8) << ") for dtype " + << op_dtype << " is greater than the maximum number of vector bytes (" + << max_vector_bytes_ << ")"; errors_.push_back(s.str()); } CheckBufferIndicesVectorizable(op->indices); @@ -242,13 +256,16 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitStmt_(const BufferStoreNode* op) { - if (op->value->dtype.is_vector()) { - if (static_cast(op->value->dtype.lanes() * op->value->dtype.bytes()) > - max_vector_bytes_) { + PrimType value_ty = op->value.ty(); + DLDataType value_dtype = value_ty->dtype; + if (value_ty.IsFixedLengthVector()) { + if (static_cast(static_cast((value_dtype).lanes) * + (((value_dtype).bits + 7) / 8)) > max_vector_bytes_) { std::stringstream s; - s << "Number of lanes (" << op->value->dtype.lanes() << ") times number of bytes (" - << op->value->dtype.bytes() << ") for dtype " << op->value->dtype - << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")"; + s << "Number of lanes (" << static_cast((value_dtype).lanes) + << ") times number of bytes (" << (((value_dtype).bits + 7) / 8) << ") for dtype " + << value_dtype << " is greater than the maximum number of vector bytes (" + << max_vector_bytes_ << ")"; errors_.push_back(s.str()); } CheckBufferIndicesVectorizable(op->indices); diff --git a/src/s_tir/backend/adreno/inject_texture_alloc.cc b/src/s_tir/backend/adreno/inject_texture_alloc.cc index e4e7c322ef55..5b6aeda19362 100644 --- a/src/s_tir/backend/adreno/inject_texture_alloc.cc +++ b/src/s_tir/backend/adreno/inject_texture_alloc.cc @@ -79,11 +79,11 @@ class TextureAllocInjector : public arith::IRMutatorWithAnalyzer { ffi::Array args; args.push_back(StringImm(storage_scope)); args.push_back(IntImm::Int64(3)); - args.push_back(Call(DataType::Handle(), builtin::tvm_stack_make_shape(), + args.push_back(Call(PrimType::Handle(), builtin::tvm_stack_make_shape(), {texture.width, texture.height, texture.depth})); args.push_back(IntImm::Int64(channel_size)); stmt = Bind(op->buffer->data, - Call(op->buffer->data.dtype(), builtin::nd_mem_alloc_with_scope(), args)); + Call(op->buffer->data.ty(), builtin::nd_mem_alloc_with_scope(), args)); } return stmt; } diff --git a/src/s_tir/backend/adreno/texture_flatten.cc b/src/s_tir/backend/adreno/texture_flatten.cc index 0dd939ad817a..d4297e42e4d2 100644 --- a/src/s_tir/backend/adreno/texture_flatten.cc +++ b/src/s_tir/backend/adreno/texture_flatten.cc @@ -100,7 +100,7 @@ class TextureFlattener : public TextureLoweringBase { if (IsTextureStorage(storage_scope)) { ffi::Array args = GetTextureAccessArgs(op, op->buffer); args.push_back(op->value); - stmt = Evaluate(Call(args[0]->dtype, builtin::texture2d_store(), args)); + stmt = Evaluate(Call(args[0].ty(), builtin::texture2d_store(), args)); } return stmt; @@ -147,7 +147,7 @@ class TextureFlattener : public TextureLoweringBase { PrimExpr col_offset = SimplifyOffset(col_dims, col_indices); PrimExpr depth_offset = SimplifyOffset(depth_dims, depth_indices); PrimExpr channel_size = IntImm( - DataType::Int(32, 1), *tirx::as_const_int(buffer->shape.back()) * buffer->dtype.bits()); + PrimType::Int(32, 1), *tirx::as_const_int(buffer->shape.back()) * buffer->dtype.bits()); args.push_back(row_offset); args.push_back(col_offset); args.push_back(depth_offset); diff --git a/src/s_tir/data_layout.cc b/src/s_tir/data_layout.cc index 787386c8ccb9..6fa2db0206e4 100644 --- a/src/s_tir/data_layout.cc +++ b/src/s_tir/data_layout.cc @@ -22,10 +22,10 @@ * \brief Data SLayout expression. */ #include +#include #include #include #include -#include #include #include #include @@ -113,8 +113,9 @@ SLayout::SLayout(const ffi::Array& axes) { data_ = std::move(node); } -SLayout::SLayout(const std::string& name, DataType dtype) { // NOLINT(*) - TVM_FFI_CHECK(dtype.is_int(), TypeError) << "The input dtype should be integer type"; +SLayout::SLayout(const std::string& name, PrimType index_ty) { // NOLINT(*) + TVM_FFI_CHECK(index_ty.code() == DLDataTypeCode::kDLInt, TypeError) + << "The input dtype should be integer type"; if (name == "__undef__") return; auto node = ffi::make_object(); @@ -131,8 +132,8 @@ SLayout::SLayout(const std::string& name, DataType dtype) { // NOLINT(*) if (c >= 'A' && c <= 'Z') { TVM_FFI_ICHECK_EQ(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor << " before dimension " << c; - IterVar axis(Range(IntImm(dtype, 0), Var(std::string(1, c), dtype)), - Var(std::string(1, c), dtype), tirx::kDataPar); + IterVar axis(Range(IntImm(index_ty, 0), Var(std::string(1, c), index_ty)), + Var(std::string(1, c), index_ty), tirx::kDataPar); if (!in_packing) { node->axes.push_back(axis); } else { @@ -143,7 +144,7 @@ SLayout::SLayout(const std::string& name, DataType dtype) { // NOLINT(*) << factor << " for dimension " << c; std::stringstream name; name << factor << c; - IterVar axis(Range(IntImm(dtype, 0), IntImm(dtype, factor)), Var(name.str(), dtype), + IterVar axis(Range(IntImm(index_ty, 0), IntImm(index_ty, factor)), Var(name.str(), index_ty), tirx::kDataPar); if (!in_packing) { node->axes.push_back(axis); @@ -174,8 +175,8 @@ SLayout::SLayout(const std::string& name, DataType dtype) { // NOLINT(*) extent = extent * factor->value; } std::string grouped_name = ss.str(); - IterVar grouped_axis(Range(IntImm(dtype, 0), IntImm(dtype, extent)), Var(grouped_name, dtype), - tirx::kDataPar); + IterVar grouped_axis(Range(IntImm(index_ty, 0), IntImm(index_ty, extent)), + Var(grouped_name, index_ty), tirx::kDataPar); node->axes.push_back(grouped_axis); in_packing = false; @@ -231,21 +232,21 @@ ffi::Array SLayout::UnpackIterVar(IterVar packed_iter) { int64_t factor = 0, final_factor = 1; std::string name(packed_iter->var->name_hint.c_str()); - DataType dtype = packed_iter->var.dtype(); + PrimType index_ty = packed_iter->var.ty(); for (auto ch : name) { if (ch >= '0' && ch <= '9') { factor = factor * 10 + (ch - '0'); } else if (ch >= 'a' && ch <= 'z') { TVM_FFI_ICHECK(factor != 0) << "Invalid Factor Size"; - result.push_back(IterVar(Range(IntImm(dtype, 0), IntImm(dtype, factor)), - Var(std::string(1, ch), dtype), tirx::kDataPar)); + result.push_back(IterVar(Range(IntImm(index_ty, 0), IntImm(index_ty, factor)), + Var(std::string(1, ch), index_ty), tirx::kDataPar)); final_factor *= factor; factor = 0; } else if (ch >= 'A' && ch <= 'Z') { TVM_FFI_ICHECK(factor == 0) << "Can't have non-zero factors for primal axis"; - result.push_back(IterVar(Range(IntImm(dtype, 0), Var(std::string(1, ch), dtype)), - Var(std::string(1, ch), dtype), tirx::kDataPar)); + result.push_back(IterVar(Range(IntImm(index_ty, 0), Var(std::string(1, ch), index_ty)), + Var(std::string(1, ch), index_ty), tirx::kDataPar)); } } @@ -256,7 +257,7 @@ IterVar SLayout::PackIterVar(ffi::Array iter_vars) { std::stringstream name; size_t extent = 1; - DataType dtype = iter_vars[0]->dom->extent.as().value()->dtype; + PrimType index_ty = iter_vars[0]->dom->extent.as().value().ty(); for (auto itvar : iter_vars) { TVM_FFI_ICHECK(itvar->dom->extent.as()) << "Packed Axis can contain only Subordinate Axes"; @@ -264,7 +265,7 @@ IterVar SLayout::PackIterVar(ffi::Array iter_vars) { extent = extent * itvar->dom->extent.as().value()->value; } - return IterVar(Range(IntImm(dtype, 0), IntImm(dtype, extent)), Var(name.str(), dtype), + return IterVar(Range(IntImm(index_ty, 0), IntImm(index_ty, extent)), Var(name.str(), index_ty), tirx::kDataPar); } @@ -357,7 +358,8 @@ inline bool GetStoreRule(ffi::Array* index_rule, ffi::Array* if (axis == sub_axis) { const auto* sub_extent = inter_unpacked_axes[l]->dom->extent.as(); TVM_FFI_ICHECK(sub_extent) << "Expected Integer Extents for Offset Calculation"; - factor_ij = factor_ij * IntImm(sub_extent->dtype, sub_extent->value); + factor_ij = + factor_ij * IntImm(ffi::GetRef(sub_extent).ty(), sub_extent->value); } } } @@ -498,11 +500,11 @@ inline ffi::Array TransformShape(const ffi::Array& src_shape << ", get " << orig_shape; } } - bind_map[orig_axis->var.get()] = IntImm(orig_axis->var->dtype, 0); + bind_map[orig_axis->var.get()] = IntImm(orig_axis->var.ty(), 0); } else { - bind_map[orig_axis->var.get()] = orig_axis->var->dtype == orig_shape->dtype + bind_map[orig_axis->var.get()] = orig_axis->var.ty()->dtype == orig_shape.ty()->dtype ? orig_shape - : cast(orig_axis->var->dtype, orig_shape); + : cast(orig_axis->var.ty(), orig_shape); } } // infer the target shape, @@ -583,7 +585,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("s_tir.SLayout", [](std::string name, DataType dtype) { return SLayout(name, dtype); }) + .def("s_tir.SLayout", [](std::string name, PrimType dtype) { return SLayout(name, dtype); }) .def("s_tir.SLayoutIndexOf", [](SLayout layout, std::string axis) -> int { return layout.IndexOf(axis); }) .def("s_tir.SLayoutFactorOf", diff --git a/src/s_tir/meta_schedule/arg_info.cc b/src/s_tir/meta_schedule/arg_info.cc index dc452b370037..73fa41773883 100644 --- a/src/s_tir/meta_schedule/arg_info.cc +++ b/src/s_tir/meta_schedule/arg_info.cc @@ -98,7 +98,7 @@ ffi::Array ArgInfo::FromPrimFunc(const tirx::PrimFunc& func) { for (const tirx::Var& arg : func->params) { if (ffi::Optional _buffer = func->buffer_map.Get(arg)) { tirx::Buffer buffer = _buffer.value(); - result.push_back(TensorInfo(/*dtype=*/buffer->dtype, + result.push_back(TensorInfo(/*dtype=*/buffer->dtype->dtype, /*shape=*/AsVector(buffer->shape))); } else { TVM_FFI_THROW(ValueError) << "Unsupported argument type: " << arg; @@ -117,7 +117,7 @@ ffi::Array ArgInfo::FromEntryFunc(const IRModule& mod, bool remove_prep /******** TensorInfo ********/ -TensorInfo::TensorInfo(runtime::DataType dtype, ffi::Shape shape) { +TensorInfo::TensorInfo(DLDataType dtype, ffi::Shape shape) { ffi::ObjectPtr n = ffi::make_object(); n->dtype = dtype; n->shape = shape; @@ -150,7 +150,7 @@ TensorInfo TensorInfo::FromJSON(const ffi::ObjectRef& json_obj) { } std::vector s; std::transform(shape.begin(), shape.end(), std::back_inserter(s), [](int64_t i) { return i; }); - return TensorInfo(DataType(dtype), ffi::Shape(s.begin(), s.end())); + return TensorInfo(dtype, ffi::Shape(s.begin(), s.end())); } /******** Repr ********/ @@ -182,10 +182,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("s_tir.meta_schedule.ArgInfoFromPrimFunc", ArgInfo::FromPrimFunc) .def("s_tir.meta_schedule.ArgInfoFromEntryFunc", ArgInfo::FromEntryFunc) .def("s_tir.meta_schedule.ArgInfoFromJSON", ArgInfo::FromJSON) - .def("s_tir.meta_schedule.TensorInfo", - [](runtime::DataType dtype, ffi::Shape shape) -> TensorInfo { - return TensorInfo(dtype, shape); - }); + .def("s_tir.meta_schedule.TensorInfo", [](DLDataType dtype, ffi::Shape shape) -> TensorInfo { + return TensorInfo(dtype, shape); + }); } } // namespace meta_schedule diff --git a/src/s_tir/meta_schedule/database/database_utils.cc b/src/s_tir/meta_schedule/database/database_utils.cc index ea1473ae6500..826c38c8d1b0 100644 --- a/src/s_tir/meta_schedule/database/database_utils.cc +++ b/src/s_tir/meta_schedule/database/database_utils.cc @@ -32,7 +32,9 @@ void JSONDumps(Any json_obj, std::ostringstream& os) { os << "null"; } else if (auto opt_int_imm = json_obj.try_cast()) { IntImm int_imm = *std::move(opt_int_imm); - if (int_imm->dtype == DataType::Bool()) { + PrimType int_ty = int_imm.ty(); + if (int_ty.MatchesElementType(DLDataTypeCode::kDLBool, 8) && !int_ty.IsScalableVector() && + !int_ty.IsFixedLengthVector()) { if (int_imm->value) { os << "true"; } else { @@ -154,7 +156,6 @@ class JSONTokenizer { bool NextFalse() { return NextLiteral("false", 5); } bool NextNumber(Token* token) { - using runtime::DataType; bool is_float = false; const char* st = cur_; for (; cur_ != end_; ++cur_) { diff --git a/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc b/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc index f0e3aa897cdd..2f87217db065 100644 --- a/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc @@ -273,12 +273,12 @@ Pass SimplifyForFeatureExtraction() { HasBufferLoad(node->condition)) { return ffi::GetRef