diff --git a/CMakeLists.txt b/CMakeLists.txt index c7bbbf770..5e422c6fc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,7 +60,10 @@ endif () if (APPLE) set(METAL_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/third-party/metal-cpp + /usr/local/include/metal_irconverter ${CMAKE_CURRENT_SOURCE_DIR}/third-party/metal_irconverter_runtime) + find_library(METAL_IRCONVERTER_LIBRARY metalirconverter PATHS /usr/local/lib REQUIRED) + set(METAL_LIBRARIES ${METAL_IRCONVERTER_LIBRARY}) set(OFFLOADTEST_ENABLE_METAL On) endif () diff --git a/include/Support/Pipeline.h b/include/Support/Pipeline.h index 01da45b76..e1aeb8e74 100644 --- a/include/Support/Pipeline.h +++ b/include/Support/Pipeline.h @@ -62,6 +62,33 @@ enum class ResourceKind { SampledTexture2D, }; +enum class DescriptorKind { UAV, SRV, CBV, SAMPLER }; + +static DescriptorKind getDescriptorKind(ResourceKind RK) { + switch (RK) { + case ResourceKind::Buffer: + case ResourceKind::StructuredBuffer: + case ResourceKind::ByteAddressBuffer: + case ResourceKind::Texture2D: + return DescriptorKind::SRV; + + case ResourceKind::RWStructuredBuffer: + case ResourceKind::RWBuffer: + case ResourceKind::RWByteAddressBuffer: + case ResourceKind::RWTexture2D: + return DescriptorKind::UAV; + + case ResourceKind::ConstantBuffer: + return DescriptorKind::CBV; + + case ResourceKind::Sampler: + return DescriptorKind::SAMPLER; + case ResourceKind::SampledTexture2D: + llvm_unreachable("Sampled textures aren't supported!"); + } + llvm_unreachable("All cases handled"); +} + enum class FilterMode { Nearest, Linear }; enum class AddressMode { Clamp, Repeat, Mirror, Border, MirrorOnce }; @@ -403,7 +430,6 @@ struct Shader { Stages Stage; std::string Entry; std::unique_ptr Shader; - std::unique_ptr Reflection; int DispatchSize[3]; llvm::SmallVector SpecializationConstants; }; diff --git a/lib/API/CMakeLists.txt b/lib/API/CMakeLists.txt index ec3d24feb..83fe8c91f 100644 --- a/lib/API/CMakeLists.txt +++ b/lib/API/CMakeLists.txt @@ -19,12 +19,15 @@ if (OFFLOADTEST_ENABLE_D3D12) endif() if (APPLE) - list(APPEND api_sources MTL/MTLDevice.cpp) + list(APPEND api_sources MTL/MTLDevice.cpp + MTL/MTLDescriptorHeap.cpp + MTL/MTLTopLevelArgumentBuffer.cpp) list(APPEND api_libraries "-framework Metal" "-framework MetalKit" "-framework AppKit" "-framework Foundation" - "-framework QuartzCore") + "-framework QuartzCore" + ${METAL_LIBRARIES}) list(APPEND api_headers PRIVATE ${METAL_INCLUDE_DIRS}) endif() diff --git a/lib/API/DX/Device.cpp b/lib/API/DX/Device.cpp index c74913d27..c97db5f33 100644 --- a/lib/API/DX/Device.cpp +++ b/lib/API/DX/Device.cpp @@ -143,33 +143,6 @@ static D3D12_RESOURCE_DIMENSION getDXDimension(ResourceKind RK) { llvm_unreachable("All cases handled"); } -enum DXResourceKind { UAV, SRV, CBV, SAMPLER }; - -static DXResourceKind getDXKind(offloadtest::ResourceKind RK) { - switch (RK) { - case ResourceKind::Buffer: - case ResourceKind::StructuredBuffer: - case ResourceKind::ByteAddressBuffer: - case ResourceKind::Texture2D: - return SRV; - - case ResourceKind::RWStructuredBuffer: - case ResourceKind::RWBuffer: - case ResourceKind::RWByteAddressBuffer: - case ResourceKind::RWTexture2D: - return UAV; - - case ResourceKind::ConstantBuffer: - return CBV; - - case ResourceKind::Sampler: - return SAMPLER; - case ResourceKind::SampledTexture2D: - llvm_unreachable("Sampled textures aren't supported in DirectX!"); - } - llvm_unreachable("All cases handled"); -} - static llvm::Expected getResourceDescription(const Resource &R) { const D3D12_RESOURCE_DIMENSION Dimension = getDXDimension(R.Kind); @@ -189,7 +162,8 @@ getResourceDescription(const Resource &R) { if (R.isTexture()) Layout = - R.IsReserved && (getDXKind(R.Kind) == SRV || getDXKind(R.Kind) == UAV) + R.IsReserved && (getDescriptorKind(R.Kind) == DescriptorKind::SRV || + getDescriptorKind(R.Kind) == DescriptorKind::UAV) ? D3D12_TEXTURE_LAYOUT_64KB_UNDEFINED_SWIZZLE : D3D12_TEXTURE_LAYOUT_UNKNOWN; else @@ -726,17 +700,17 @@ class DXDevice : public offloadtest::Device { uint32_t DescriptorIdx = 0; const uint32_t StartRangeIdx = RangeIdx; for (const auto &R : D.Resources) { - switch (getDXKind(R.Kind)) { - case SRV: + switch (getDescriptorKind(R.Kind)) { + case DescriptorKind::SRV: Ranges.get()[RangeIdx].RangeType = D3D12_DESCRIPTOR_RANGE_TYPE_SRV; break; - case UAV: + case DescriptorKind::UAV: Ranges.get()[RangeIdx].RangeType = D3D12_DESCRIPTOR_RANGE_TYPE_UAV; break; - case CBV: + case DescriptorKind::CBV: Ranges.get()[RangeIdx].RangeType = D3D12_DESCRIPTOR_RANGE_TYPE_CBV; break; - case SAMPLER: + case DescriptorKind::SAMPLER: llvm_unreachable("Not implemented yet."); } Ranges.get()[RangeIdx].NumDescriptors = R.getArraySize(); @@ -1224,29 +1198,29 @@ class DXDevice : public offloadtest::Device { [&IS, this](Resource &R, llvm::SmallVectorImpl &Resources) -> llvm::Error { - switch (getDXKind(R.Kind)) { - case SRV: { + switch (getDescriptorKind(R.Kind)) { + case DescriptorKind::SRV: { auto ExRes = createSRV(R, IS); if (!ExRes) return ExRes.takeError(); Resources.push_back(std::make_pair(&R, *ExRes)); break; } - case UAV: { + case DescriptorKind::UAV: { auto ExRes = createUAV(R, IS); if (!ExRes) return ExRes.takeError(); Resources.push_back(std::make_pair(&R, *ExRes)); break; } - case CBV: { + case DescriptorKind::CBV: { auto ExRes = createCBV(R, IS); if (!ExRes) return ExRes.takeError(); Resources.push_back(std::make_pair(&R, *ExRes)); break; } - case SAMPLER: + case DescriptorKind::SAMPLER: return llvm::createStringError( std::errc::not_supported, "Samplers are not yet implemented for DirectX."); @@ -1266,17 +1240,17 @@ class DXDevice : public offloadtest::Device { uint32_t HeapIndex = 0; for (auto &T : IS.DescTables) { for (auto &R : T.Resources) { - switch (getDXKind(R.first->Kind)) { - case SRV: + switch (getDescriptorKind(R.first->Kind)) { + case DescriptorKind::SRV: HeapIndex = bindSRV(*(R.first), IS, HeapIndex, R.second); break; - case UAV: + case DescriptorKind::UAV: HeapIndex = bindUAV(*(R.first), IS, HeapIndex, R.second); break; - case CBV: + case DescriptorKind::CBV: HeapIndex = bindCBV(*(R.first), IS, HeapIndex, R.second); break; - case SAMPLER: + case DescriptorKind::SAMPLER: llvm_unreachable("Not implemented yet."); } } @@ -1410,23 +1384,23 @@ class DXDevice : public offloadtest::Device { return llvm::createStringError( std::errc::value_too_large, "Root descriptor cannot refer to resource arrays."); - switch (getDXKind(RootDescIt->first->Kind)) { - case SRV: + switch (getDescriptorKind(RootDescIt->first->Kind)) { + case DescriptorKind::SRV: IS.CB->CmdList->SetComputeRootShaderResourceView( RootParamIndex++, RootDescIt->second.back().Buffer->GetGPUVirtualAddress()); break; - case UAV: + case DescriptorKind::UAV: IS.CB->CmdList->SetComputeRootUnorderedAccessView( RootParamIndex++, RootDescIt->second.back().Buffer->GetGPUVirtualAddress()); break; - case CBV: + case DescriptorKind::CBV: IS.CB->CmdList->SetComputeRootConstantBufferView( RootParamIndex++, RootDescIt->second.back().Buffer->GetGPUVirtualAddress()); break; - case SAMPLER: + case DescriptorKind::SAMPLER: llvm_unreachable("Not implemented yet."); } ++RootDescIt; diff --git a/lib/API/MTL/MTLDescriptorHeap.cpp b/lib/API/MTL/MTLDescriptorHeap.cpp new file mode 100644 index 000000000..ee759b1cc --- /dev/null +++ b/lib/API/MTL/MTLDescriptorHeap.cpp @@ -0,0 +1,72 @@ +#include "MTLDescriptorHeap.h" +#include "MetalIRConverter.h" + +using namespace offloadtest; + +static NS::UInteger getDescriptorHeapBindPoint(MTLDescriptorHeapType Type) { + switch (Type) { + case MTLDescriptorHeapType::CBV_SRV_UAV: + return kIRDescriptorHeapBindPoint; + case MTLDescriptorHeapType::Sampler: + return kIRSamplerHeapBindPoint; + } + llvm_unreachable("All cases handled."); +} + +MTLGPUDescriptorHandle & +MTLGPUDescriptorHandle::Offset(int32_t OffsetInDescriptors) { + Ptr = MTL::GPUAddress(int64_t(Ptr) + int64_t(OffsetInDescriptors) * + sizeof(IRDescriptorTableEntry)); + return *this; +} + +llvm::Expected> +MTLDescriptorHeap::create(MTL::Device *Device, + const MTLDescriptorHeapDesc &Desc) { + if (!Device) + return llvm::createStringError(std::errc::invalid_argument, + "Invalid MTL::Device pointer."); + + if (Desc.NumDescriptors == 0) + return llvm::createStringError(std::errc::invalid_argument, + "Invalid descriptor heap description."); + + MTL::Buffer *Buf = + Device->newBuffer(Desc.NumDescriptors * sizeof(IRDescriptorTableEntry), + MTL::ResourceStorageModeShared); + if (!Buf) + return llvm::createStringError(std::errc::not_enough_memory, + "Failed to create MTLDescriptorHeap."); + return std::make_unique(Desc, Buf); +} + +MTLDescriptorHeap::~MTLDescriptorHeap() { + if (Buffer) + Buffer->release(); +} + +MTLGPUDescriptorHandle +MTLDescriptorHeap::getGPUDescriptorHandleForHeapStart() const { + return MTLGPUDescriptorHandle{Buffer->gpuAddress()}; +} + +IRDescriptorTableEntry * +MTLDescriptorHeap::getEntryHandle(uint32_t Index) const { + assert(Index < Desc.NumDescriptors && "Descriptor index out of bounds."); + return static_cast(Buffer->contents()) + Index; +} + +void MTLDescriptorHeap::bind(MTL::RenderCommandEncoder *Encoder) { + Encoder->useResource(Buffer, MTL::ResourceUsageRead); + // Dynamic resource indexing + const NS::UInteger BindPoint = getDescriptorHeapBindPoint(Desc.Type); + Encoder->setVertexBuffer(Buffer, 0, BindPoint); + Encoder->setFragmentBuffer(Buffer, 0, BindPoint); +} + +void MTLDescriptorHeap::bind(MTL::ComputeCommandEncoder *Encoder) { + Encoder->useResource(Buffer, MTL::ResourceUsageRead); + // Dynamic resource indexing + const NS::UInteger BindPoint = getDescriptorHeapBindPoint(Desc.Type); + Encoder->setBuffer(Buffer, 0, BindPoint); +} diff --git a/lib/API/MTL/MTLDescriptorHeap.h b/lib/API/MTL/MTLDescriptorHeap.h new file mode 100644 index 000000000..76b980136 --- /dev/null +++ b/lib/API/MTL/MTLDescriptorHeap.h @@ -0,0 +1,69 @@ +//===- MTLDescriptorHeap.h - Metal Descriptor Heap ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// +//===----------------------------------------------------------------------===// + +#ifndef OFFLOADTEST_API_MTL_MTLDESCRIPTORHEAP_H +#define OFFLOADTEST_API_MTL_MTLDESCRIPTORHEAP_H + +#include "llvm/Support/Error.h" +#include + +// Forward declarations +namespace MTL { +class Device; +class Buffer; +class RenderCommandEncoder; +class ComputeCommandEncoder; +} // namespace MTL +struct IRDescriptorTableEntry; + +namespace offloadtest { +struct MTLGPUDescriptorHandle { + MTLGPUDescriptorHandle &Offset(int32_t OffsetInDescriptors); + + uint64_t Ptr; +}; + +enum class MTLDescriptorHeapType { + CBV_SRV_UAV, + Sampler, +}; + +struct MTLDescriptorHeapDesc { + MTLDescriptorHeapType Type; + uint32_t NumDescriptors; +}; + +// MTLDescriptorHeap mimics the D3D12 descriptor heap concept, except +// MTLDescriptorHeap is always shader visible and meant to be used +// by the argument buffer for shader resource binding with the explicit root +// signature layout. +class MTLDescriptorHeap { + MTLDescriptorHeapDesc Desc; + MTL::Buffer *Buffer; + +public: + static llvm::Expected> + create(MTL::Device *Device, const MTLDescriptorHeapDesc &Desc); + + MTLDescriptorHeap(const MTLDescriptorHeapDesc &Desc, MTL::Buffer *Buffer) + : Desc(Desc), Buffer(Buffer) {} + ~MTLDescriptorHeap(); + + MTLGPUDescriptorHandle getGPUDescriptorHandleForHeapStart() const; + + IRDescriptorTableEntry *getEntryHandle(uint32_t Index) const; + + void bind(MTL::RenderCommandEncoder *Encoder); + void bind(MTL::ComputeCommandEncoder *Encoder); +}; +} // namespace offloadtest + +#endif // OFFLOADTEST_API_MTL_MTLDESCRIPTORHEAP_H diff --git a/lib/API/MTL/MTLDevice.cpp b/lib/API/MTL/MTLDevice.cpp index 4adb5ca0f..fad5dc536 100644 --- a/lib/API/MTL/MTLDevice.cpp +++ b/lib/API/MTL/MTLDevice.cpp @@ -7,10 +7,13 @@ #define IR_RUNTIME_METALCPP #define IR_PRIVATE_IMPLEMENTATION +#include "metal_irconverter.h" #include "metal_irconverter_runtime.h" #include "API/Device.h" +#include "MTLDescriptorHeap.h" #include "MTLResources.h" +#include "MTLTopLevelArgumentBuffer.h" #include "Support/Pipeline.h" #include "llvm/ADT/ScopeExit.h" @@ -19,6 +22,7 @@ #include "llvm/Support/JSON.h" #include "llvm/Support/raw_ostream.h" #include +#include using namespace offloadtest; @@ -34,6 +38,53 @@ static llvm::Error toError(NS::Error *Err) { return llvm::createStringError(EC, ErrMsg); } +static llvm::Error toError(const IRError *Err, llvm::StringRef Context) { + if (!Err) + return llvm::Error::success(); + + const uint32_t Code = IRErrorGetCode(Err); + if (IRErrorCodeNoError == Code) + return llvm::Error::success(); + + const std::error_code EC = + std::error_code(static_cast(Code), std::generic_category()); + llvm::SmallString<64> ErrMsg; + llvm::raw_svector_ostream OS(ErrMsg); + OS << Context << ": "; + + switch (Code) { +#define IR_ERR(x) \ + case x: \ + OS << #x; \ + break; + + IR_ERR(IRErrorCodeShaderRequiresRootSignature); + IR_ERR(IRErrorCodeUnrecognizedRootSignatureDescriptor); + IR_ERR(IRErrorCodeUnrecognizedParameterTypeInRootSignature); + IR_ERR(IRErrorCodeResourceNotReferencedByRootSignature); + IR_ERR(IRErrorCodeShaderIncompatibleWithDualSourceBlending); + IR_ERR(IRErrorCodeUnsupportedWaveSize); + IR_ERR(IRErrorCodeUnsupportedInstruction); + IR_ERR(IRErrorCodeCompilationError); + IR_ERR(IRErrorCodeFailedToSynthesizeStageInFunction); + IR_ERR(IRErrorCodeFailedToSynthesizeStreamOutFunction); + IR_ERR(IRErrorCodeFailedToSynthesizeIndirectIntersectionFunction); + IR_ERR(IRErrorCodeUnableToVerifyModule); + IR_ERR(IRErrorCodeUnableToLinkModule); + IR_ERR(IRErrorCodeUnrecognizedDXILHeader); + IR_ERR(IRErrorCodeInvalidRaytracingAttribute); + IR_ERR(IRErrorCodeNullHullShaderInputOutputMismatch); + IR_ERR(IRErrorCodeInvalidRaytracingUserAttributeSize); + IR_ERR(IRErrorCodeIncorrectHitgroupType); + IR_ERR(IRErrorCodeFP64Usage); + IR_ERR(IRErrorCodeUnknown); + default: + break; +#undef IR_ERR + } + return llvm::createStringError(EC, ErrMsg); +} + #define MTLFormats(FMT) \ if (Channels == 1) \ return MTL::PixelFormatR##FMT; \ @@ -74,7 +125,47 @@ static MTL::VertexFormat getMTLVertexFormat(DataFormat Format, int Channels) { return MTL::VertexFormatInvalid; } +static IRShaderStage getShaderStage(Stages Stage) { + switch (Stage) { + case Stages::Compute: + return IRShaderStageCompute; + case Stages::Vertex: + return IRShaderStageVertex; + case Stages::Pixel: + return IRShaderStageFragment; + } + llvm_unreachable("All cases handled"); +} + namespace { +struct MTLDeleter { + template void operator()(T *Arg) const { + if (Arg) + Arg->release(); + } +}; + +template using MTLPtr = std::unique_ptr; + +template struct IRDeleter { + template constexpr void operator()(T *Arg) const { Fn(Arg); } +}; + +using IRCompilerPtr = std::unique_ptr>; +using IRObjectPtr = std::unique_ptr>; +using IRRootSignaturePtr = + std::unique_ptr>; +using IRMetalLibBinaryPtr = + std::unique_ptr>; +using IRShaderReflectionPtr = + std::unique_ptr>; +using IRErrorPtr = std::unique_ptr>; + +struct MetalIR { + IRMetalLibBinaryPtr Binary; + IRShaderReflectionPtr Reflection; +}; + class MTLQueue : public offloadtest::Queue { public: MTL::CommandQueue *Queue; @@ -177,13 +268,23 @@ class MTLDevice : public offloadtest::Device { MTL::Device *Device; MTLQueue GraphicsQueue; + struct ResourceSet { + MTLPtr Resource; + ResourceSet(MTL::Resource *Resource) : Resource(Resource) {} + }; + + // ResourceBundle will contain one ResourceSet for a singular resource + // or multiple ResourceSets for resource array. + using ResourceBundle = llvm::SmallVector; + using ResourcePair = std::pair; + + struct DescriptorTable { + llvm::SmallVector Resources; + }; + struct InvocationState { InvocationState() { Pool = NS::AutoreleasePool::alloc()->init(); } ~InvocationState() { - for (MTL::Texture *T : Textures) - T->release(); - for (MTL::Buffer *B : Buffers) - B->release(); if (ComputePipeline) ComputePipeline->release(); if (RenderPipeline) @@ -193,20 +294,181 @@ class MTLDevice : public offloadtest::Device { } NS::AutoreleasePool *Pool = nullptr; + IRRootSignaturePtr RootSig; + std::unique_ptr ArgBuffer; + std::unique_ptr DescHeap; MTL::ComputePipelineState *ComputePipeline = nullptr; MTL::RenderPipelineState *RenderPipeline = nullptr; - MTL::Buffer *ArgBuffer; MTL::Buffer *VertexBuffer; MTL::VertexDescriptor *VertexDescriptor; - llvm::SmallVector Textures; - llvm::SmallVector Buffers; std::shared_ptr FrameBufferTexture; std::shared_ptr FrameBufferReadback; std::shared_ptr DepthStencil; std::unique_ptr CB; std::unique_ptr Fence; + IRShaderReflectionPtr ComputeReflection; + + llvm::SmallVector DescTables; + // TODO: Support RootResources? }; + llvm::Error createRootSignature(const Pipeline &P, InvocationState &State) { + std::vector RootParams; + const uint32_t DescriptorCount = P.getDescriptorCount(); + const std::unique_ptr Ranges = + std::unique_ptr( + new IRDescriptorRange1[DescriptorCount]); + + uint32_t RangeIdx = 0; + for (const auto &D : P.Sets) { + uint32_t DescriptorIdx = 0; + const uint32_t StartRangeIdx = RangeIdx; + for (const auto &R : D.Resources) { + auto &Range = Ranges.get()[RangeIdx]; + switch (getDescriptorKind(R.Kind)) { + case DescriptorKind::SRV: + Range.RangeType = IRDescriptorRangeTypeSRV; + break; + case DescriptorKind::UAV: + Range.RangeType = IRDescriptorRangeTypeUAV; + break; + case DescriptorKind::CBV: + Range.RangeType = IRDescriptorRangeTypeCBV; + break; + case DescriptorKind::SAMPLER: + llvm_unreachable("Not implemented yet."); + } + Range.NumDescriptors = R.getArraySize(); + Range.BaseShaderRegister = R.DXBinding.Register; + Range.RegisterSpace = R.DXBinding.Space; + Range.OffsetInDescriptorsFromTableStart = DescriptorIdx; + llvm::outs() << "DescriptorRange[" << RangeIdx << "] {" + << " Type=" << static_cast(Range.RangeType) + << "," + << " NumDescriptors=" << Range.NumDescriptors << "," + << " BaseShaderRegister=" << Range.BaseShaderRegister + << "," + << " RegisterSpace=" << Range.RegisterSpace << "," + << " OffsetInDescriptorsFromTableStart=" + << Range.OffsetInDescriptorsFromTableStart << " }\n"; + RangeIdx++; + DescriptorIdx += R.getArraySize(); + } + + auto &Param = RootParams.emplace_back(); + Param.ParameterType = IRRootParameterTypeDescriptorTable; + Param.DescriptorTable.NumDescriptorRanges = + static_cast(D.Resources.size()); + Param.DescriptorTable.pDescriptorRanges = &Ranges.get()[StartRangeIdx]; + Param.ShaderVisibility = IRShaderVisibilityAll; + } + + // NOTE: Attempting to create a RS with version 1.0 seems to fail + // with IRErrorCodeUnrecognizedRootSignatureDescriptor, creating with 1.1 + // instead + IRVersionedRootSignatureDescriptor VersionedDesc = {}; + VersionedDesc.version = IRRootSignatureVersion_1_1; + auto &Desc = VersionedDesc.desc_1_1; + Desc.NumParameters = static_cast(RootParams.size()); + Desc.pParameters = RootParams.data(); + Desc.NumStaticSamplers = 0; + Desc.pStaticSamplers = nullptr; + Desc.Flags = P.isGraphics() + ? IRRootSignatureFlagAllowInputAssemblerInputLayout + : IRRootSignatureFlagNone; + + IRError *Err = nullptr; + IRRootSignaturePtr RootSig( + IRRootSignatureCreateFromDescriptor(&VersionedDesc, &Err)); + if (!RootSig) + return toError(IRErrorPtr(Err).get(), "Failed to create root signature"); + + State.RootSig = std::move(RootSig); + + auto ArgBufferOrErr = + MTLTopLevelArgumentBuffer::create(Device, State.RootSig.get()); + if (!ArgBufferOrErr) + return ArgBufferOrErr.takeError(); + + State.ArgBuffer = std::move(*ArgBufferOrErr); + return llvm::Error::success(); + } + + llvm::Error createDescriptorHeap(Pipeline &P, InvocationState &State) { + if (P.getDescriptorCount() == 0) { + llvm::outs() + << "No descriptors found, skipping descriptor heap creation.\n"; + return llvm::Error::success(); + } + const uint32_t DescriptorCount = P.getDescriptorCountWithFlattenedArrays(); + const MTLDescriptorHeapDesc HeapDesc = {MTLDescriptorHeapType::CBV_SRV_UAV, + DescriptorCount}; + + auto DescHeapOrErr = MTLDescriptorHeap::create(Device, HeapDesc); + if (!DescHeapOrErr) + return DescHeapOrErr.takeError(); + + State.DescHeap = std::move(*DescHeapOrErr); + llvm::outs() << "Descriptor heap created with " << DescriptorCount + << " descriptors.\n"; + return llvm::Error::success(); + } + + llvm::Expected convertToMetalIR(const Pipeline &P, const Shader &S, + const InvocationState &State) { + IRCompilerPtr Compiler(IRCompilerCreate()); + if (!Compiler) + return llvm::createStringError(std::errc::not_supported, + "Failed to create IR compiler instance."); + + if (!State.RootSig) + return llvm::createStringError( + std::errc::invalid_argument, + "Root signature must be created before converting to Metal IR."); + + // Configure IR compiler settings + IRCompilerSetEntryPointName(Compiler.get(), S.Entry.c_str()); + IRCompilerSetGlobalRootSignature(Compiler.get(), State.RootSig.get()); + if (P.isGraphics()) { + // Matches DX::Device backend: + // PSODesc.PrimitiveTopologyType = D3D12_PRIMITIVE_TOPOLOGY_TYPE_TRIANGLE; + IRCompilerSetInputTopology(Compiler.get(), IRInputTopologyTriangle); + } + + const llvm::StringRef Program = S.Shader->getBuffer(); + IRObject *DXIL = IRObjectCreateFromDXIL( + reinterpret_cast(Program.data()), Program.size(), + IRBytecodeOwnershipNone); + + // Compile DXIL to Metal IR + IRError *Err = nullptr; + IRObjectPtr ResultIR( + IRCompilerAllocCompileAndLink(Compiler.get(), nullptr, DXIL, &Err)); + if (Err) + return toError(IRErrorPtr(Err).get(), + "Failed to compile and link DXIL to Metal IR"); + + // Retrieve Metallib and shader reflection from the compiled IR object + const IRShaderStage ShaderStage = getShaderStage(S.Stage); + auto MetalLib = IRMetalLibBinaryPtr(IRMetalLibBinaryCreate()); + if (!IRObjectGetMetalLibBinary(ResultIR.get(), ShaderStage, + MetalLib.get())) { + return llvm::createStringError( + std::errc::not_supported, + "Failed to retrieve Metal library binary from " + "IR object."); + } + + auto Reflection = IRShaderReflectionPtr(IRShaderReflectionCreate()); + if (!IRObjectGetReflection(ResultIR.get(), ShaderStage, Reflection.get())) { + return llvm::createStringError( + std::errc::not_supported, + "Failed to retrieve shader reflection from IR object."); + } + + return MetalIR{std::move(MetalLib), std::move(Reflection)}; + } + llvm::Error setupVertexShader(InvocationState &IS, const Pipeline &P, MTL::Function *Fn) { if (P.Bindings.VertexBufferPtr) { @@ -270,11 +532,14 @@ class MTLDevice : public offloadtest::Device { return llvm::createStringError( std::errc::invalid_argument, "Compute pipeline must have exactly one compute shader."); - const llvm::StringRef Program = P.Shaders[0].Shader->getBuffer(); - dispatch_data_t Data = dispatch_data_create( - Program.data(), Program.size(), dispatch_get_main_queue(), - ^{ - }); + + auto MetalIR = convertToMetalIR(P, P.Shaders[0], IS); + if (!MetalIR) + return MetalIR.takeError(); + + IS.ComputeReflection = std::move(MetalIR->Reflection); + + dispatch_data_t Data = IRMetalLibGetBytecodeData(MetalIR->Binary.get()); MTL::Library *Lib = Device->newLibrary(Data, &Error); if (Error) return toError(Error); @@ -291,11 +556,11 @@ class MTLDevice : public offloadtest::Device { MTL::RenderPipelineDescriptor::alloc()->init(); IS.Pool->addObject(Desc); for (const auto &S : P.Shaders) { - const llvm::StringRef Program = S.Shader->getBuffer(); - dispatch_data_t Data = dispatch_data_create( - Program.data(), Program.size(), dispatch_get_main_queue(), - ^{ - }); + auto MetalIR = convertToMetalIR(P, S, IS); + if (!MetalIR) + return MetalIR.takeError(); + + dispatch_data_t Data = IRMetalLibGetBytecodeData(MetalIR->Binary.get()); MTL::Library *Lib = Device->newLibrary(Data, &Error); if (Error) return toError(Error); @@ -357,34 +622,27 @@ class MTLDevice : public offloadtest::Device { return llvm::Error::success(); } - llvm::Error createDescriptor(Resource &R, InvocationState &IS, - const uint32_t HeapIdx) { - auto *TablePtr = (IRDescriptorTableEntry *)IS.ArgBuffer->contents(); - - assert(R.BufferPtr->ArraySize == 1 && - "Resource arrays are not yet supported on Metal."); + // Creates a Metal resource (buffer or texture) for the given Resource at the + // specified array index. + llvm::Expected + createResource(Resource &R, size_t ResourceArrayIndex = 0) { + const offloadtest::CPUBuffer &B = *R.BufferPtr; if (R.isRaw()) { MTL::Buffer *Buf = - Device->newBuffer(R.BufferPtr->Data.back().get(), R.size(), + Device->newBuffer(B.Data[ResourceArrayIndex].get(), R.size(), MTL::ResourceStorageModeManaged); - IRBufferView View = {}; - View.buffer = Buf; - View.bufferSize = R.size(); - - IRDescriptorTableSetBufferView(&TablePtr[HeapIdx], &View); - IS.Buffers.push_back(Buf); + Buf->didModifyRange(NS::Range::Make(0, Buf->length())); + return Buf; } else { - const uint64_t Width = R.isTexture() ? R.BufferPtr->OutputProps.Width - : R.size() / R.getElementSize(); - const uint64_t Height = - R.isTexture() ? R.BufferPtr->OutputProps.Height : 1; + const uint64_t Width = + R.isTexture() ? B.OutputProps.Width : R.size() / R.getElementSize(); + const uint64_t Height = R.isTexture() ? B.OutputProps.Height : 1; MTL::TextureUsage UsageFlags = MTL::ResourceUsageRead; if (R.isReadWrite()) UsageFlags |= MTL::ResourceUsageWrite; MTL::TextureDescriptor *Desc = nullptr; - const MTL::PixelFormat Format = - getMTLFormat(R.BufferPtr->Format, R.BufferPtr->Channels); + const MTL::PixelFormat Format = getMTLFormat(B.Format, B.Channels); switch (R.Kind) { case ResourceKind::Buffer: case ResourceKind::RWBuffer: @@ -410,33 +668,190 @@ class MTLDevice : public offloadtest::Device { MTL::Texture *NewTex = Device->newTexture(Desc); NewTex->replaceRegion(MTL::Region(0, 0, Width, Height), 0, - R.BufferPtr->Data.back().get(), + B.Data[ResourceArrayIndex].get(), Width * R.getElementSize()); + return NewTex; + } + } - IS.Textures.push_back(NewTex); + llvm::Expected createSRV(Resource &R, InvocationState &IS) { + ResourceBundle Bundle; - IRDescriptorTableSetTexture(&TablePtr[HeapIdx], NewTex, 0, 0); + for (size_t RegOffset = 0; RegOffset < R.BufferPtr->Data.size(); + ++RegOffset) { + llvm::outs() << "Creating SRV: { Size = " << R.size() << ", Register = t" + << R.DXBinding.Register + RegOffset + << ", Space = " << R.DXBinding.Space; + llvm::outs() << " }\n"; + + auto ResourceOrErr = createResource(R, RegOffset); + if (!ResourceOrErr) + return ResourceOrErr.takeError(); + + Bundle.emplace_back(ResourceOrErr.get()); } + return Bundle; + } - return llvm::Error::success(); + // TODO: counter buffer via IRRuntimeCreateAppendBufferView? + llvm::Expected createUAV(Resource &R, InvocationState &IS) { + ResourceBundle Bundle; + + for (size_t RegOffset = 0; RegOffset < R.BufferPtr->Data.size(); + ++RegOffset) { + llvm::outs() << "Creating UAV: { Size = " << R.size() << ", Register = u" + << R.DXBinding.Register + RegOffset + << ", Space = " << R.DXBinding.Space + << ", HasCounter = " << R.HasCounter; + llvm::outs() << " }\n"; + + auto ResourceOrErr = createResource(R, RegOffset); + if (!ResourceOrErr) + return ResourceOrErr.takeError(); + + Bundle.emplace_back(ResourceOrErr.get()); + } + return Bundle; + } + + llvm::Expected createCBV(Resource &R, InvocationState &IS) { + ResourceBundle Bundle; + + for (size_t RegOffset = 0; RegOffset < R.BufferPtr->Data.size(); + ++RegOffset) { + llvm::outs() << "Creating CBV: { Size = " << R.size() << ", Register = b" + << R.DXBinding.Register + RegOffset + << ", Space = " << R.DXBinding.Space << " }\n"; + + auto ResourceOrErr = createResource(R, RegOffset); + if (!ResourceOrErr) + return ResourceOrErr.takeError(); + + Bundle.emplace_back(ResourceOrErr.get()); + } + return Bundle; + } + + void createDescriptor(Resource &R, MTL::Resource *Resource, + IRDescriptorTableEntry *Entry) { + if (R.isRaw()) { + IRBufferView View = {}; + View.buffer = static_cast(Resource); + View.bufferSize = R.size(); + IRDescriptorTableSetBufferView(Entry, &View); + } else { + MTL::Texture *Tex = static_cast(Resource); + IRDescriptorTableSetTexture(Entry, Tex, 0, 0); + } + } + + // returns the next available HeapIdx + uint32_t bindSRV(Resource &R, InvocationState &IS, uint32_t HeapIdx, + const ResourceBundle &ResBundle) { + const uint32_t EltSize = R.getElementSize(); + const uint32_t NumElts = R.size() / EltSize; + + for (const ResourceSet &RS : ResBundle) { + llvm::outs() << "SRV: HeapIdx = " << HeapIdx << " EltSize = " << EltSize + << " NumElts = " << NumElts << "\n"; + createDescriptor(R, RS.Resource.get(), + IS.DescHeap->getEntryHandle(HeapIdx)); + HeapIdx++; + } + return HeapIdx; + } + + // returns the next available HeapIdx + uint32_t bindUAV(Resource &R, InvocationState &IS, uint32_t HeapIdx, + const ResourceBundle &ResBundle) { + const uint32_t EltSize = R.getElementSize(); + const uint32_t NumElts = R.size() / EltSize; + for (const ResourceSet &RS : ResBundle) { + llvm::outs() << "UAV: HeapIdx = " << HeapIdx << " EltSize = " << EltSize + << " NumElts = " << NumElts << "\n"; + createDescriptor(R, RS.Resource.get(), + IS.DescHeap->getEntryHandle(HeapIdx)); + HeapIdx++; + } + return HeapIdx; + } + + // returns the next available HeapIdx + uint32_t bindCBV(Resource &R, InvocationState &IS, uint32_t HeapIdx, + const ResourceBundle &ResBundle) { + for (const ResourceSet &RS : ResBundle) { + llvm::outs() << "CBV: HeapIdx = " << HeapIdx << " Size = " << R.size() + << "\n"; + createDescriptor(R, RS.Resource.get(), + IS.DescHeap->getEntryHandle(HeapIdx)); + HeapIdx++; + } + return HeapIdx; } llvm::Error createBuffers(Pipeline &P, InvocationState &IS) { - const size_t ResourceCount = P.getDescriptorCount(); - const size_t TableSize = sizeof(IRDescriptorTableEntry) * ResourceCount; - - if (TableSize > 0) { - IS.ArgBuffer = - Device->newBuffer(TableSize, MTL::ResourceStorageModeManaged); - uint32_t HeapIndex = 0; - for (auto &D : P.Sets) { - for (auto &R : D.Resources) { - if (auto Err = createDescriptor(R, IS, HeapIndex++)) - return Err; + auto CreateBuffer = + [&IS, + this](Resource &R, + llvm::SmallVectorImpl &Resources) -> llvm::Error { + switch (getDescriptorKind(R.Kind)) { + case DescriptorKind::SRV: { + auto ExRes = createSRV(R, IS); + if (!ExRes) + return ExRes.takeError(); + Resources.emplace_back(&R, std::move(*ExRes)); + break; + } + case DescriptorKind::UAV: { + auto ExRes = createUAV(R, IS); + if (!ExRes) + return ExRes.takeError(); + Resources.emplace_back(&R, std::move(*ExRes)); + break; + } + case DescriptorKind::CBV: { + auto ExRes = createCBV(R, IS); + if (!ExRes) + return ExRes.takeError(); + Resources.emplace_back(&R, std::move(*ExRes)); + break; + } + case DescriptorKind::SAMPLER: + return llvm::createStringError( + std::errc::not_supported, + "Samplers are not yet implemented for Metal."); + } + return llvm::Error::success(); + }; + + for (auto &D : P.Sets) { + IS.DescTables.emplace_back(DescriptorTable()); + DescriptorTable &Table = IS.DescTables.back(); + for (auto &R : D.Resources) + if (auto Err = CreateBuffer(R, Table.Resources)) + return Err; + } + + // Bind descriptors in descriptor tables. + uint32_t HeapIndex = 0; + for (auto &T : IS.DescTables) { + for (auto &R : T.Resources) { + switch (getDescriptorKind(R.first->Kind)) { + case DescriptorKind::SRV: + HeapIndex = bindSRV(*(R.first), IS, HeapIndex, R.second); + break; + case DescriptorKind::UAV: + HeapIndex = bindUAV(*(R.first), IS, HeapIndex, R.second); + break; + case DescriptorKind::CBV: + HeapIndex = bindCBV(*(R.first), IS, HeapIndex, R.second); + break; + case DescriptorKind::SAMPLER: + llvm_unreachable("Not implemented yet."); } } - IS.ArgBuffer->didModifyRange(NS::Range::Make(0, IS.ArgBuffer->length())); } + if (P.isGraphics()) { // Create and mark the vertex buffer as modified. IS.VertexBuffer = Device->newBuffer( @@ -456,52 +871,36 @@ class MTLDevice : public offloadtest::Device { llvm::scope_exit([&]() { CmdEncoder->endEncoding(); }); CmdEncoder->setComputePipelineState(IS.ComputePipeline); - CmdEncoder->setBuffer(IS.ArgBuffer, 0, 2); - for (uint64_t I = 0; I < IS.Textures.size(); ++I) - CmdEncoder->useResource(IS.Textures[I], - MTL::ResourceUsageRead | MTL::ResourceUsageWrite); - for (uint64_t I = 0; I < IS.Buffers.size(); ++I) - CmdEncoder->useResource(IS.Buffers[I], - MTL::ResourceUsageRead | MTL::ResourceUsageWrite); + MTLGPUDescriptorHandle Handle = {}; + if (IS.DescHeap) { + IS.DescHeap->bind(CmdEncoder); + Handle = IS.DescHeap->getGPUDescriptorHandleForHeapStart(); + } + + for (uint32_t Idx = 0u; Idx < P.Sets.size(); ++Idx) { + IS.ArgBuffer->setRootDescriptorTable(Idx, Handle); + Handle.Offset(P.Sets[Idx].Resources.size()); + } + + IS.ArgBuffer->bind(CmdEncoder); + for (const auto &Table : IS.DescTables) + for (const auto &ResPair : Table.Resources) + for (const auto &ResSet : ResPair.second) + CmdEncoder->useResource(ResSet.Resource.get(), + MTL::ResourceUsageRead | + MTL::ResourceUsageWrite); NS::UInteger TGS[3] = {IS.ComputePipeline->maxTotalThreadsPerThreadgroup(), 1, 1}; - if (P.Shaders[0].Reflection) { - llvm::Expected E = llvm::json::parse( - llvm::StringRef(P.Shaders[0].Reflection->getBuffer())); - if (!E) - return E.takeError(); - llvm::json::Value Reflection = *E; - - const llvm::json::Object *ReflectionObj = Reflection.getAsObject(); - if (!ReflectionObj) - return llvm::createStringError( - std::errc::invalid_argument, - "Shader reflection must be a JSON object."); - auto StateIt = ReflectionObj->find("state"); - if (StateIt == ReflectionObj->end()) - return llvm::createStringError( - std::errc::invalid_argument, - "Key 'state' not found in shader reflection."); - const llvm::json::Object *State = StateIt->second.getAsObject(); - auto TGSize = State->find("tg_size"); - if (TGSize == State->end()) - return llvm::createStringError( - std::errc::invalid_argument, - "Key 'tg_size' not found in shader reflection."); - const llvm::json::Array *TGSizeArr = TGSize->second.getAsArray(); - if (TGSizeArr->size() != 3) - return llvm::createStringError( - std::errc::invalid_argument, - "Threadgroup size in reflection must have three components."); - for (size_t I = 0; I < 3; ++I) { - auto OpVal = (*TGSizeArr)[I].getAsUINT64(); - if (!OpVal) - return llvm::createStringError(std::errc::invalid_argument, - "Threadgroup size components in " - "reflection must be integers."); - TGS[I] = *OpVal; + if (IS.ComputeReflection) { + IRVersionedCSInfo Info; + if (IRShaderReflectionCopyComputeInfo(IS.ComputeReflection.get(), + IRReflectionVersion_1_0, &Info)) { + TGS[0] = Info.info_1_0.tg_size[0]; + TGS[1] = Info.info_1_0.tg_size[1]; + TGS[2] = Info.info_1_0.tg_size[2]; } + IRShaderReflectionReleaseComputeInfo(&Info); } const llvm::ArrayRef DispatchSize = @@ -615,6 +1014,21 @@ class MTLDevice : public offloadtest::Device { DSDesc->release(); DSState->release(); + if (IS.DescHeap) { + IS.DescHeap->bind(CmdEncoder); + // NOTE: This code assumes 1 descriptor set (D3D12 backend also assumes + // this) + IS.ArgBuffer->setRootDescriptorTable( + 0, IS.DescHeap->getGPUDescriptorHandleForHeapStart()); + } + IS.ArgBuffer->bind(CmdEncoder); + for (const auto &Table : IS.DescTables) + for (const auto &ResPair : Table.Resources) + for (const auto &ResSet : ResPair.second) + CmdEncoder->useResource(ResSet.Resource.get(), + MTL::ResourceUsageRead | + MTL::ResourceUsageWrite); + // Explicitly set viewport to texture dimensions. CmdEncoder->setViewport( MTL::Viewport{0.0, 0.0, (double)Width, (double)Height, 0.0, 1.0}); @@ -664,33 +1078,37 @@ class MTLDevice : public offloadtest::Device { } llvm::Error copyBack(Pipeline &P, InvocationState &IS) { - uint32_t TextureIndex = 0; - uint32_t BufferIndex = 0; - for (auto &D : P.Sets) { - for (auto &R : D.Resources) { - assert(R.BufferPtr->ArraySize == 1 && - "Resource arrays are not yet supported on Metal."); - if (R.isReadOnly()) { - if (R.isRaw()) - ++BufferIndex; - else - ++TextureIndex; - continue; - } + auto MemCpyBack = [](ResourcePair &Pair) -> llvm::Error { + const Resource &R = *Pair.first; + if (!R.isReadWrite()) + return llvm::Error::success(); + + const CPUBuffer &B = *R.BufferPtr; + auto *RSIt = Pair.second.begin(); + auto *DataIt = B.Data.begin(); + for (; RSIt != Pair.second.end() && DataIt != B.Data.end(); + ++RSIt, ++DataIt) { if (R.isRaw()) { - memcpy(R.BufferPtr->Data.back().get(), - IS.Buffers[BufferIndex++]->contents(), R.size()); - continue; + MTL::Buffer *Buf = static_cast(RSIt->Resource.get()); + memcpy(DataIt->get(), Buf->contents(), Buf->length()); + } else { + MTL::Texture *Tex = static_cast(RSIt->Resource.get()); + const uint64_t Width = R.isTexture() ? B.OutputProps.Width + : R.size() / R.getElementSize(); + const uint64_t Height = R.isTexture() ? B.OutputProps.Height : 1; + Tex->getBytes(DataIt->get(), Width * R.getElementSize(), + MTL::Region(0, 0, Width, Height), 0); } - const uint64_t Width = R.isTexture() ? R.BufferPtr->OutputProps.Width - : R.size() / R.getElementSize(); - const uint64_t Height = - R.isTexture() ? R.BufferPtr->OutputProps.Height : 1; - IS.Textures[TextureIndex++]->getBytes( - R.BufferPtr->Data.back().get(), Width * R.getElementSize(), - MTL::Region(0, 0, Width, Height), 0); } - } + + return llvm::Error::success(); + }; + + for (auto &Table : IS.DescTables) + for (auto &R : Table.Resources) + if (auto Err = MemCpyBack(R)) + return Err; + if (P.isGraphics()) { CPUBuffer *RTarget = P.Bindings.RTargetBufferPtr; const uint64_t Width = RTarget->OutputProps.Width; @@ -781,6 +1199,12 @@ class MTLDevice : public offloadtest::Device { return FenceOrErr.takeError(); IS.Fence = std::move(*FenceOrErr); + if (auto Err = createRootSignature(P, IS)) + return Err; + + if (auto Err = createDescriptorHeap(P, IS)) + return Err; + if (auto Err = createBuffers(P, IS)) return Err; diff --git a/lib/API/MTL/MTLTopLevelArgumentBuffer.cpp b/lib/API/MTL/MTLTopLevelArgumentBuffer.cpp new file mode 100644 index 000000000..689bb8123 --- /dev/null +++ b/lib/API/MTL/MTLTopLevelArgumentBuffer.cpp @@ -0,0 +1,182 @@ +#include "MTLTopLevelArgumentBuffer.h" + +using namespace offloadtest; + +static llvm::StringRef getResourceTypeString(IRResourceType Type) { + switch (Type) { + case IRResourceTypeCBV: + return "CBV"; + case IRResourceTypeSRV: + return "SRV"; + case IRResourceTypeUAV: + return "UAV"; + case IRResourceTypeTable: + return "Table"; + default: + return "Unknown"; + } +} + +llvm::Expected> +MTLTopLevelArgumentBuffer::create(MTL::Device *Device, + IRRootSignature *RootSig) { + if (!Device) + return llvm::createStringError(std::errc::invalid_argument, + "Invalid MTL::Device pointer."); + + if (!RootSig) + return llvm::createStringError(std::errc::invalid_argument, + "Invalid IRRootSignature pointer."); + + std::vector ResourceLocs( + IRRootSignatureGetResourceCount(RootSig)); + // Empty root signature is valid, bind methods will be no-ops + if (ResourceLocs.empty()) { + return std::make_unique(ResourceLocs, nullptr); + } + + IRRootSignatureGetResourceLocations(RootSig, ResourceLocs.data()); + + size_t BufferSize = 0; + for (size_t ResourceIdx = 0; ResourceIdx < ResourceLocs.size(); + ++ResourceIdx) { + const IRResourceLocation &Loc = ResourceLocs[ResourceIdx]; + BufferSize += Loc.sizeBytes; + + llvm::outs() << "Resource[" << ResourceIdx << "] {" + << " Type=" << getResourceTypeString(Loc.resourceType) << "," + << " Space=" << Loc.space << "," + << " Slot=" << Loc.slot << "," + << " TopLevelOffset=" << Loc.topLevelOffset << "," + << " SizeInBytes=" << Loc.sizeBytes << " }\n"; + } + MTL::Buffer *Buffer = + Device->newBuffer(BufferSize, MTL::ResourceStorageModeShared); + if (!Buffer) + return llvm::createStringError( + std::errc::not_enough_memory, + "Failed to create top-level argument buffer."); + return std::make_unique(std::move(ResourceLocs), + Buffer); +} + +MTLTopLevelArgumentBuffer::~MTLTopLevelArgumentBuffer() { + if (Buffer) + Buffer->release(); +} + +bool MTLTopLevelArgumentBuffer::checkIndex(uint32_t Index) const { + if (Index >= ResourceLocs.size()) { + llvm::errs() << "Invalid index " << Index << ", only " + << ResourceLocs.size() + << " resources in root " + "signature.\n"; + return false; + } + return true; +} + +bool MTLTopLevelArgumentBuffer::checkResourceType( + uint32_t Index, IRResourceType ExpectedType) const { + const IRResourceLocation &Loc = ResourceLocs[Index]; + if (Loc.resourceType != ExpectedType) { + llvm::errs() << "Resource type mismatch for index " << Index + << ", expected " << static_cast(ExpectedType) + << " but root signature specifies " + << static_cast(Loc.resourceType) << ".\n"; + return false; + } + return true; +} + +bool MTLTopLevelArgumentBuffer::checkResourceSize(uint32_t Index, + size_t ExpectedSize) const { + const IRResourceLocation &Loc = ResourceLocs[Index]; + if (Loc.sizeBytes != ExpectedSize) { + llvm::errs() << "Size mismatch for index " << Index << ", expected " + << ExpectedSize << " but root signature specifies " + << Loc.sizeBytes << ".\n"; + return false; + } + return true; +} + +template +void MTLTopLevelArgumentBuffer::setResource(uint32_t Index, T Resource) const { + if (!Buffer || !checkIndex(Index) || + !checkResourceType(Index, ResourceType) || + !checkResourceSize(Index, sizeof(T))) + return; + + const IRResourceLocation &Loc = ResourceLocs[Index]; + std::byte *Dst = + static_cast(Buffer->contents()) + Loc.topLevelOffset; + memcpy(Dst, &Resource, sizeof(T)); +} + +void MTLTopLevelArgumentBuffer::setRoot32BitConstant( + uint32_t Index, uint32_t SrcData, uint32_t DestOffsetIn32BitValues) const { + setRoot32BitConstants(Index, 1, &SrcData, DestOffsetIn32BitValues); +} + +void MTLTopLevelArgumentBuffer::setRoot32BitConstants( + uint32_t Index, uint32_t Num32BitValuesToSet, const void *pSrcData, + uint32_t DestOffsetIn32BitValues) const { + if (!Buffer || !checkIndex(Index) || + !checkResourceType(Index, IRResourceTypeConstant)) + return; + + const IRResourceLocation &Loc = ResourceLocs[Index]; + if ((DestOffsetIn32BitValues + Num32BitValuesToSet) * sizeof(uint32_t) > + Loc.sizeBytes) { + llvm::errs() << "Size mismatch for index " << Index << ", root signature " + << "specifies " << Loc.sizeBytes << " bytes but trying to set " + << (DestOffsetIn32BitValues + Num32BitValuesToSet) * + sizeof(uint32_t) + << " bytes.\n"; + return; + } + + std::byte *Dst = static_cast(Buffer->contents()) + + Loc.topLevelOffset + + DestOffsetIn32BitValues * sizeof(uint32_t); + memcpy(Dst, pSrcData, Num32BitValuesToSet * sizeof(uint32_t)); +} + +void MTLTopLevelArgumentBuffer::setRootConstantBufferView( + uint32_t Index, uint64_t GPUAddr) const { + setResource(Index, GPUAddr); +} + +void MTLTopLevelArgumentBuffer::setRootShaderResourceView( + uint32_t Index, uint64_t GPUAddr) const { + setResource(Index, GPUAddr); +} + +void MTLTopLevelArgumentBuffer::setRootUnorderedAccessView( + uint32_t Index, uint64_t GPUAddr) const { + setResource(Index, GPUAddr); +} + +void MTLTopLevelArgumentBuffer::setRootDescriptorTable( + uint32_t Index, MTLGPUDescriptorHandle BaseHandle) const { + setResource(Index, BaseHandle); +} + +void MTLTopLevelArgumentBuffer::bind(MTL::RenderCommandEncoder *Encoder) const { + if (!Buffer) + return; + + Encoder->useResource(Buffer, MTL::ResourceUsageRead); + Encoder->setVertexBuffer(Buffer, 0, kIRArgumentBufferBindPoint); + Encoder->setFragmentBuffer(Buffer, 0, kIRArgumentBufferBindPoint); +} + +void MTLTopLevelArgumentBuffer::bind( + MTL::ComputeCommandEncoder *Encoder) const { + if (!Buffer) + return; + + Encoder->useResource(Buffer, MTL::ResourceUsageRead); + Encoder->setBuffer(Buffer, 0, kIRArgumentBufferBindPoint); +} diff --git a/lib/API/MTL/MTLTopLevelArgumentBuffer.h b/lib/API/MTL/MTLTopLevelArgumentBuffer.h new file mode 100644 index 000000000..b370d7b10 --- /dev/null +++ b/lib/API/MTL/MTLTopLevelArgumentBuffer.h @@ -0,0 +1,72 @@ +//===- MTLTopLevelArgumentBuffer.h - Metal Top Level Argument Buffer ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// +//===----------------------------------------------------------------------===// + +#ifndef OFFLOADTEST_API_MTL_MTLTOPLEVELARGUMENTBUFFER_H +#define OFFLOADTEST_API_MTL_MTLTOPLEVELARGUMENTBUFFER_H + +#include "MTLDescriptorHeap.h" +#include "MetalIRConverter.h" + +namespace offloadtest { +// Manages a Metal buffer that serves as the top-level argument buffer for +// shader resource binding with the explicit root signature layout. +class MTLTopLevelArgumentBuffer { + std::vector ResourceLocs; + MTL::Buffer *Buffer = nullptr; + + bool checkIndex(uint32_t Index) const; + bool checkResourceType(uint32_t Index, IRResourceType ExpectedType) const; + bool checkResourceSize(uint32_t Index, size_t ExpectedSize) const; + + template + void setResource(uint32_t Index, T Resource) const; + +public: + /// @brief Creates a MTLTopLevelArgumentBuffer based on the given root + /// signature. Empty root signature (zero resources) is allowed, bind methods + /// will be no-op in that case. + /// @param Device Metal device to create the argument buffer on. + /// @param RootSig Root signature describing the layout of the argument + /// buffer. + /// @return Created MTLTopLevelArgumentBuffer or error if creation failed. + static llvm::Expected> + create(MTL::Device *Device, IRRootSignature *RootSig); + + MTLTopLevelArgumentBuffer(std::vector ResourceLocs, + MTL::Buffer *Buffer) + : ResourceLocs(std::move(ResourceLocs)), Buffer(Buffer) {} + ~MTLTopLevelArgumentBuffer(); + + // Binds 32-bit root constant(s) to the argument buffer. + void setRoot32BitConstant(uint32_t Index, uint32_t SrcData, + uint32_t DestOffsetIn32BitValues) const; + void setRoot32BitConstants(uint32_t Index, uint32_t Num32BitValuesToSet, + const void *pSrcData, + uint32_t DestOffsetIn32BitValues) const; + + // Binds CBV/SRV/UAV resource via GPU address or resource ID to the argument + // buffer. + void setRootConstantBufferView(uint32_t Index, uint64_t GPUAddr) const; + void setRootShaderResourceView(uint32_t Index, uint64_t GPUAddr) const; + void setRootUnorderedAccessView(uint32_t Index, uint64_t GPUAddr) const; + + // Binds descriptor table to the argument buffer. + void setRootDescriptorTable(uint32_t Index, + MTLGPUDescriptorHandle BaseHandle) const; + + // Bind the argument buffer to the render command encoder. + void bind(MTL::RenderCommandEncoder *Encoder) const; + // Bind the argument buffer to the compute command encoder. + void bind(MTL::ComputeCommandEncoder *Encoder) const; +}; +} // namespace offloadtest + +#endif // OFFLOADTEST_API_MTL_MTLTOPLEVELARGUMENTBUFFER_H diff --git a/lib/API/MTL/MetalIRConverter.h b/lib/API/MTL/MetalIRConverter.h new file mode 100644 index 000000000..41eb69629 --- /dev/null +++ b/lib/API/MTL/MetalIRConverter.h @@ -0,0 +1,22 @@ +//===- MetalIRConverter.h - Metal IR Converter ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// +//===----------------------------------------------------------------------===// + +#ifndef OFFLOADTEST_API_MTL_METALIRCONVERTER_H +#define OFFLOADTEST_API_MTL_METALIRCONVERTER_H + +#pragma push_macro("IR_RUNTIME_METALCPP") +#define IR_RUNTIME_METALCPP +#include "Metal/Metal.hpp" +#include "metal_irconverter.h" +#include "metal_irconverter_runtime.h" +#pragma pop_macro("IR_RUNTIME_METALCPP") + +#endif // OFFLOADTEST_API_MTL_METALIRCONVERTER_H diff --git a/test/Feature/ResourceArrays/array-global.test b/test/Feature/ResourceArrays/array-global.test index bbc07f061..b37370cf2 100644 --- a/test/Feature/ResourceArrays/array-global.test +++ b/test/Feature/ResourceArrays/array-global.test @@ -134,10 +134,6 @@ DescriptorSets: ... #--- end -# Offload tests are missing support for resource arrays on Metal -# Unimplemented https://github.com/llvm/offload-test-suite/issues/305 -# XFAIL: Metal - # RUN: split-file %s %t # RUN: %dxc_target -T cs_6_0 -Fo %t.o %t/source.hlsl # RUN: %offloader %t/pipeline.yaml %t.o | FileCheck %s diff --git a/test/Feature/ResourceArrays/array-of-constant-buffers.test b/test/Feature/ResourceArrays/array-of-constant-buffers.test index 3418de22c..ebb73bcf3 100644 --- a/test/Feature/ResourceArrays/array-of-constant-buffers.test +++ b/test/Feature/ResourceArrays/array-of-constant-buffers.test @@ -70,10 +70,6 @@ DescriptorSets: # Unimplemented https://github.com/llvm/llvm-project/issues/133835 # XFAIL: Clang -# Offload tests are missing support for resource arrays on Metal -# Unimplemented https://github.com/llvm/offload-test-suite/issues/305 -# XFAIL: Metal - # RUN: split-file %s %t # RUN: %dxc_target -T cs_6_0 -Fo %t.o %t/source.hlsl # RUN: %offloader %t/pipeline.yaml %t.o diff --git a/test/Feature/ResourceArrays/multi-dim-array-subset.test b/test/Feature/ResourceArrays/multi-dim-array-subset.test index ead1a5089..6b39f6f26 100644 --- a/test/Feature/ResourceArrays/multi-dim-array-subset.test +++ b/test/Feature/ResourceArrays/multi-dim-array-subset.test @@ -72,9 +72,6 @@ DescriptorSets: # Unimplemented https://github.com/llvm/llvm-project/issues/164908 # XFAIL: Clang && Vulkan -# Unimplemented https://github.com/llvm/offload-test-suite/issues/305 -# XFAIL: Metal - # RUN: split-file %s %t # RUN: %dxc_target -T cs_6_0 -Fo %t.o %t/source.hlsl # RUN: %offloader %t/pipeline.yaml %t.o diff --git a/test/Feature/ResourceArrays/multi-dim-unbounded-array.test b/test/Feature/ResourceArrays/multi-dim-unbounded-array.test index edae62a38..2bcfac4fd 100644 --- a/test/Feature/ResourceArrays/multi-dim-unbounded-array.test +++ b/test/Feature/ResourceArrays/multi-dim-unbounded-array.test @@ -54,9 +54,6 @@ DescriptorSets: ... #--- end -# Unimplemented https://github.com/llvm/offload-test-suite/issues/305 -# XFAIL: Metal - # Vulkan does not support multi-dimensional resource arrays # UNSUPPORTED: Vulkan diff --git a/test/Feature/ResourceArrays/overflow-unbounded-array.test b/test/Feature/ResourceArrays/overflow-unbounded-array.test index 91319c8c7..4d5be1bce 100644 --- a/test/Feature/ResourceArrays/overflow-unbounded-array.test +++ b/test/Feature/ResourceArrays/overflow-unbounded-array.test @@ -65,9 +65,6 @@ DescriptorSets: ... #--- end -# Unimplemented https://github.com/llvm/offload-test-suite/issues/305 -# XFAIL: Metal - # RUN: split-file %s %t # RUN: %dxc_target -T cs_6_0 -Fo %t.o %t/source.hlsl # RUN: %offloader %t/pipeline.yaml %t.o diff --git a/test/Feature/ResourceArrays/unbounded-array.test b/test/Feature/ResourceArrays/unbounded-array.test index 1d319e11b..1b27ffda0 100644 --- a/test/Feature/ResourceArrays/unbounded-array.test +++ b/test/Feature/ResourceArrays/unbounded-array.test @@ -55,7 +55,6 @@ DescriptorSets: #--- end # Unimplemented https://github.com/llvm/offload-test-suite/issues/305 -# XFAIL: Metal # XFAIL: DXC && Vulkan && KosmicKrisp # RUN: split-file %s %t diff --git a/test/Feature/ResourcesInStructs/res-array-of-matrix-in-struct.test b/test/Feature/ResourcesInStructs/res-array-of-matrix-in-struct.test index 99c0c9ddc..ba5a3bd00 100644 --- a/test/Feature/ResourcesInStructs/res-array-of-matrix-in-struct.test +++ b/test/Feature/ResourcesInStructs/res-array-of-matrix-in-struct.test @@ -78,9 +78,6 @@ DescriptorSets: # Unimplemented https://github.com/microsoft/DirectXShaderCompiler/issues/8301 # XFAIL: Vulkan && DXC -# Unimplemented https://github.com/llvm/offload-test-suite/issues/305 -# XFAIL: Metal - # RUN: split-file %s %t # RUN: %dxc_target -T cs_6_0 -Fo %t.o %t/source.hlsl # RUN: %offloader %t/pipeline.yaml %t.o diff --git a/test/Feature/ResourcesInStructs/res-in-struct-mix.test b/test/Feature/ResourcesInStructs/res-in-struct-mix.test index ddee0f523..3fe175aa4 100644 --- a/test/Feature/ResourcesInStructs/res-in-struct-mix.test +++ b/test/Feature/ResourcesInStructs/res-in-struct-mix.test @@ -122,9 +122,6 @@ DescriptorSets: # Unimplemented https://github.com/llvm/wg-hlsl/issues/367 # XFAIL: Clang -# Bug https://github.com/llvm/offload-test-suite/issues/642 -# XFAIL: Metal - # Vulkan does not support global structures with buffers or structures # containing both resources and non-resources. # UNSUPPORTED: Vulkan diff --git a/test/Feature/ResourcesInStructs/res-in-struct-simple-array.test b/test/Feature/ResourcesInStructs/res-in-struct-simple-array.test index d6e571e7c..1b50977aa 100644 --- a/test/Feature/ResourcesInStructs/res-in-struct-simple-array.test +++ b/test/Feature/ResourcesInStructs/res-in-struct-simple-array.test @@ -73,9 +73,6 @@ DescriptorSets: # Unimplemented https://github.com/llvm/wg-hlsl/issues/367 # XFAIL: Clang -# Unimplemented https://github.com/llvm/offload-test-suite/issues/305 -# XFAIL: Metal - # RUN: split-file %s %t # RUN: %dxc_target -T cs_6_0 -Fo %t.o %t/source.hlsl # RUN: %offloader %t/pipeline.yaml %t.o diff --git a/test/Feature/Semantics/MatrixSemantics.test b/test/Feature/Semantics/MatrixSemantics.test index 743e7fcea..bffaba7df 100644 --- a/test/Feature/Semantics/MatrixSemantics.test +++ b/test/Feature/Semantics/MatrixSemantics.test @@ -99,9 +99,6 @@ Results: # Semantics are not implemented in the DXIL backend. # XFAIL: Clang && !Vulkan -# See https://github.com/llvm/offload-test-suite/issues/744 -# XFAIL: Metal - # RUN: split-file %s %t # RUN: %dxc_target -T vs_6_0 -Fo %t-vertex.o %t/vertex.hlsl # RUN: %dxc_target -T ps_6_0 -Fo %t-pixel.o %t/pixel.hlsl diff --git a/test/Feature/Semantics/NestedStructSemantics.test b/test/Feature/Semantics/NestedStructSemantics.test index e2f7aadda..40f687004 100644 --- a/test/Feature/Semantics/NestedStructSemantics.test +++ b/test/Feature/Semantics/NestedStructSemantics.test @@ -102,9 +102,6 @@ Results: # Semantics are not implemented in the DXIL backend. # XFAIL: Clang && !Vulkan -# See https://github.com/llvm/offload-test-suite/issues/744 -# XFAIL: Metal - # RUN: split-file %s %t # RUN: %dxc_target -T vs_6_0 -Fo %t-vertex.o %t/vertex.hlsl # RUN: %dxc_target -T ps_6_0 -Fo %t-pixel.o %t/pixel.hlsl diff --git a/test/Feature/Semantics/SemanticTypes.test b/test/Feature/Semantics/SemanticTypes.test index c4842702a..393f04204 100644 --- a/test/Feature/Semantics/SemanticTypes.test +++ b/test/Feature/Semantics/SemanticTypes.test @@ -104,9 +104,6 @@ Results: # Semantics are not implemented in the DXIL backend. # XFAIL: Clang && !Vulkan -# See https://github.com/llvm/offload-test-suite/issues/744 -# XFAIL: Metal - # RUN: split-file %s %t # RUN: %dxc_target -T vs_6_0 -Fo %t-vertex.o %t/vertex.hlsl # RUN: %dxc_target -T ps_6_0 -Fo %t-pixel.o %t/pixel.hlsl diff --git a/test/Feature/Semantics/ShadowedSemantics.test b/test/Feature/Semantics/ShadowedSemantics.test index cf5a7153e..03f36e422 100644 --- a/test/Feature/Semantics/ShadowedSemantics.test +++ b/test/Feature/Semantics/ShadowedSemantics.test @@ -107,9 +107,6 @@ Results: # Semantics are not implemented in the DXIL backend. # XFAIL: Clang && !Vulkan -# See https://github.com/llvm/offload-test-suite/issues/744 -# XFAIL: Metal - # RUN: split-file %s %t # RUN: %dxc_target -T vs_6_0 -Fo %t-vertex.o %t/vertex.hlsl # RUN: %dxc_target -T ps_6_0 -Fo %t-pixel.o %t/pixel.hlsl diff --git a/test/Feature/StructuredBuffer/matrix.test b/test/Feature/StructuredBuffer/matrix.test index a8e9750e8..b04fe0795 100644 --- a/test/Feature/StructuredBuffer/matrix.test +++ b/test/Feature/StructuredBuffer/matrix.test @@ -108,9 +108,6 @@ DescriptorSets: ... #--- end -# Unimplemented https://github.com/llvm/offload-test-suite/issues/1021 -# XFAIL: Metal - # RUN: split-file %s %t # RUN: %dxc_target -fvk-use-dx-layout -T cs_6_0 -Fo %t.o %t/source.hlsl # RUN: %offloader %t/pipeline.yaml %t.o diff --git a/test/Feature/StructuredBuffer/matrix_assign.test b/test/Feature/StructuredBuffer/matrix_assign.test index af767a456..ab54d9487 100644 --- a/test/Feature/StructuredBuffer/matrix_assign.test +++ b/test/Feature/StructuredBuffer/matrix_assign.test @@ -97,9 +97,6 @@ DescriptorSets: ... #--- end -# Unimplemented https://github.com/llvm/offload-test-suite/issues/1021 -# XFAIL: Metal - # RUN: split-file %s %t # RUN: %dxc_target -fvk-use-dx-layout -T cs_6_0 -Fo %t.o %t/source.hlsl # RUN: %offloader %t/pipeline.yaml %t.o diff --git a/test/Feature/Textures/Texture2D.OperatorIndex.test.yaml b/test/Feature/Textures/Texture2D.OperatorIndex.test.yaml index 7181ff0a4..fd187946f 100644 --- a/test/Feature/Textures/Texture2D.OperatorIndex.test.yaml +++ b/test/Feature/Textures/Texture2D.OperatorIndex.test.yaml @@ -62,7 +62,8 @@ Results: #--- end # Unimplemented: Clang + DX: https://github.com/llvm/llvm-project/issues/101558 -# XFAIL: DirectX || Metal +# XFAIL: DirectX +# XFAIL: Clang && Metal # RUN: split-file %s %t # RUN: %dxc_target -T cs_6_0 -Fo %t.o %t/source.hlsl diff --git a/test/Feature/Textures/Texture2D.SRVToUAV.array.test.yaml b/test/Feature/Textures/Texture2D.SRVToUAV.array.test.yaml index 262c0103c..e1e444cf0 100644 --- a/test/Feature/Textures/Texture2D.SRVToUAV.array.test.yaml +++ b/test/Feature/Textures/Texture2D.SRVToUAV.array.test.yaml @@ -91,10 +91,6 @@ DescriptorSets: # Unimplemented https://github.com/llvm/llvm-project/issues/133835 # XFAIL: Clang -# Offload tests are missing support for resource arrays on Metal -# Unimplemented https://github.com/llvm/offload-test-suite/issues/305 -# XFAIL: Metal - # RUN: split-file %s %t # RUN: %dxc_target -T cs_6_0 -Fo %t.o %t/source.hlsl # RUN: %offloader %t/pipeline.yaml %t.o diff --git a/test/Tools/Offloader/BufferExact-error-array.test b/test/Tools/Offloader/BufferExact-error-array.test index f88e1bbaf..b05ec0df7 100644 --- a/test/Tools/Offloader/BufferExact-error-array.test +++ b/test/Tools/Offloader/BufferExact-error-array.test @@ -46,10 +46,6 @@ DescriptorSets: ... #--- end -# Offload tests are missing support for resource arrays on Metal -# Unimplemented https://github.com/llvm/offload-test-suite/issues/305 -# XFAIL: Metal - # RUN: split-file %s %t # RUN: %dxc_target -T cs_6_5 -Fo %t.o %t/source.hlsl # RUN: not %offloader %t/pipeline.yaml %t.o 2>&1 | FileCheck %s diff --git a/test/WaveOps/QuadReadAcrossDiagonal.32.test b/test/WaveOps/QuadReadAcrossDiagonal.32.test index 3261d4b64..a6e8800f9 100644 --- a/test/WaveOps/QuadReadAcrossDiagonal.32.test +++ b/test/WaveOps/QuadReadAcrossDiagonal.32.test @@ -344,9 +344,6 @@ DescriptorSets: # Bug: https://github.com/llvm/offload-test-suite/issues/986 # XFAIL: Intel && Vulkan && DXC -# Bug: https://github.com/llvm/offload-test-suite/issues/989 -# XFAIL: Metal - # RUN: split-file %s %t # RUN: %dxc_target -T cs_6_5 -Fo %t.o %t/source.hlsl # RUN: %offloader %t/pipeline.yaml %t.o diff --git a/test/WaveOps/QuadReadAcrossDiagonal.int16.test b/test/WaveOps/QuadReadAcrossDiagonal.int16.test index 18b34e08e..634f7fe5f 100644 --- a/test/WaveOps/QuadReadAcrossDiagonal.int16.test +++ b/test/WaveOps/QuadReadAcrossDiagonal.int16.test @@ -240,9 +240,6 @@ DescriptorSets: # Bug: https://github.com/llvm/offload-test-suite/issues/986 # XFAIL: Intel && Vulkan && DXC -# Bug: https://github.com/llvm/offload-test-suite/issues/989 -# XFAIL: Metal - # RUN: split-file %s %t # RUN: %dxc_target -enable-16bit-types -T cs_6_5 -Fo %t.o %t/source.hlsl # RUN: %offloader %t/pipeline.yaml %t.o diff --git a/test/WaveOps/QuadReadAcrossX.32.test b/test/WaveOps/QuadReadAcrossX.32.test index 0e0ace848..39fa27a28 100644 --- a/test/WaveOps/QuadReadAcrossX.32.test +++ b/test/WaveOps/QuadReadAcrossX.32.test @@ -340,9 +340,6 @@ DescriptorSets: # Bug: https://github.com/llvm/offload-test-suite/issues/986 # XFAIL: Intel && Vulkan && DXC -# Bug: https://github.com/llvm/offload-test-suite/issues/989 -# XFAIL: Metal - # RUN: split-file %s %t # RUN: %dxc_target -T cs_6_5 -Fo %t.o %t/source.hlsl # RUN: %offloader %t/pipeline.yaml %t.o diff --git a/test/WaveOps/QuadReadAcrossX.int16.test b/test/WaveOps/QuadReadAcrossX.int16.test index 5b09c573e..31bb29dd7 100644 --- a/test/WaveOps/QuadReadAcrossX.int16.test +++ b/test/WaveOps/QuadReadAcrossX.int16.test @@ -236,9 +236,6 @@ DescriptorSets: # Bug: https://github.com/llvm/offload-test-suite/issues/986 # XFAIL: Intel && Vulkan && DXC -# Bug: https://github.com/llvm/offload-test-suite/issues/989 -# XFAIL: Metal - # RUN: split-file %s %t # RUN: %dxc_target -enable-16bit-types -T cs_6_5 -Fo %t.o %t/source.hlsl # RUN: %offloader %t/pipeline.yaml %t.o diff --git a/test/WaveOps/QuadReadAcrossY.32.test b/test/WaveOps/QuadReadAcrossY.32.test index b5b6ca24e..9e8e554a5 100644 --- a/test/WaveOps/QuadReadAcrossY.32.test +++ b/test/WaveOps/QuadReadAcrossY.32.test @@ -340,9 +340,6 @@ DescriptorSets: # Bug: https://github.com/llvm/offload-test-suite/issues/986 # XFAIL: Intel && Vulkan && DXC -# Bug: https://github.com/llvm/offload-test-suite/issues/989 -# XFAIL: Metal - # RUN: split-file %s %t # RUN: %dxc_target -T cs_6_5 -Fo %t.o %t/source.hlsl # RUN: %offloader %t/pipeline.yaml %t.o diff --git a/test/WaveOps/QuadReadAcrossY.int16.test b/test/WaveOps/QuadReadAcrossY.int16.test index ba38fe5a2..af7a46e51 100644 --- a/test/WaveOps/QuadReadAcrossY.int16.test +++ b/test/WaveOps/QuadReadAcrossY.int16.test @@ -236,9 +236,6 @@ DescriptorSets: # Bug: https://github.com/llvm/offload-test-suite/issues/986 # XFAIL: Intel && Vulkan && DXC -# Bug: https://github.com/llvm/offload-test-suite/issues/989 -# XFAIL: Metal - # RUN: split-file %s %t # RUN: %dxc_target -enable-16bit-types -T cs_6_5 -Fo %t.o %t/source.hlsl # RUN: %offloader %t/pipeline.yaml %t.o diff --git a/test/WaveOps/WaveActiveMax.fp16.test b/test/WaveOps/WaveActiveMax.fp16.test index e8e10a244..8f727322c 100644 --- a/test/WaveOps/WaveActiveMax.fp16.test +++ b/test/WaveOps/WaveActiveMax.fp16.test @@ -307,9 +307,6 @@ DescriptorSets: # Bug https://github.com/llvm/llvm-project/issues/172577 # XFAIL: Clang && Vulkan -# Bug https://github.com/llvm/offload-test-suite/issues/393 -# XFAIL: Metal - # REQUIRES: Half # RUN: split-file %s %t diff --git a/test/WaveOps/WaveActiveMax.fp32.test b/test/WaveOps/WaveActiveMax.fp32.test index 9cc0eb222..3ab5371cb 100644 --- a/test/WaveOps/WaveActiveMax.fp32.test +++ b/test/WaveOps/WaveActiveMax.fp32.test @@ -307,9 +307,6 @@ DescriptorSets: # Bug https://github.com/llvm/llvm-project/issues/172577 # XFAIL: Clang && Vulkan -# Bug https://github.com/llvm/offload-test-suite/issues/393 -# XFAIL: Metal - # RUN: split-file %s %t # RUN: %dxc_target -T cs_6_5 -Fo %t.o %t/source.hlsl # RUN: %offloader %t/pipeline.yaml %t.o diff --git a/test/WaveOps/WaveActiveMax.int16.test b/test/WaveOps/WaveActiveMax.int16.test index a68f23511..665765015 100644 --- a/test/WaveOps/WaveActiveMax.int16.test +++ b/test/WaveOps/WaveActiveMax.int16.test @@ -307,9 +307,6 @@ DescriptorSets: # Bug https://github.com/llvm/llvm-project/issues/172577 # XFAIL: Clang && Vulkan -# Bug https://github.com/llvm/offload-test-suite/issues/393 -# XFAIL: Metal - # REQUIRES: Int16 # RUN: split-file %s %t diff --git a/test/WaveOps/WaveActiveMax.int32.test b/test/WaveOps/WaveActiveMax.int32.test index 0ce7e5573..6cb612190 100644 --- a/test/WaveOps/WaveActiveMax.int32.test +++ b/test/WaveOps/WaveActiveMax.int32.test @@ -307,9 +307,6 @@ DescriptorSets: # Bug https://github.com/llvm/llvm-project/issues/172577 # XFAIL: Clang && Vulkan -# Bug https://github.com/llvm/offload-test-suite/issues/393 -# XFAIL: Metal - # RUN: split-file %s %t # RUN: %dxc_target -T cs_6_5 -Fo %t.o %t/source.hlsl # RUN: %offloader %t/pipeline.yaml %t.o diff --git a/test/WaveOps/WaveActiveMin.fp16.test b/test/WaveOps/WaveActiveMin.fp16.test index 591476cf8..e0c63f7bf 100644 --- a/test/WaveOps/WaveActiveMin.fp16.test +++ b/test/WaveOps/WaveActiveMin.fp16.test @@ -307,9 +307,6 @@ DescriptorSets: # Bug https://github.com/llvm/llvm-project/issues/172577 # XFAIL: Clang && Vulkan -# Bug https://github.com/llvm/offload-test-suite/issues/393 -# XFAIL: Metal - # REQUIRES: Half # RUN: split-file %s %t diff --git a/test/WaveOps/WaveActiveMin.fp32.test b/test/WaveOps/WaveActiveMin.fp32.test index 34a7525a8..24d6ecdfd 100644 --- a/test/WaveOps/WaveActiveMin.fp32.test +++ b/test/WaveOps/WaveActiveMin.fp32.test @@ -307,9 +307,6 @@ DescriptorSets: # Bug https://github.com/llvm/llvm-project/issues/172577 # XFAIL: Clang && Vulkan -# Bug https://github.com/llvm/offload-test-suite/issues/393 -# XFAIL: Metal - # RUN: split-file %s %t # RUN: %dxc_target -T cs_6_5 -Fo %t.o %t/source.hlsl # RUN: %offloader %t/pipeline.yaml %t.o diff --git a/test/WaveOps/WaveActiveMin.int16.test b/test/WaveOps/WaveActiveMin.int16.test index 6bd60b54f..80eaf0379 100644 --- a/test/WaveOps/WaveActiveMin.int16.test +++ b/test/WaveOps/WaveActiveMin.int16.test @@ -307,9 +307,6 @@ DescriptorSets: # Bug https://github.com/llvm/llvm-project/issues/172577 # XFAIL: Clang && Vulkan -# Bug https://github.com/llvm/offload-test-suite/issues/393 -# XFAIL: Metal - # REQUIRES: Int16 # RUN: split-file %s %t diff --git a/test/WaveOps/WaveActiveMin.int32.test b/test/WaveOps/WaveActiveMin.int32.test index c3d314742..7ef46865e 100644 --- a/test/WaveOps/WaveActiveMin.int32.test +++ b/test/WaveOps/WaveActiveMin.int32.test @@ -307,9 +307,6 @@ DescriptorSets: # Bug https://github.com/llvm/llvm-project/issues/172577 # XFAIL: Clang && Vulkan -# Bug https://github.com/llvm/offload-test-suite/issues/393 -# XFAIL: Metal - # RUN: split-file %s %t # RUN: %dxc_target -T cs_6_5 -Fo %t.o %t/source.hlsl # RUN: %offloader %t/pipeline.yaml %t.o diff --git a/test/WaveOps/WaveActiveSum.int16.test b/test/WaveOps/WaveActiveSum.int16.test index e754d9c2e..4659d7c4a 100644 --- a/test/WaveOps/WaveActiveSum.int16.test +++ b/test/WaveOps/WaveActiveSum.int16.test @@ -329,7 +329,7 @@ DescriptorSets: # XFAIL: NV && Clang && DirectX # Bug https://github.com/llvm/offload-test-suite/issues/393 -# XFAIL: Metal || (Vulkan && MoltenVK) +# XFAIL: (Vulkan && MoltenVK) # RUN: split-file %s %t # RUN: %dxc_target -enable-16bit-types -T cs_6_5 -Fo %t.o %t/source.hlsl diff --git a/test/WaveOps/WaveActiveSum.int32.test b/test/WaveOps/WaveActiveSum.int32.test index 9a99984fd..035d9fad1 100644 --- a/test/WaveOps/WaveActiveSum.int32.test +++ b/test/WaveOps/WaveActiveSum.int32.test @@ -319,10 +319,6 @@ DescriptorSets: ... #--- end - -# Bug https://github.com/llvm/offload-test-suite/issues/393 -# XFAIL: Metal - # Bug https://github.com/llvm/llvm-project/issues/156775 # XFAIL: Vulkan && Clang diff --git a/test/WaveOps/WavePrefixSum.int16.test b/test/WaveOps/WavePrefixSum.int16.test index 9b8151d8b..0b3a258ed 100644 --- a/test/WaveOps/WavePrefixSum.int16.test +++ b/test/WaveOps/WavePrefixSum.int16.test @@ -321,9 +321,6 @@ DescriptorSets: # Bug: https://github.com/llvm/llvm-project/issues/186151 # XFAIL: Clang && Vulkan -# Bug: https://github.com/llvm/offload-test-suite/issues/960 -# XFAIL: Metal - # Bug: https://github.com/llvm/offload-test-suite/issues/961 # XFAIL: NV && Clang && DirectX diff --git a/test/WaveOps/WaveReadLaneAt.16.test b/test/WaveOps/WaveReadLaneAt.16.test index 64709ef94..e3995321e 100644 --- a/test/WaveOps/WaveReadLaneAt.16.test +++ b/test/WaveOps/WaveReadLaneAt.16.test @@ -136,9 +136,6 @@ DescriptorSets: ... #--- end -# Bug https://github.com/llvm/offload-test-suite/issues/351 -# XFAIL: Metal - # Bug https://github.com/llvm/offload-test-suite/issues/532 # XFAIL: DirectX && QC diff --git a/test/WaveOps/WaveReadLaneAt.32.test b/test/WaveOps/WaveReadLaneAt.32.test index aa758caf7..b0268bc31 100644 --- a/test/WaveOps/WaveReadLaneAt.32.test +++ b/test/WaveOps/WaveReadLaneAt.32.test @@ -172,8 +172,6 @@ DescriptorSets: Binding: 7 ... #--- end -# Bug https://github.com/llvm/offload-test-suite/issues/351 -# XFAIL: Metal # RUN: split-file %s %t # RUN: %dxc_target -T cs_6_5 -Gis -Fo %t.o %t/source.hlsl diff --git a/test/WaveOps/WaveReadLaneFirst.fp16.test b/test/WaveOps/WaveReadLaneFirst.fp16.test index 9173d417a..67eb034ae 100644 --- a/test/WaveOps/WaveReadLaneFirst.fp16.test +++ b/test/WaveOps/WaveReadLaneFirst.fp16.test @@ -307,9 +307,6 @@ DescriptorSets: # Bug https://github.com/llvm/llvm-project/issues/156775 # XFAIL: Clang -# Bug https://github.com/llvm/offload-test-suite/issues/393 -# XFAIL: Metal - # Bug https://github.com/llvm/offload-test-suite/issues/627 # XFAIL: QC && DirectX diff --git a/test/WaveOps/WaveReadLaneFirst.fp32.test b/test/WaveOps/WaveReadLaneFirst.fp32.test index e6f32eae0..d2f88a616 100644 --- a/test/WaveOps/WaveReadLaneFirst.fp32.test +++ b/test/WaveOps/WaveReadLaneFirst.fp32.test @@ -307,9 +307,6 @@ DescriptorSets: # Bug https://github.com/llvm/llvm-project/issues/156775 # XFAIL: Clang -# Bug https://github.com/llvm/offload-test-suite/issues/393 -# XFAIL: Metal - # Bug https://github.com/llvm/offload-test-suite/issues/627 # XFAIL: QC && DirectX diff --git a/test/WaveOps/WaveReadLaneFirst.int16.test b/test/WaveOps/WaveReadLaneFirst.int16.test index 29f1b0ae1..042688b5c 100644 --- a/test/WaveOps/WaveReadLaneFirst.int16.test +++ b/test/WaveOps/WaveReadLaneFirst.int16.test @@ -307,9 +307,6 @@ DescriptorSets: # Bug https://github.com/llvm/llvm-project/issues/156775 # XFAIL: Clang -# Bug https://github.com/llvm/offload-test-suite/issues/393 -# XFAIL: Metal - # RUN: split-file %s %t # RUN: %dxc_target -enable-16bit-types -T cs_6_5 -Fo %t.o %t/source.hlsl # RUN: %offloader %t/pipeline.yaml %t.o diff --git a/test/WaveOps/WaveReadLaneFirst.int32.test b/test/WaveOps/WaveReadLaneFirst.int32.test index 29680fc85..87809ce40 100644 --- a/test/WaveOps/WaveReadLaneFirst.int32.test +++ b/test/WaveOps/WaveReadLaneFirst.int32.test @@ -307,9 +307,6 @@ DescriptorSets: # Bug https://github.com/llvm/llvm-project/issues/156775 # XFAIL: Clang -# Bug https://github.com/llvm/offload-test-suite/issues/393 -# XFAIL: Metal - # RUN: split-file %s %t # RUN: %dxc_target -T cs_6_5 -Fo %t.o %t/source.hlsl # RUN: %offloader %t/pipeline.yaml %t.o diff --git a/test/lit.cfg.py b/test/lit.cfg.py index 029953e74..cea984878 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -179,8 +179,6 @@ def setDeviceFeatures(config, device, compiler): offloader_args.append("-validation-layer") if ShouldSearchByGPuName: offloader_args.extend([f'-adapter-regex="{GPUName}"']) -if config.offloadtest_enable_metal: - offloader_args.append("-reflection=%t.json") tools.append( ToolSubst("%offloader", command=FindTool("offloader"), extra_args=offloader_args) ) @@ -191,7 +189,6 @@ def setDeviceFeatures(config, device, compiler): if config.offloadtest_test_clang: ExtraCompilerArgs.append("-fspv-extension=DXC") if config.offloadtest_enable_metal: - ExtraCompilerArgs = ["-metal", "-Fre", "%t.json"] # metal-irconverter version: 3.0.0 MSCVersionOutput = subprocess.check_output( ["metal-shaderconverter", "--version"] diff --git a/third-party/ThirdParty.md b/third-party/ThirdParty.md index 7bfb34cf7..9bbd11e7c 100644 --- a/third-party/ThirdParty.md +++ b/third-party/ThirdParty.md @@ -18,10 +18,19 @@ binaries through the DirectX API. The libpng library is used by the test runtime to emit viewable images in PNG format. +## metal_ir_converter + +* URL: https://developer.apple.com/metal/shader-converter/ +* Version: 3.1.0 +* License: Apache 2.0 + +Metal shader converter converts shader intermediate representations in LLVM IR bytecode +into a form suitable to be loaded into Metal. + ## metal_ir_converter_runtime * URL: https://developer.apple.com/metal/shader-converter/ -* Version: 2.0.0 +* Version: 3.1.0 * License: Apache 2.0 The Metal IR Converter runtime is part of the Metal Shader Converter tooling diff --git a/third-party/metal_irconverter_runtime/LICENSE.txt b/third-party/metal_irconverter_runtime/LICENSE.txt index 6c877ff98..408ab67a2 100644 --- a/third-party/metal_irconverter_runtime/LICENSE.txt +++ b/third-party/metal_irconverter_runtime/LICENSE.txt @@ -187,7 +187,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright © 2023 Apple Inc. + Copyright © 2023-2025 Apple Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/third-party/metal_irconverter_runtime/ir_raytracing.h b/third-party/metal_irconverter_runtime/ir_raytracing.h index 6201e72ac..ac9a48ca7 100644 --- a/third-party/metal_irconverter_runtime/ir_raytracing.h +++ b/third-party/metal_irconverter_runtime/ir_raytracing.h @@ -1,6 +1,6 @@ //------------------------------------------------------------------------------------------------------------------------------------------------------------- // -// Copyright 2023 Apple Inc. +// Copyright 2023-2025 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -34,18 +34,37 @@ #ifdef __METAL_VERSION__ #define IR_CONSTANT_PTR(ptr) constant ptr* + #define IR_DEVICE_PTR(ptr) device ptr* #else #define IR_CONSTANT_PTR(ptr) uint64_t + #define IR_DEVICE_PTR(ptr) uint64_t #endif // __METAL_VERSION__ +#ifndef IR_DEPRECATED +#ifdef _MSC_VER +#define IR_DEPRECATED(message) __declspec(deprecated(message)) +#else +#define IR_DEPRECATED(message) __attribute__((deprecated(message))) +#endif // _MSC_VER +#endif + typedef struct IRShaderIdentifier { - // For HitGroups, index into visible function table containing a converted - // intersection function. + // For HitGroups: + // If compilation mode is IRIntersectionFunctionCompilationIntersectionFunctionBufferFunction, + // intersection function handle resource ID of converted intersection or any-hit shader. + // If compilation mode is IRIntersectionFunctionCompilationVisibleFunction, + // index into visible function table containing converted intersection or any-hit shader. + // If compilation mode is IRIntersectionFunctionCompilationIntersectionFunction, + // index into intersection function table containing converted intersection or any-hit shader. uint64_t intersectionShaderHandle; - // For ray generation, miss, callable shaders, index into visible function - // table containing the translated function. For HitGroups, index to the - // converted closest-hit shader. + // For ray-generation shaders: + // If compilation mode is IRRayGenerationCompilationVisibleFunction, + // index into visible function table containing the converted shader. + // For miss, callable shaders: + // index into visible function table containing the converted shader. + // For HitGroups: + // index to the converted closest-hit shader. uint64_t shaderHandle; // GPU address to a buffer containing static samplers for shader records uint64_t localRootSignatureSamplersBuffer; @@ -96,9 +115,11 @@ typedef struct IRDispatchRaysDescriptor using RaygenFunctionType = void(constant top_level_global_ab*, constant top_level_local_ab*, constant res_desc_heap_ab*, constant smp_desc_heap_ab*, constant IRDispatchRaysArgument*, uint3); #define RaygenFunctionPointerTable metal::visible_function_table #define IFT metal::raytracing::intersection_function_table<> + #define MSLAccelerationStructure metal::raytracing::instance_acceleration_structure #else #define RaygenFunctionPointerTable resourceid_t #define IFT resourceid_t + #define MSLAccelerationStructure uint64_t #endif typedef struct IRDispatchRaysArgument @@ -109,7 +130,7 @@ typedef struct IRDispatchRaysArgument IR_CONSTANT_PTR(smp_desc_heap_ab) SmpDescHeap; RaygenFunctionPointerTable VisibleFunctionTable; IFT IntersectionFunctionTable; - uint32_t Pad[7]; + IR_CONSTANT_PTR(IFT) IntersectionFunctionTables; } IRDispatchRaysArgument; #ifdef IR_RUNTIME_METALCPP @@ -120,8 +141,8 @@ typedef MTLDispatchThreadgroupsIndirectArguments dispatchthreadgroupsindirectarg typedef struct IRRaytracingAccelerationStructureGPUHeader { - IR_CONSTANT_PTR(metal::raytracing::instance_acceleration_structure) accelerationStructureID; - IR_CONSTANT_PTR(uint32_t) addressOfInstanceContributions; + MSLAccelerationStructure accelerationStructureID; + IR_DEVICE_PTR(uint32_t) addressOfInstanceContributions; uint64_t pad0[4]; dispatchthreadgroupsindirectargs_t pad1; } IRRaytracingAccelerationStructureGPUHeader; @@ -142,15 +163,15 @@ typedef struct IRRaytracingInstanceDescriptor #ifdef __METAL_VERSION__ void IRRaytracingUpdateInstanceContributions(IRRaytracingAccelerationStructureGPUHeader header, - IRRaytracingInstanceDescriptor instanceDescriptor, + device IRRaytracingInstanceDescriptor* instanceDescriptor, uint32_t index); #ifdef IR_PRIVATE_IMPLEMENTATION void IRRaytracingUpdateInstanceContributions(IRRaytracingAccelerationStructureGPUHeader header, - IRRaytracingInstanceDescriptor instanceDescriptor, + device IRRaytracingInstanceDescriptor* instanceDescriptor, uint32_t index) { - header.addressOfInstanceContributions[index] = instanceDescriptor.InstanceContributionToHitGroupIndex[index]; + header.addressOfInstanceContributions[index] = instanceDescriptor[index].InstanceContributionToHitGroupIndex; } #endif // IR_PRIVATE_IMPLEMENTATION #endif // __METAL_VERSION__ @@ -176,11 +197,19 @@ void IRDescriptorTableSetAccelerationStructure(IRDescriptorTableEntry* entry, ui * @param instanceContributions array of instance contributions to hit group index. * @param instanceCount number of elements in the instanceContributions array. */ +IR_DEPRECATED("use IRRaytracingSetAccelerationStructure variant with instanceContributionArrayBufferGPUAddress.") void IRRaytracingSetAccelerationStructure(uint8_t* headerBuffer, resourceid_t accelerationStructure, uint8_t* instanceContributionArrayBuffer, const uint32_t* instanceContributions, uinteger_t instanceCount) IR_OVERLOADABLE; + +void IRRaytracingSetAccelerationStructure(uint8_t* headerBuffer, + resourceid_t accelerationStructure, + uint8_t* instanceContributionArrayBuffer, + uint64_t instanceContributionArrayBufferGPUAddress, + const uint32_t* instanceContributions, + uinteger_t instanceContributionsCount) IR_OVERLOADABLE; /** * Initialize a shader identifier to reference a ray generation, closest-hit, any-hit, miss, or callable shader without a @@ -235,7 +264,6 @@ void IRRaytracingSetAccelerationStructure(uint8_t* headerBuffer, { IRRaytracingAccelerationStructureGPUHeader* header = (IRRaytracingAccelerationStructureGPUHeader*)headerBuffer; header->accelerationStructureID = accelerationStructure._impl; - header->addressOfInstanceContributions = (uint64_t)instanceContributionArrayBuffer; uint32_t* bufferInstanceContributions = (uint32_t*)instanceContributionArrayBuffer; for (uinteger_t i = 0; i < instanceCount; ++i) { @@ -243,6 +271,24 @@ void IRRaytracingSetAccelerationStructure(uint8_t* headerBuffer, } } +IR_INLINE +void IRRaytracingSetAccelerationStructure(uint8_t* headerBuffer, + resourceid_t accelerationStructure, + uint8_t* instanceContributionArrayBuffer, + uint64_t instanceContributionArrayBufferGPUAddress, + const uint32_t* instanceContributions, + uinteger_t instanceContributionsCount) IR_OVERLOADABLE +{ + IRRaytracingAccelerationStructureGPUHeader* header = (IRRaytracingAccelerationStructureGPUHeader*)headerBuffer; + header->accelerationStructureID = accelerationStructure._impl; + header->addressOfInstanceContributions = instanceContributionArrayBufferGPUAddress; + uint32_t* bufferInstanceContributions = (uint32_t*)instanceContributionArrayBuffer; + for (uinteger_t i = 0; i < instanceContributionsCount; ++i) + { + bufferInstanceContributions[i] = instanceContributions[i]; + } +} + #endif // IR_PRIVATE_IMPLEMENTATION #endif // __METAL_VERSION__ diff --git a/third-party/metal_irconverter_runtime/ir_tessellator_tables.h b/third-party/metal_irconverter_runtime/ir_tessellator_tables.h index a3478d817..c0fa7bacb 100644 --- a/third-party/metal_irconverter_runtime/ir_tessellator_tables.h +++ b/third-party/metal_irconverter_runtime/ir_tessellator_tables.h @@ -1,6 +1,6 @@ //------------------------------------------------------------------------------------------------------------------------------------------------------------- // -// Copyright 2023 Apple Inc. +// Copyright 2023-2025 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/third-party/metal_irconverter_runtime/metal_irconverter_runtime.h b/third-party/metal_irconverter_runtime/metal_irconverter_runtime.h index 8e48ba66d..6285083f7 100644 --- a/third-party/metal_irconverter_runtime/metal_irconverter_runtime.h +++ b/third-party/metal_irconverter_runtime/metal_irconverter_runtime.h @@ -1,6 +1,6 @@ //------------------------------------------------------------------------------------------------------------------------------------------------------------- // -// Copyright 2023 Apple Inc. +// Copyright 2023-2025 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ typedef MTL::Buffer* buffer_t; typedef MTL::SamplerState* sampler_t; typedef MTL::RenderCommandEncoder* renderencoder_t; typedef MTL::PrimitiveType primitivetype_t; -typedef MTL::FunctionConstantValues* functionconstant_t; +typedef MTL::FunctionConstantValues* functionconstantvalues_t; typedef MTL::IndexType indextype_t; typedef MTL::Size mtlsize_t; typedef MTL::RenderPipelineState* renderpipelinestate_t; @@ -47,7 +47,7 @@ typedef id buffer_t; typedef id sampler_t; typedef id renderencoder_t; typedef MTLPrimitiveType primitivetype_t; -typedef MTLFunctionConstantValues* functionconstant_t; +typedef MTLFunctionConstantValues* functionconstantvalues_t; typedef MTLIndexType indextype_t; typedef MTLSize mtlsize_t; typedef id renderpipelinestate_t; @@ -81,8 +81,19 @@ extern const uint64_t kIRArgumentBufferDrawArgumentsBindPoint; extern const uint64_t kIRArgumentBufferUniformsBindPoint; extern const uint64_t kIRVertexBufferBindPoint; extern const uint64_t kIRStageInAttributeStartIndex; -extern const char* kIRIndirectTriangleIntersectionFunctionName; -extern const char* kIRIndirectProceduralIntersectionFunctionName; + +extern const char* kIRIndirectTriangleIntersectionFunctionName; +extern const char* kIRIndirectProceduralIntersectionFunctionName; + +extern const char* kIRTrianglePassthroughGeometryShader; +extern const char* kIRLinePassthroughGeometryShader; +extern const char* kIRPointPassthroughGeometryShader; + +extern const char* kIRFunctionGroupRayGeneration; +extern const char* kIRFunctionGroupClosestHit; +extern const char* kIRFunctionGroupMiss; + +extern const uint16_t kIRNonIndexedDraw; typedef struct IRDescriptorTableEntry { @@ -122,6 +133,7 @@ typedef enum IRRuntimePrimitiveType IRRuntimePrimitiveTypeTriangleStrip = 4, IRRuntimePrimitiveTypeLineWithAdj = 5, IRRuntimePrimitiveTypeTriangleWithAdj = 6, + IRRuntimePrimitiveTypeLineStripWithAdj = 7, IRRuntimePrimitiveType1ControlPointPatchlist = IRRuntimePrimitiveTypeTriangle, IRRuntimePrimitiveType2ControlPointPatchlist = IRRuntimePrimitiveTypeTriangle, IRRuntimePrimitiveType3ControlPointPatchlist = IRRuntimePrimitiveTypeTriangle, @@ -251,7 +263,7 @@ typedef struct IRRuntimeDrawParams typedef struct IRRuntimeDrawInfo { // Vertex pipelines only require the index type. - uint16_t indexType; + uint16_t indexType; // set to kIRNonIndexedDraw to indicate a non-indexed draw call // Required by all mesh shader-based pipelines. uint8_t primitiveTopology; @@ -267,8 +279,14 @@ typedef struct IRRuntimeDrawInfo uint64_t indexBuffer; // position aligned to 8 bytes } IRRuntimeDrawInfo; - - +typedef union IRRuntimeFunctionConstantValue +{ + struct { int32_t i0, i1, i2, i3; }; + struct { int16_t s0, s1, s2, s3, s4, s5, s6, s7; }; + struct { int64_t l0, l1; }; + struct { float f0, f1, f2, f3; }; +} IRRuntimeFunctionConstantValue; + /** * Create a BufferView instance representing an append/consume buffer. * Append/consume buffers provide storage and an atomic insert/remove operation. This function takes an @@ -289,16 +307,17 @@ void IRRuntimeCreateAppendBufferView(device_t device, /** * Obtain the count of an append/consume buffer. + * @note the backing MTLBuffer needs to have MTLStorageModeShared storage mode. * @param bufferView buffer view representing the append/consume buffer for which to retrieve the counter. * @return the current count of the append/consume buffer. This function doesn't cause a GPU-CPU sync. */ -uint32_t IRRuntimeGetAppendBufferCount(IRBufferView* bufferView); +uint32_t IRRuntimeGetAppendBufferCount(const IRBufferView* bufferView); /** - * Produce metadata from a buffer view description. + * Obtain encoded metadata from a buffer view description. * @param view the view description to encode into the produced metadata. **/ -uint64_t IRDescriptorTableGetBufferMetadata(IRBufferView* view); +uint64_t IRDescriptorTableGetBufferMetadata(const IRBufferView* view); /** * Encode a buffer into the argument buffer. @@ -314,7 +333,7 @@ void IRDescriptorTableSetBuffer(IRDescriptorTableEntry* entry, uint64_t gpu_va, * @param entry the pointer to the descriptor table entry to encode the buffer reference into. * @param bufferView the buffer view description. **/ -void IRDescriptorTableSetBufferView(IRDescriptorTableEntry* entry, IRBufferView* bufferView); +void IRDescriptorTableSetBufferView(IRDescriptorTableEntry* entry, const IRBufferView* bufferView); /** * Encode a texture into the argument buffer. @@ -422,6 +441,7 @@ void IRRuntimeDrawIndexedPrimitives(renderencoder_t enc, primitivetype_t primiti void IRRuntimeDrawIndexedPrimitives(renderencoder_t enc, primitivetype_t primitiveType, indextype_t indexType, buffer_t indexBuffer, uint64_t indexBufferOffset, buffer_t indirectBuffer, uint64_t indirectBufferOffset ) IR_OVERLOADABLE; /** + * Draw indexed primitives using an emulated geometry pipeline. * You need to bind your vertex arrays and strides before issuing this call. * Bind a buffer with IRRuntimeVertexBuffers at index 0 for the object stage, * You need to manually flag residency for all referenced vertex buffers and for the index buffer. @@ -438,6 +458,21 @@ void IRRuntimeDrawIndexedPrimitivesGeometryEmulation(renderencoder_t enc, uint32_t baseInstance); /** + * Draw non-indexed primitives using an emulated geometry pipeline. + * You need to bind your vertex arrays and strides before issuing this call. + * Bind a buffer with IRRuntimeVertexBuffers at index 0 for the object stage, + * You need to manually flag residency for all referenced vertex buffers. + */ +void IRRuntimeDrawPrimitivesGeometryEmulation(renderencoder_t enc, + IRRuntimePrimitiveType primitiveType, + IRRuntimeGeometryPipelineConfig geometryPipelineConfig, + uint32_t instanceCount, + uint32_t vertexCountPerInstance, + uint32_t baseVertex, + uint32_t baseInstance); + +/** + * * Draw indexed primitives using an emulated geometry/tessellation pipeline. * You need to bind your vertex arrays and strides before issuing this call. * Bind a buffer with IRRuntimeVertexBuffers at index 0 for the object stage, * You need to manually flag residency for all referenced vertex buffers and for the index buffer. @@ -453,6 +488,20 @@ void IRRuntimeDrawIndexedPatchesTessellationEmulation(renderencoder_t enc, int32_t baseVertex, uint32_t startIndex); +/** + * Draw non-indexed primitives using an emulated geometry/tessellation pipeline. + * You need to bind your vertex arrays and strides before issuing this call. + * Bind a buffer with IRRuntimeVertexBuffers at index 0 for the object stage, + * You need to manually flag residency for all referenced vertex buffers. + */ +void IRRuntimeDrawPatchesTessellationEmulation(renderencoder_t enc, + IRRuntimePrimitiveType primitiveTopology, + IRRuntimeTessellationPipelineConfig tessellationPipelineConfig, + uint32_t instanceCount, + uint32_t vertexCountPerInstance, + uint32_t baseInstance, + uint32_t baseVertex); + /** * Validate that the hull domain and tessellation stages are compatible from their reflection data. * @param hsTessellatorOutputPrimitive the tessellator's output primitive which needs to match the geometry stage's input primitive. You can cast this parameter from IRTessellatorOutputPrimitive. @@ -487,6 +536,9 @@ renderpipelinestate_t IRRuntimeNewGeometryEmulationPipeline(device_t device, /** * Create a new mesh pipeline suitable for emulating a render pipeline with a hull stage, a domain stage, and a geometry stage. + * @note You may optionally not provide a geometry shader as part of your tessellation pipeline by setting the decriptor parameter's + * geometryLibrary member to NULL, and providing a geometryFunctionName of kIRTrianglePassthroughGeometryShader, kIRLinePassthroughGeometryShader, + * or kIRPointPassthroughGeometryShader, depending on the primitive topology of draw calls that use the pipeline. * @param device the device to use for creating the new pipeline. * @param descriptor an object describing the origin libraries and function names to create the pipeline. * @param error an output error object containing details about any error encountered during the creation process. @@ -495,6 +547,14 @@ renderpipelinestate_t IRRuntimeNewGeometryEmulationPipeline(device_t device, renderpipelinestate_t IRRuntimeNewGeometryTessellationEmulationPipeline(device_t device, const IRGeometryTessellationEmulationPipelineDescriptor* descriptor, nserror_t* error) API_AVAILABLE(macosx(14), ios(17)); + +/** + * Sets a value for a function constant at a specific index. + * @param values the target function constant values object + * @param index the index of the function constant (must be between 0 and 65535) + * @param value the constant value + */ +void IRRuntimeSetFunctionConstantValue(functionconstantvalues_t values, uint16_t index, IRRuntimeFunctionConstantValue *value); #ifdef IR_PRIVATE_IMPLEMENTATION @@ -513,9 +573,19 @@ const uint64_t kIRArgumentBufferDrawArgumentsBindPoint = 4; const uint64_t kIRArgumentBufferUniformsBindPoint = 5; const uint64_t kIRVertexBufferBindPoint = 6; const uint64_t kIRStageInAttributeStartIndex = 11; + const char* kIRIndirectTriangleIntersectionFunctionName = "irconverter.wrapper.intersection.function.triangle"; const char* kIRIndirectProceduralIntersectionFunctionName = "irconverter.wrapper.intersection.function.procedural"; +const char* kIRTrianglePassthroughGeometryShader = "irconverter_domain_shader_triangle_passthrough"; +const char* kIRLinePassthroughGeometryShader = "irconverter_domain_shader_line_passthrough"; +const char* kIRPointPassthroughGeometryShader = "irconverter_domain_shader_point_passthrough"; + +const uint16_t kIRNonIndexedDraw = 0; + +const char* kIRFunctionGroupRayGeneration = "rayGen"; +const char* kIRFunctionGroupClosestHit = "closestHit"; +const char* kIRFunctionGroupMiss = "miss"; const uint64_t kIRBufSizeOffset = 0; const uint64_t kIRBufSizeMask = 0xffffffff; @@ -584,7 +654,7 @@ void IRRuntimeCreateAppendBufferView(device_t device, buffer_t appendBuffer, uin } IR_INLINE -uint32_t IRRuntimeGetAppendBufferCount(IRBufferView* bufferView) +uint32_t IRRuntimeGetAppendBufferCount(const IRBufferView* bufferView) { uint64_t bufferOffset = bufferView->textureViewOffsetInElements * 4; #ifdef IR_RUNTIME_METALCPP @@ -597,7 +667,7 @@ uint32_t IRRuntimeGetAppendBufferCount(IRBufferView* bufferView) IR_INLINE -uint64_t IRDescriptorTableGetBufferMetadata(IRBufferView* view) +uint64_t IRDescriptorTableGetBufferMetadata(const IRBufferView* view) { uint64_t md = (view->bufferSize & kIRBufSizeMask) << kIRBufSizeOffset; @@ -617,7 +687,7 @@ void IRDescriptorTableSetBuffer(IRDescriptorTableEntry* entry, uint64_t gpu_va, } IR_INLINE -void IRDescriptorTableSetBufferView(IRDescriptorTableEntry* entry, IRBufferView* bufferView) +void IRDescriptorTableSetBufferView(IRDescriptorTableEntry* entry, const IRBufferView* bufferView) { #ifdef IR_RUNTIME_METALCPP entry->gpuVA = bufferView->buffer->gpuAddress() + bufferView->bufferOffset; @@ -657,20 +727,25 @@ void IRDescriptorTableSetSampler(IRDescriptorTableEntry* entry, sampler_t argume entry->metadata = encodedLodBias; } +static IR_INLINE +uint16_t IRMetalIndexToIRIndex(indextype_t indexType) +{ + return (uint16_t)(indexType+1); +} + IR_INLINE void IRRuntimeDrawPrimitives(renderencoder_t enc, primitivetype_t primitiveType, uint64_t vertexStart, uint64_t vertexCount, uint64_t instanceCount, uint64_t baseInstance) IR_OVERLOADABLE { IRRuntimeDrawArgument da = { (uint32_t)vertexCount, (uint32_t)instanceCount, (uint32_t)vertexStart, (uint32_t)baseInstance }; IRRuntimeDrawParams dp = { .draw = da }; - IRRuntimeDrawInfo di = { 0, (uint8_t)primitiveType, 0, 0, 0 }; #ifdef IR_RUNTIME_METALCPP enc->setVertexBytes( &dp, sizeof( IRRuntimeDrawParams ), kIRArgumentBufferDrawArgumentsBindPoint ); - enc->setVertexBytes( &di, sizeof( IRRuntimeDrawInfo ), kIRArgumentBufferUniformsBindPoint ); + enc->setVertexBytes( &kIRNonIndexedDraw, sizeof( uint16_t ), kIRArgumentBufferUniformsBindPoint ); enc->drawPrimitives( primitiveType, vertexStart, vertexCount, instanceCount, baseInstance ); #else [enc setVertexBytes:&dp length:sizeof( IRRuntimeDrawParams ) atIndex:kIRArgumentBufferDrawArgumentsBindPoint]; - [enc setVertexBytes:&di length:sizeof( IRRuntimeDrawInfo ) atIndex:kIRArgumentBufferUniformsBindPoint]; + [enc setVertexBytes:&kIRNonIndexedDraw length:sizeof( uint16_t ) atIndex:kIRArgumentBufferUniformsBindPoint]; [enc drawPrimitives:primitiveType vertexStart:vertexStart vertexCount:vertexCount instanceCount:instanceCount baseInstance:baseInstance]; #endif } @@ -692,9 +767,11 @@ void IRRuntimeDrawPrimitives(renderencoder_t enc, primitivetype_t primitiveType, { #ifdef IR_RUNTIME_METALCPP enc->setVertexBuffer( indirectBuffer, indirectBufferOffset, kIRArgumentBufferDrawArgumentsBindPoint ); + enc->setVertexBytes( &kIRNonIndexedDraw, sizeof( uint16_t ), kIRArgumentBufferUniformsBindPoint ); enc->drawPrimitives( primitiveType, indirectBuffer, indirectBufferOffset ); #else [enc setVertexBuffer:indirectBuffer offset:indirectBufferOffset atIndex:kIRArgumentBufferDrawArgumentsBindPoint]; + [enc setVertexBytes:&kIRNonIndexedDraw length:sizeof( uint16_t ) atIndex:kIRArgumentBufferUniformsBindPoint]; [enc drawPrimitives:primitiveType indirectBuffer:indirectBuffer indirectBufferOffset:indirectBufferOffset]; #endif } @@ -711,15 +788,15 @@ void IRRuntimeDrawIndexedPrimitives(renderencoder_t enc, primitivetype_t primiti }; IRRuntimeDrawParams dp = { .drawIndexed = da }; - IRRuntimeDrawInfo di = { .indexType = (uint8_t)(indexType+1), .primitiveTopology = (uint8_t)primitiveType }; + const uint16_t IRIndexType = IRMetalIndexToIRIndex(indexType); #ifdef IR_RUNTIME_METALCPP enc->setVertexBytes( &dp, sizeof( IRRuntimeDrawParams ), kIRArgumentBufferDrawArgumentsBindPoint ); - enc->setVertexBytes( &di, sizeof( IRRuntimeDrawInfo ), kIRArgumentBufferUniformsBindPoint ); + enc->setVertexBytes( &IRIndexType, sizeof( uint16_t ), kIRArgumentBufferUniformsBindPoint ); enc->drawIndexedPrimitives( primitiveType, indexCount, indexType, indexBuffer, indexBufferOffset, instanceCount, baseVertex, baseInstance ); #else [enc setVertexBytes:&dp length:sizeof( IRRuntimeDrawParams ) atIndex:kIRArgumentBufferDrawArgumentsBindPoint]; - [enc setVertexBytes:&di length:sizeof( IRRuntimeDrawInfo ) atIndex:kIRArgumentBufferUniformsBindPoint]; + [enc setVertexBytes:&IRIndexType length:sizeof( uint16_t ) atIndex:kIRArgumentBufferUniformsBindPoint]; [enc drawIndexedPrimitives:primitiveType indexCount:indexCount indexType:indexType indexBuffer:indexBuffer indexBufferOffset:indexBufferOffset instanceCount:instanceCount baseVertex:baseVertex baseInstance:baseInstance]; #endif } @@ -739,11 +816,15 @@ void IRRuntimeDrawIndexedPrimitives(renderencoder_t enc, primitivetype_t primiti IR_INLINE void IRRuntimeDrawIndexedPrimitives(renderencoder_t enc, primitivetype_t primitiveType, indextype_t indexType, buffer_t indexBuffer, uint64_t indexBufferOffset, buffer_t indirectBuffer, uint64_t indirectBufferOffset ) IR_OVERLOADABLE { + const uint16_t IRIndexType = IRMetalIndexToIRIndex(indexType); + #ifdef IR_RUNTIME_METALCPP enc->setVertexBuffer( indirectBuffer, indirectBufferOffset, kIRArgumentBufferDrawArgumentsBindPoint ); + enc->setVertexBytes( &IRIndexType, sizeof( uint16_t ), kIRArgumentBufferUniformsBindPoint ); enc->drawIndexedPrimitives( primitiveType, indexType, indexBuffer, indexBufferOffset, indirectBuffer, indirectBufferOffset ); #else [enc setVertexBuffer:indirectBuffer offset:indirectBufferOffset atIndex:kIRArgumentBufferDrawArgumentsBindPoint]; + [enc setVertexBytes:&IRIndexType length:sizeof( uint16_t ) atIndex:kIRArgumentBufferUniformsBindPoint]; [enc drawIndexedPrimitives:primitiveType indexType:indexType indexBuffer:indexBuffer indexBufferOffset:indexBufferOffset indirectBuffer:indirectBuffer indirectBufferOffset:indirectBufferOffset]; #endif } @@ -760,6 +841,7 @@ static uint32_t IRRuntimePrimitiveTypeVertexCount(IRRuntimePrimitiveType primiti case IRRuntimePrimitiveTypeLineWithAdj: return 4; break; case IRRuntimePrimitiveTypeTriangleStrip: return 3; break; case IRRuntimePrimitiveTypeLineStrip: return 2; break; + case IRRuntimePrimitiveTypeLineStripWithAdj:return 4; break; default: return 0; } return 0; @@ -777,23 +859,7 @@ static uint32_t IRRuntimePrimitiveTypeVertexOverlap(IRRuntimePrimitiveType primi case IRRuntimePrimitiveTypeLineWithAdj: return 0; break; case IRRuntimePrimitiveTypeTriangleStrip: return 2; break; case IRRuntimePrimitiveTypeLineStrip: return 1; break; - default: return 0; - } - return 0; -} - -IR_INLINE -static uint32_t DXAIRPrimitiveTypeToPrimitiveTopology(IRRuntimePrimitiveType primitiveType) -{ - switch (primitiveType) - { - case IRRuntimePrimitiveTypePoint: return 1; - case IRRuntimePrimitiveTypeLine: return 2; - case IRRuntimePrimitiveTypeLineStrip: return 3; - case IRRuntimePrimitiveTypeTriangle: return 4; - case IRRuntimePrimitiveTypeTriangleStrip: return 5; - case IRRuntimePrimitiveTypeLineWithAdj: return 10; - case IRRuntimePrimitiveTypeTriangleWithAdj: return 12; + case IRRuntimePrimitiveTypeLineStripWithAdj:return 3; break; default: return 0; } return 0; @@ -842,10 +908,10 @@ static mtlsize_t IRRuntimeCalculateObjectTgCountForTessellationAndGeometryEmulat uint32_t instanceCount) { uint32_t nonProvokingVertices = 0; - if (primitiveType == (uint32_t)IRRuntimePrimitiveTypeLineStrip || primitiveType == (uint32_t)IRRuntimePrimitiveTypeTriangleStrip) + if (primitiveType == (uint32_t)IRRuntimePrimitiveTypeLineStrip || primitiveType == (uint32_t)IRRuntimePrimitiveTypeTriangleStrip || primitiveType == (uint32_t)IRRuntimePrimitiveTypeLineStripWithAdj) { // For strips, last k vertices aren't able to spawn a full primitive. - nonProvokingVertices = (primitiveType == (uint32_t)IRRuntimePrimitiveTypeTriangleStrip) ? 2 : 1; + nonProvokingVertices = IRRuntimePrimitiveTypeVertexOverlap(primitiveType); } mtlsize_t siz; @@ -906,11 +972,11 @@ void IRRuntimeDrawIndexedPrimitivesGeometryEmulation(renderencoder_t enc, IRRuntimeDrawParams drawParams; drawParams.drawIndexed = (IRRuntimeDrawIndexedArgument){ - .startInstanceLocation = baseInstance, + .indexCountPerInstance = indexCountPerInstance, .instanceCount = instanceCount, - .baseVertexLocation = baseVertex, .startIndexLocation = startIndex, - .indexCountPerInstance = indexCountPerInstance + .baseVertexLocation = baseVertex, + .startInstanceLocation = baseInstance }; @@ -941,6 +1007,67 @@ void IRRuntimeDrawIndexedPrimitivesGeometryEmulation(renderencoder_t enc, #endif } + +IR_INLINE +void IRRuntimeDrawPrimitivesGeometryEmulation(renderencoder_t enc, + IRRuntimePrimitiveType primitiveType, + IRRuntimeGeometryPipelineConfig geometryPipelineConfig, + uint32_t instanceCount, + uint32_t vertexCountPerInstance, + uint32_t baseVertex, + uint32_t baseInstance) +{ + IRRuntimeDrawInfo drawInfo = IRRuntimeCalculateDrawInfoForGSEmulation(primitiveType, + (indextype_t)-1, + geometryPipelineConfig.gsVertexSizeInBytes, + geometryPipelineConfig.gsMaxInputPrimitivesPerMeshThreadgroup, + instanceCount); + drawInfo.indexType = kIRNonIndexedDraw; + + mtlsize_t objectThreadgroupCount = IRRuntimeCalculateObjectTgCountForTessellationAndGeometryEmulation(vertexCountPerInstance, + drawInfo.objectThreadgroupVertexStride, + primitiveType, + instanceCount); + + uint32_t objectThreadgroupSize,meshThreadgroupSize; + IRRuntimeCalculateThreadgroupSizeForGeometry(primitiveType, + geometryPipelineConfig.gsMaxInputPrimitivesPerMeshThreadgroup, + drawInfo.objectThreadgroupVertexStride, + &objectThreadgroupSize, + &meshThreadgroupSize); + + IRRuntimeDrawParams drawParams; + drawParams.draw = (IRRuntimeDrawArgument){ + .vertexCountPerInstance = vertexCountPerInstance, + .instanceCount = instanceCount, + .startVertexLocation = baseVertex, + .startInstanceLocation = baseInstance + }; + + + #ifdef IR_RUNTIME_METALCPP + + enc->setObjectBytes(&drawInfo, sizeof(IRRuntimeDrawInfo), kIRArgumentBufferUniformsBindPoint); + enc->setMeshBytes(&drawInfo, sizeof(IRRuntimeDrawInfo), kIRArgumentBufferUniformsBindPoint); + enc->setObjectBytes(&drawParams, sizeof(IRRuntimeDrawParams), kIRArgumentBufferDrawArgumentsBindPoint); + enc->setMeshBytes(&drawParams, sizeof(IRRuntimeDrawParams), kIRArgumentBufferDrawArgumentsBindPoint); + + enc->drawMeshThreadgroups(objectThreadgroupCount, MTL::Size::Make(objectThreadgroupSize, 1, 1), MTL::Size::Make(meshThreadgroupSize, 1, 1)); + + #else + + [enc setObjectBytes:&drawInfo length:sizeof(IRRuntimeDrawInfo) atIndex:kIRArgumentBufferUniformsBindPoint]; + [enc setMeshBytes:&drawInfo length:sizeof(IRRuntimeDrawInfo) atIndex:kIRArgumentBufferUniformsBindPoint]; + [enc setObjectBytes:&drawParams length:sizeof(IRRuntimeDrawParams) atIndex:kIRArgumentBufferDrawArgumentsBindPoint]; + [enc setMeshBytes:&drawParams length:sizeof(IRRuntimeDrawParams) atIndex:kIRArgumentBufferDrawArgumentsBindPoint]; + + [enc drawMeshThreadgroups:objectThreadgroupCount + threadsPerObjectThreadgroup:MTLSizeMake(objectThreadgroupSize, 1, 1) + threadsPerMeshThreadgroup:MTLSizeMake(meshThreadgroupSize, 1, 1)]; + + #endif +} + IR_INLINE static uint16_t IRTessellatorThreadgroupVertexOverlap(IRRuntimeTessellatorOutputPrimitive tessellatorOutputPrimitive) { @@ -1020,21 +1147,14 @@ void IRRuntimeDrawIndexedPatchesTessellationEmulation(renderencoder_t enc, IRRuntimeDrawParams drawParams; drawParams.drawIndexed = (IRRuntimeDrawIndexedArgument){ - .startInstanceLocation = baseInstance, + .indexCountPerInstance = indexCountPerInstance, .instanceCount = instanceCount, - .baseVertexLocation = baseVertex, .startIndexLocation = startIndex, - .indexCountPerInstance = indexCountPerInstance + .baseVertexLocation = baseVertex, + .startInstanceLocation = baseInstance }; - uint32_t threadgroupMem = 16; - uint32_t prefixSumMem = 16; - - threadgroupMem = tessellationPipelineConfig.vsOutputSizeInBytes * - tessellationPipelineConfig.hsInputControlPointCount * - tessellationPipelineConfig.hsMaxPatchesPerObjectThreadgroup; - - prefixSumMem = 15360 - (32 * 4); + uint32_t threadgroupMem = 15360; #ifdef IR_RUNTIME_METALCPP drawInfo.indexBuffer = indexBuffer->gpuAddress(); @@ -1045,7 +1165,6 @@ void IRRuntimeDrawIndexedPatchesTessellationEmulation(renderencoder_t enc, enc->setMeshBytes(&drawParams, sizeof(IRRuntimeDrawParams), kIRArgumentBufferDrawArgumentsBindPoint); enc->setObjectThreadgroupMemoryLength(threadgroupMem, 0); - enc->setObjectThreadgroupMemoryLength(prefixSumMem, 1); enc->drawMeshThreadgroups(objectThreadgroupCount, MTL::Size::Make(objectThreadgroupSize, 1, 1), @@ -1060,7 +1179,78 @@ void IRRuntimeDrawIndexedPatchesTessellationEmulation(renderencoder_t enc, [enc setMeshBytes:&drawParams length:sizeof(IRRuntimeDrawParams) atIndex:kIRArgumentBufferDrawArgumentsBindPoint]; [enc setObjectThreadgroupMemoryLength:threadgroupMem atIndex:0]; - [enc setObjectThreadgroupMemoryLength:prefixSumMem atIndex:1]; + + [enc drawMeshThreadgroups:objectThreadgroupCount + threadsPerObjectThreadgroup:MTLSizeMake(objectThreadgroupSize, 1, 1) + threadsPerMeshThreadgroup:MTLSizeMake(meshThreadgroupSize, 1, 1)]; +#endif // IR_RUNTIME_METALCPP +} + +IR_INLINE +void IRRuntimeDrawPatchesTessellationEmulation(renderencoder_t enc, + IRRuntimePrimitiveType primitiveTopology, + IRRuntimeTessellationPipelineConfig tessellationPipelineConfig, + uint32_t instanceCount, + uint32_t vertexCountPerInstance, + uint32_t baseInstance, + uint32_t baseVertex) +{ + IRRuntimeDrawInfo drawInfo = IRRuntimeCalculateDrawInfoForGSTSEmulation( + /* primitiveType */ primitiveTopology, + /* indexType */ (indextype_t)-1, + /* tessellatorOutputPrimitive */ tessellationPipelineConfig.outputPrimitiveType, + /* gsMaxInputPrimitivesPerMeshThreadgroup */ tessellationPipelineConfig.gsMaxInputPrimitivesPerMeshThreadgroup, + /* hsPatchesPerObjectThreadgroup */ tessellationPipelineConfig.hsMaxPatchesPerObjectThreadgroup, + /* hsInputControlPointsPerPatch */ tessellationPipelineConfig.hsInputControlPointCount, + /* hsObjectThreadsPerPatch */ tessellationPipelineConfig.hsMaxObjectThreadsPerThreadgroup, + /* gsInstanceCount */ tessellationPipelineConfig.gsInstanceCount); + drawInfo.indexType = kIRNonIndexedDraw; + + + mtlsize_t objectThreadgroupCount = IRRuntimeCalculateObjectTgCountForTessellationAndGeometryEmulation(vertexCountPerInstance, + drawInfo.objectThreadgroupVertexStride, + primitiveTopology, + instanceCount); + + uint32_t objectThreadgroupSize, meshThreadgroupSize; + IRRuntimeCalculateThreadgroupSizeForTessellationAndGeometry(tessellationPipelineConfig.hsMaxPatchesPerObjectThreadgroup, + tessellationPipelineConfig.hsMaxObjectThreadsPerThreadgroup, + tessellationPipelineConfig.gsMaxInputPrimitivesPerMeshThreadgroup, + &objectThreadgroupSize, + &meshThreadgroupSize); + + + IRRuntimeDrawParams drawParams; + drawParams.draw = (IRRuntimeDrawArgument){ + .vertexCountPerInstance = vertexCountPerInstance, + .instanceCount = instanceCount, + .startVertexLocation = baseVertex, + .startInstanceLocation = baseInstance + + }; + + uint32_t threadgroupMem = 15360; + +#ifdef IR_RUNTIME_METALCPP + + enc->setObjectBytes(&drawInfo, sizeof(IRRuntimeDrawInfo), kIRArgumentBufferUniformsBindPoint); + enc->setMeshBytes(&drawInfo, sizeof(IRRuntimeDrawInfo), kIRArgumentBufferUniformsBindPoint); + enc->setObjectBytes(&drawParams, sizeof(IRRuntimeDrawParams), kIRArgumentBufferDrawArgumentsBindPoint); + enc->setMeshBytes(&drawParams, sizeof(IRRuntimeDrawParams), kIRArgumentBufferDrawArgumentsBindPoint); + + enc->setObjectThreadgroupMemoryLength(threadgroupMem, 0); + + enc->drawMeshThreadgroups(objectThreadgroupCount, + MTL::Size::Make(objectThreadgroupSize, 1, 1), + MTL::Size::Make(meshThreadgroupSize, 1, 1)); +#else + + [enc setObjectBytes:&drawInfo length:sizeof(IRRuntimeDrawInfo) atIndex:kIRArgumentBufferUniformsBindPoint]; + [enc setMeshBytes:&drawInfo length:sizeof(IRRuntimeDrawInfo) atIndex:kIRArgumentBufferUniformsBindPoint]; + [enc setObjectBytes:&drawParams length:sizeof(IRRuntimeDrawParams) atIndex:kIRArgumentBufferDrawArgumentsBindPoint]; + [enc setMeshBytes:&drawParams length:sizeof(IRRuntimeDrawParams) atIndex:kIRArgumentBufferDrawArgumentsBindPoint]; + + [enc setObjectThreadgroupMemoryLength:threadgroupMem atIndex:0]; [enc drawMeshThreadgroups:objectThreadgroupCount threadsPerObjectThreadgroup:MTLSizeMake(objectThreadgroupSize, 1, 1) @@ -1268,6 +1458,16 @@ renderpipelinestate_t IRRuntimeNewGeometryEmulationPipeline(device_t device, con #endif // IR_RUNTIME_METALCPP } +IR_INLINE +void IRRuntimeSetFunctionConstantValue(functionconstantvalues_t values, uint16_t index, IRRuntimeFunctionConstantValue *value) +{ +#ifdef IR_RUNTIME_METALCPP + values->setConstantValue(value, MTL::DataTypeUInt4, index); +#else + [values setConstantValue:value type:MTLDataTypeUInt4 atIndex:index]; +#endif +} + IR_INLINE renderpipelinestate_t IRRuntimeNewGeometryTessellationEmulationPipeline(device_t device, const IRGeometryTessellationEmulationPipelineDescriptor* descriptor, nserror_t* error) API_AVAILABLE(macosx(14), ios(17)) { @@ -1376,31 +1576,37 @@ renderpipelinestate_t IRRuntimeNewGeometryTessellationEmulationPipeline(device_t // Geometry function: { - // Not done here: verify the stage is not just passthrough. - // Configure function: - bool enableTessellationEmulation = true; - bool enableStreamOut = false; - - MTL::FunctionConstantValues* pFunctionConstants = MTL::FunctionConstantValues::alloc()->init(); - - pFunctionConstants->setConstantValue(&enableTessellationEmulation, - MTL::DataTypeBool, MTLSTR("tessellationEnabled")); - - pFunctionConstants->setConstantValue(&enableStreamOut, - MTL::DataTypeBool, MTLSTR("streamOutEnabled")); - - pFunctionConstants->setConstantValue(&(descriptor->pipelineConfig.vsOutputSizeInBytes), - MTL::DataTypeInt, MTLSTR("vertex_shader_output_size_fc")); - - MTL::FunctionDescriptor* pFunctionDesc = MTL::FunctionDescriptor::alloc()->init(); - pFunctionDesc->setConstantValues(pFunctionConstants); - pFunctionDesc->setName( NS::String::string(descriptor->geometryFunctionName, NS::UTF8StringEncoding) ); - - pGeometryFn = descriptor->geometryLibrary->newFunction(pFunctionDesc, error); - - pFunctionDesc->release(); - pFunctionConstants->release(); + if (descriptor->geometryLibrary != nullptr) + { + bool enableTessellationEmulation = true; + bool enableStreamOut = false; + + MTL::FunctionConstantValues* pFunctionConstants = MTL::FunctionConstantValues::alloc()->init(); + + pFunctionConstants->setConstantValue(&enableTessellationEmulation, + MTL::DataTypeBool, MTLSTR("tessellationEnabled")); + + pFunctionConstants->setConstantValue(&enableStreamOut, + MTL::DataTypeBool, MTLSTR("streamOutEnabled")); + + pFunctionConstants->setConstantValue(&(descriptor->pipelineConfig.vsOutputSizeInBytes), + MTL::DataTypeInt, MTLSTR("vertex_shader_output_size_fc")); + + MTL::FunctionDescriptor* pFunctionDesc = MTL::FunctionDescriptor::alloc()->init(); + pFunctionDesc->setConstantValues(pFunctionConstants); + pFunctionDesc->setName( NS::String::string(descriptor->geometryFunctionName, NS::UTF8StringEncoding) ); + + pGeometryFn = descriptor->geometryLibrary->newFunction(pFunctionDesc, error); + + pFunctionDesc->release(); + pFunctionConstants->release(); + } + else + { + NS::String* passthroughName = NS::String::string(descriptor->geometryFunctionName, NS::UTF8StringEncoding); + pGeometryFn = descriptor->domainLibrary->newFunction(passthroughName); + } if (!pGeometryFn) { goto exit_geometry_function_error; @@ -1532,8 +1738,7 @@ renderpipelinestate_t IRRuntimeNewGeometryTessellationEmulationPipeline(device_t MTLFunctionDescriptor* pFunctionDesc = [[MTLFunctionDescriptor alloc] init]; [pFunctionDesc setConstantValues:pFunctionConstants]; - - NSString* functionName = [NSString stringWithFormat:@"%s.dxil_irconverter_object_shader", descriptor->vertexFunctionName]; + pFunctionDesc.name = [NSString stringWithFormat:@"%s.dxil_irconverter_object_shader", descriptor->vertexFunctionName]; pVertexFn = [descriptor->vertexLibrary newFunctionWithDescriptor:pFunctionDesc error:error]; @@ -1584,32 +1789,37 @@ renderpipelinestate_t IRRuntimeNewGeometryTessellationEmulationPipeline(device_t // Geometry function: { - // Not done here: verify the stage is not just passthrough. - - // Configure function: - bool enableTessellationEmulation = true; - bool enableStreamOut = false; - - MTLFunctionConstantValues* pFunctionConstants = [[MTLFunctionConstantValues alloc] init]; - - [pFunctionConstants setConstantValue:&enableTessellationEmulation - type:MTLDataTypeBool - withName:@"tessellationEnabled"]; - - [pFunctionConstants setConstantValue:&enableStreamOut - type:MTLDataTypeBool - withName:@"streamOutEnabled"]; - - [pFunctionConstants setConstantValue:&(descriptor->pipelineConfig.vsOutputSizeInBytes) - type:MTLDataTypeInt - withName:@"vertex_shader_output_size_fc"]; - - MTLFunctionDescriptor* pFunctionDesc = [[MTLFunctionDescriptor alloc] init]; - [pFunctionDesc setConstantValues:pFunctionConstants]; - [pFunctionDesc setName:[NSString stringWithUTF8String:descriptor->geometryFunctionName]]; - - pGeometryFn = [descriptor->geometryLibrary newFunctionWithDescriptor:pFunctionDesc error:error]; - + if (descriptor->geometryLibrary != nil) + { + // Configure function: + bool enableTessellationEmulation = true; + bool enableStreamOut = false; + + MTLFunctionConstantValues* pFunctionConstants = [[MTLFunctionConstantValues alloc] init]; + + [pFunctionConstants setConstantValue:&enableTessellationEmulation + type:MTLDataTypeBool + withName:@"tessellationEnabled"]; + + [pFunctionConstants setConstantValue:&enableStreamOut + type:MTLDataTypeBool + withName:@"streamOutEnabled"]; + + [pFunctionConstants setConstantValue:&(descriptor->pipelineConfig.vsOutputSizeInBytes) + type:MTLDataTypeInt + withName:@"vertex_shader_output_size_fc"]; + + MTLFunctionDescriptor* pFunctionDesc = [[MTLFunctionDescriptor alloc] init]; + [pFunctionDesc setConstantValues:pFunctionConstants]; + [pFunctionDesc setName:[NSString stringWithUTF8String:descriptor->geometryFunctionName]]; + + pGeometryFn = [descriptor->geometryLibrary newFunctionWithDescriptor:pFunctionDesc error:error]; + } + else + { + NSString* passthroughName = [NSString stringWithUTF8String:descriptor->geometryFunctionName]; + pGeometryFn = [descriptor->domainLibrary newFunctionWithName:passthroughName]; + } if (!pGeometryFn) { return nil; diff --git a/tools/api-query/CMakeLists.txt b/tools/api-query/CMakeLists.txt index 9e55fb2f2..ac8359391 100644 --- a/tools/api-query/CMakeLists.txt +++ b/tools/api-query/CMakeLists.txt @@ -8,3 +8,9 @@ if (APPLE AND OFFLOADTEST_ENABLE_VULKAN) set_property(TARGET api-query APPEND_STRING PROPERTY LINK_FLAGS " -Wl,-rpath,${_Vulkan_LIB_DIR} ") endif() + +if (APPLE) + get_filename_component(_MetalIRConverter_LIB_DIR ${METAL_IRCONVERTER_LIBRARY} DIRECTORY) + set_property(TARGET api-query APPEND_STRING PROPERTY + LINK_FLAGS " -Wl,-rpath,${_MetalIRConverter_LIB_DIR} ") +endif() diff --git a/tools/offloader/CMakeLists.txt b/tools/offloader/CMakeLists.txt index d4c6120b8..8e04d30a1 100644 --- a/tools/offloader/CMakeLists.txt +++ b/tools/offloader/CMakeLists.txt @@ -12,3 +12,9 @@ if (APPLE AND OFFLOADTEST_ENABLE_VULKAN) set_property(TARGET offloader APPEND_STRING PROPERTY LINK_FLAGS " -Wl,-rpath,${_Vulkan_LIB_DIR} ") endif() + +if (APPLE) + get_filename_component(_MetalIRConverter_LIB_DIR ${METAL_IRCONVERTER_LIBRARY} DIRECTORY) + set_property(TARGET offloader APPEND_STRING PROPERTY + LINK_FLAGS " -Wl,-rpath,${_MetalIRConverter_LIB_DIR} ") +endif() diff --git a/tools/offloader/offloader.cpp b/tools/offloader/offloader.cpp index f81e43cbb..5d3b85d27 100644 --- a/tools/offloader/offloader.cpp +++ b/tools/offloader/offloader.cpp @@ -69,11 +69,6 @@ static cl::opt AdapterRegex( "Case-insensitive regular expression to match GPU adapter description"), cl::value_desc(""), cl::init("")); -static cl::list Reflection( - "reflection", - cl::desc("Filenames for shader reflection metadata (Metal only)"), - cl::value_desc("filename")); - static std::unique_ptr readFile(const std::string &Path) { const ExitOnError ExitOnErr("gpu-exec: error: "); ErrorOr> FileOrErr = @@ -121,9 +116,6 @@ int run() { // Read in the shaders for (size_t I = 0; I < InputShader.size(); ++I) { PipelineDesc.Shaders[I].Shader = readFile(InputShader[I]); - if (I < Reflection.size()) { - PipelineDesc.Shaders[I].Reflection = readFile(Reflection[I]); - } } if (InputShader.size() != PipelineDesc.Shaders.size()) @@ -136,15 +128,17 @@ int run() { const StringRef Binary = PipelineDesc.Shaders[0].Shader->getBuffer(); if (APIToUse == GPUAPI::Unknown) { if (Binary.starts_with("DXBC")) { +#ifdef __APPLE__ + APIToUse = GPUAPI::Metal; + outs() << "Using Metal API\n"; +#else APIToUse = GPUAPI::DirectX; outs() << "Using DirectX API\n"; +#endif } else if (*reinterpret_cast(Binary.data()) == 0x07230203) { APIToUse = GPUAPI::Vulkan; outs() << "Using Vulkan API\n"; - } else if (Binary.starts_with("MTLB")) { - APIToUse = GPUAPI::Metal; - outs() << "Using Metal API\n"; } }