|
| 1 | +#include <Babylon/Plugins/ShaderCompiler.h> |
| 2 | + |
| 3 | +#include "ShaderCompilerCommon.h" |
| 4 | +#include "ShaderCompilerTraversers.h" |
| 5 | +#include <bgfx/bgfx.h> |
| 6 | +#include <glslang/Public/ShaderLang.h> |
| 7 | +#include <glslang/Public/ResourceLimits.h> |
| 8 | +#include <SPIRV/GlslangToSpv.h> |
| 9 | +#include <spirv_parser.hpp> |
| 10 | +#include <spirv_hlsl.hpp> |
| 11 | +#include <wrl/client.h> |
| 12 | +#include <dxcapi.h> |
| 13 | + |
| 14 | +namespace |
| 15 | +{ |
| 16 | + void AddShader(glslang::TProgram& program, glslang::TShader& shader, std::string_view source) |
| 17 | + { |
| 18 | + const std::array<const char*, 1> sources{source.data()}; |
| 19 | + shader.setStrings(sources.data(), gsl::narrow_cast<int>(sources.size())); |
| 20 | + |
| 21 | + auto defaultTBuiltInResource = GetDefaultResources(); |
| 22 | + |
| 23 | + if (!shader.parse(defaultTBuiltInResource, 310, EProfile::EEsProfile, true, true, EShMsgDefault)) |
| 24 | + { |
| 25 | + throw std::runtime_error{shader.getInfoLog()}; |
| 26 | + } |
| 27 | + |
| 28 | + program.addShader(&shader); |
| 29 | + } |
| 30 | + |
| 31 | + struct DxcCompilerState |
| 32 | + { |
| 33 | + // Order matters: ComPtrs must be destroyed before FreeLibrary so COM |
| 34 | + // object vtables (which live inside dxcompiler.dll) remain valid during |
| 35 | + // Release(). Members are destroyed in reverse declaration order. |
| 36 | + struct ModuleDeleter |
| 37 | + { |
| 38 | + void operator()(HMODULE m) const noexcept |
| 39 | + { |
| 40 | + if (m != nullptr) |
| 41 | + { |
| 42 | + FreeLibrary(m); |
| 43 | + } |
| 44 | + } |
| 45 | + }; |
| 46 | + std::unique_ptr<std::remove_pointer_t<HMODULE>, ModuleDeleter> Module; |
| 47 | + Microsoft::WRL::ComPtr<IDxcCompiler3> Compiler; |
| 48 | + }; |
| 49 | + |
| 50 | + // Thread-safe lazy initialization via C++11 magic statics. |
| 51 | + DxcCompilerState& GetDxcCompiler() |
| 52 | + { |
| 53 | + static DxcCompilerState state = []() { |
| 54 | + DxcCompilerState s; |
| 55 | + |
| 56 | + s.Module.reset(LoadLibraryExW(L"dxcompiler.dll", nullptr, LOAD_LIBRARY_SEARCH_SYSTEM32 | LOAD_LIBRARY_SEARCH_APPLICATION_DIR | LOAD_LIBRARY_SEARCH_DEFAULT_DIRS)); |
| 57 | + if (!s.Module) |
| 58 | + { |
| 59 | + throw std::runtime_error{"Failed to load dxcompiler.dll"}; |
| 60 | + } |
| 61 | + |
| 62 | + auto createInstance = reinterpret_cast<DxcCreateInstanceProc>( |
| 63 | + GetProcAddress(s.Module.get(), "DxcCreateInstance")); |
| 64 | + if (!createInstance) |
| 65 | + { |
| 66 | + throw std::runtime_error{"Failed to find DxcCreateInstance in dxcompiler.dll"}; |
| 67 | + } |
| 68 | + |
| 69 | + if (FAILED(createInstance(CLSID_DxcCompiler, IID_PPV_ARGS(&s.Compiler)))) |
| 70 | + { |
| 71 | + throw std::runtime_error{"Failed to create IDxcCompiler3"}; |
| 72 | + } |
| 73 | + |
| 74 | + return s; |
| 75 | + }(); |
| 76 | + return state; |
| 77 | + } |
| 78 | + |
| 79 | + std::pair<std::unique_ptr<spirv_cross::Parser>, std::unique_ptr<spirv_cross::Compiler>> CompileShader(glslang::TProgram& program, EShLanguage stage, gsl::span<const spirv_cross::HLSLVertexAttributeRemap> attributes, IDxcBlob** blob) |
| 80 | + { |
| 81 | + std::vector<uint32_t> spirv; |
| 82 | + glslang::GlslangToSpv(*program.getIntermediate(stage), spirv); |
| 83 | + |
| 84 | + auto parser = std::make_unique<spirv_cross::Parser>(std::move(spirv)); |
| 85 | + parser->parse(); |
| 86 | + |
| 87 | + auto compiler = std::make_unique<spirv_cross::CompilerHLSL>(parser->get_parsed_ir()); |
| 88 | + |
| 89 | + compiler->set_hlsl_options({40, true}); |
| 90 | + |
| 91 | + for (const auto& attribute : attributes) |
| 92 | + { |
| 93 | + compiler->add_vertex_attribute_remap(attribute); |
| 94 | + } |
| 95 | + |
| 96 | + std::string hlsl = compiler->compile(); |
| 97 | + |
| 98 | + auto& dxc = GetDxcCompiler(); |
| 99 | + |
| 100 | + const wchar_t* target = stage == EShLangVertex ? L"vs_6_0" : L"ps_6_0"; |
| 101 | + |
| 102 | + std::vector<LPCWSTR> args = {L"-E", L"main", L"-T", target}; |
| 103 | +#ifdef _DEBUG |
| 104 | + args.push_back(L"-Zi"); |
| 105 | + args.push_back(L"-Od"); |
| 106 | +#endif |
| 107 | + |
| 108 | + DxcBuffer sourceBuffer{}; |
| 109 | + sourceBuffer.Ptr = hlsl.data(); |
| 110 | + sourceBuffer.Size = hlsl.size(); |
| 111 | + sourceBuffer.Encoding = DXC_CP_UTF8; |
| 112 | + |
| 113 | + Microsoft::WRL::ComPtr<IDxcResult> result; |
| 114 | + if (FAILED(dxc.Compiler->Compile( |
| 115 | + &sourceBuffer, |
| 116 | + args.data(), |
| 117 | + static_cast<UINT32>(args.size()), |
| 118 | + nullptr, |
| 119 | + IID_PPV_ARGS(&result)))) |
| 120 | + { |
| 121 | + throw std::runtime_error{"DXC compilation call failed"}; |
| 122 | + } |
| 123 | + |
| 124 | + HRESULT status; |
| 125 | + result->GetStatus(&status); |
| 126 | + if (FAILED(status)) |
| 127 | + { |
| 128 | + Microsoft::WRL::ComPtr<IDxcBlobUtf8> errors; |
| 129 | + result->GetOutput(DXC_OUT_ERRORS, IID_PPV_ARGS(&errors), nullptr); |
| 130 | + throw std::runtime_error{errors && errors->GetStringLength() > 0 |
| 131 | + ? errors->GetStringPointer() |
| 132 | + : "DXC compilation failed"}; |
| 133 | + } |
| 134 | + |
| 135 | + if (FAILED(result->GetOutput(DXC_OUT_OBJECT, IID_PPV_ARGS(blob), nullptr)) || *blob == nullptr) |
| 136 | + { |
| 137 | + throw std::runtime_error{"DXC did not produce a shader object"}; |
| 138 | + } |
| 139 | + |
| 140 | + return {std::move(parser), std::move(compiler)}; |
| 141 | + } |
| 142 | +} |
| 143 | + |
| 144 | +namespace Babylon::Plugins |
| 145 | +{ |
| 146 | + using namespace ShaderCompilerCommon; |
| 147 | + |
| 148 | + ShaderCompiler::ShaderCompiler() |
| 149 | + { |
| 150 | + glslang::InitializeProcess(); |
| 151 | + } |
| 152 | + |
| 153 | + ShaderCompiler::~ShaderCompiler() |
| 154 | + { |
| 155 | + glslang::FinalizeProcess(); |
| 156 | + } |
| 157 | + |
| 158 | + Graphics::BgfxShaderInfo ShaderCompiler::Compile(std::string_view vertexSource, std::string_view fragmentSource) |
| 159 | + { |
| 160 | + glslang::TProgram program; |
| 161 | + |
| 162 | + glslang::TShader vertexShader{EShLangVertex}; |
| 163 | + AddShader(program, vertexShader, ProcessSamplerFlip(ProcessShaderCoordinates(vertexSource))); |
| 164 | + |
| 165 | + glslang::TShader fragmentShader{EShLangFragment}; |
| 166 | + AddShader(program, fragmentShader, ProcessSamplerFlip(fragmentSource)); |
| 167 | + |
| 168 | + glslang::SpvVersion spv{}; |
| 169 | + spv.spv = 0x10000; |
| 170 | + vertexShader.getIntermediate()->setSpv(spv); |
| 171 | + fragmentShader.getIntermediate()->setSpv(spv); |
| 172 | + |
| 173 | + if (!program.link(EShMsgDefault)) |
| 174 | + { |
| 175 | + throw std::runtime_error{program.getInfoLog()}; |
| 176 | + } |
| 177 | + |
| 178 | + ShaderCompilerTraversers::IdGenerator ids{}; |
| 179 | + auto cutScope = ShaderCompilerTraversers::ChangeUniformTypes(program, ids); |
| 180 | + auto utstScope = ShaderCompilerTraversers::MoveNonSamplerUniformsIntoStruct(program, ids); |
| 181 | + std::map<std::string, std::string> vertexAttributeRenaming = {}; |
| 182 | + ShaderCompilerTraversers::AssignLocationsAndNamesToVertexVaryingsD3D(program, ids, vertexAttributeRenaming); |
| 183 | + ShaderCompilerTraversers::SplitSamplersIntoSamplersAndTextures(program, ids); |
| 184 | + ShaderCompilerTraversers::InvertYDerivativeOperands(program); |
| 185 | + |
| 186 | + // clang-format off |
| 187 | + static const spirv_cross::HLSLVertexAttributeRemap attributes[] = { |
| 188 | + {bgfx::Attrib::Position, "POSITION" }, |
| 189 | + {bgfx::Attrib::Normal, "NORMAL" }, |
| 190 | + {bgfx::Attrib::Tangent, "TANGENT" }, |
| 191 | + {bgfx::Attrib::Color0, "COLOR" }, |
| 192 | + {bgfx::Attrib::Indices, "BLENDINDICES"}, |
| 193 | + {bgfx::Attrib::Weight, "BLENDWEIGHT" }, |
| 194 | + {bgfx::Attrib::TexCoord0, "TEXCOORD0" }, |
| 195 | + {bgfx::Attrib::TexCoord1, "TEXCOORD1" }, |
| 196 | + {bgfx::Attrib::TexCoord2, "TEXCOORD2" }, |
| 197 | + {bgfx::Attrib::TexCoord3, "TEXCOORD3" }, |
| 198 | + {bgfx::Attrib::TexCoord4, "TEXCOORD4" }, |
| 199 | + {bgfx::Attrib::TexCoord5, "TEXCOORD5" }, |
| 200 | + {bgfx::Attrib::TexCoord6, "TEXCOORD6" }, |
| 201 | + {bgfx::Attrib::TexCoord7, "TEXCOORD7" }, |
| 202 | + }; |
| 203 | + // clang-format on |
| 204 | + |
| 205 | + Microsoft::WRL::ComPtr<IDxcBlob> vertexBlob; |
| 206 | + auto [vertexParser, vertexCompiler] = CompileShader(program, EShLangVertex, attributes, &vertexBlob); |
| 207 | + ShaderInfo vertexShaderInfo{ |
| 208 | + std::move(vertexParser), |
| 209 | + std::move(vertexCompiler), |
| 210 | + gsl::make_span(static_cast<uint8_t*>(vertexBlob->GetBufferPointer()), vertexBlob->GetBufferSize()), |
| 211 | + std::move(vertexAttributeRenaming)}; |
| 212 | + |
| 213 | + Microsoft::WRL::ComPtr<IDxcBlob> fragmentBlob; |
| 214 | + auto [fragmentParser, fragmentCompiler] = CompileShader(program, EShLangFragment, {}, &fragmentBlob); |
| 215 | + ShaderInfo fragmentShaderInfo{ |
| 216 | + std::move(fragmentParser), |
| 217 | + std::move(fragmentCompiler), |
| 218 | + gsl::make_span(static_cast<uint8_t*>(fragmentBlob->GetBufferPointer()), fragmentBlob->GetBufferSize()), |
| 219 | + {}}; |
| 220 | + |
| 221 | + return CreateBgfxShader(std::move(vertexShaderInfo), std::move(fragmentShaderInfo)); |
| 222 | + } |
| 223 | +} |
0 commit comments