Skip to content

Commit 1f86c5d

Browse files
committed
Early lobe pruning during ShaderGraph construction
1 parent bb3e736 commit 1f86c5d

24 files changed

Lines changed: 723 additions & 41 deletions

resources/Materials/TestSuite/_options.mtlx

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,5 +84,12 @@
8484
Default is false to avoid overhead when not profiling.
8585
-->
8686
<input name="enableTracing" type="boolean" value="true" />
87+
88+
<!-- Enable lobe pruning during ShaderGraph construction.
89+
When a NodeGraph has topological inputs (e.g. mix weights) that are
90+
compile-time constant 0 or 1, skip creating the dead branch nodes.
91+
Default is false.
92+
-->
93+
<input name="enableLobePruning" type="boolean" value="false" />
8794
</nodedef>
8895
</materialx>

source/MaterialXGenHw/HwShaderGenerator.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <MaterialXGenHw/Nodes/HwLightCompoundNode.h>
1111
#include <MaterialXGenHw/Nodes/HwMaterialCompoundNode.h>
1212
#include <MaterialXGenShader/Exception.h>
13+
#include <MaterialXGenShader/NodeGraphTopology.h>
1314
#include <MaterialXGenShader/Nodes/CompoundNode.h>
1415
#include <MaterialXGenShader/GenContext.h>
1516
#include <MaterialXGenShader/Shader.h>
@@ -391,7 +392,9 @@ void HwShaderGenerator::addStageLightingUniforms(GenContext& context, ShaderStag
391392
numActiveLights->setValue(Value::createValue<int>(0));
392393
}
393394
}
394-
ShaderNodeImplPtr HwShaderGenerator::createShaderNodeImplForNodeGraph(const NodeGraph& nodegraph) const
395+
ShaderNodeImplPtr HwShaderGenerator::createShaderNodeImplForNodeGraph(
396+
const NodeGraph& nodegraph,
397+
std::unique_ptr<NodeGraphPermutation> permutation) const
395398
{
396399
vector<OutputPtr> outputs = nodegraph.getActiveOutputs();
397400
if (outputs.empty())
@@ -404,15 +407,15 @@ ShaderNodeImplPtr HwShaderGenerator::createShaderNodeImplForNodeGraph(const Node
404407
// Use specialized implementations for nodes that output light shaders and materials.
405408
if (outputType == Type::LIGHTSHADER)
406409
{
410+
// HwLightCompoundNode doesn't support permutations (light shaders don't have lobe weights)
407411
return HwLightCompoundNode::create();
408412
}
409413
if (outputType == Type::MATERIAL)
410414
{
411415
return HwMaterialCompoundNode::create();
412416
}
413417

414-
// Use the base implementation for nodes that output other types.
415-
return CompoundNode::create();
418+
return CompoundNode::create(std::move(permutation));
416419
}
417420

418421
void HwShaderGenerator::emitClosureDataArg(const ShaderNode& node, GenContext& /*context*/, ShaderStage& stage) const

source/MaterialXGenHw/HwShaderGenerator.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ class MX_GENHW_API HwShaderGenerator : public ShaderGenerator
6464
virtual string getVertexDataPrefix(const VariableBlock& vertexData) const = 0;
6565

6666
/// Create the shader node implementation for a NodeGraph implementation.
67-
ShaderNodeImplPtr createShaderNodeImplForNodeGraph(const NodeGraph& nodegraph) const override;
67+
ShaderNodeImplPtr createShaderNodeImplForNodeGraph(
68+
const NodeGraph& nodegraph,
69+
std::unique_ptr<NodeGraphPermutation> permutation) const override;
6870

6971
// Note : the order must match the order defined in libraries/pbrlib/genglsl/lib/mx_closure_type.glsl
7072
// TODO : investigate build time mechanism for ensuring these stay in sync.

source/MaterialXGenHw/Nodes/HwLightCompoundNode.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
MATERIALX_NAMESPACE_BEGIN
1212

1313
HwLightCompoundNode::HwLightCompoundNode() :
14+
CompoundNode(nullptr),
1415
_lightUniforms(HW::LIGHT_DATA, EMPTY_STRING)
1516
{
1617
}

source/MaterialXGenMdl/MdlShaderGenerator.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,9 @@ ShaderPtr MdlShaderGenerator::generate(const string& name, ElementPtr element, G
339339
return shader;
340340
}
341341

342-
ShaderNodeImplPtr MdlShaderGenerator::createShaderNodeImplForNodeGraph(const NodeGraph& nodegraph) const
342+
ShaderNodeImplPtr MdlShaderGenerator::createShaderNodeImplForNodeGraph(
343+
const NodeGraph& nodegraph,
344+
std::unique_ptr<NodeGraphPermutation> permutation) const
343345
{
344346
vector<OutputPtr> outputs = nodegraph.getActiveOutputs();
345347
if (outputs.empty())
@@ -349,13 +351,12 @@ ShaderNodeImplPtr MdlShaderGenerator::createShaderNodeImplForNodeGraph(const Nod
349351

350352
const TypeDesc outputType = _typeSystem->getType(outputs[0]->getType());
351353

352-
ShaderNodeImplPtr impl;
353-
// Use a compound implementation.
354+
// Use a compound implementation with permutation support
354355
if (outputType.isClosure())
355356
{
356-
return ClosureCompoundNodeMdl::create();
357+
return ClosureCompoundNodeMdl::create(std::move(permutation));
357358
}
358-
return CompoundNodeMdl::create();
359+
return CompoundNodeMdl::create(std::move(permutation));
359360
}
360361

361362
ShaderNodeImplPtr MdlShaderGenerator::createShaderNodeImplForImplementation(const Implementation& implElement) const

source/MaterialXGenMdl/MdlShaderGenerator.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ class MX_GENMDL_API MdlShaderGenerator : public ShaderGenerator
7575
ShaderPtr generate(const string& name, ElementPtr element, GenContext& context) const override;
7676

7777
/// Create the shader node implementation for a NodeGraph implementation.
78-
ShaderNodeImplPtr createShaderNodeImplForNodeGraph(const NodeGraph& nodegraph) const override;
78+
ShaderNodeImplPtr createShaderNodeImplForNodeGraph(
79+
const NodeGraph& nodegraph,
80+
std::unique_ptr<NodeGraphPermutation> permutation) const override;
7981

8082
/// Create the shader node implementation for an mplementation implementation.
8183
ShaderNodeImplPtr createShaderNodeImplForImplementation(const Implementation& implementation) const override;

source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,14 @@
1313

1414
MATERIALX_NAMESPACE_BEGIN
1515

16-
ShaderNodeImplPtr ClosureCompoundNodeMdl::create()
16+
ShaderNodeImplPtr ClosureCompoundNodeMdl::create(std::unique_ptr<NodeGraphPermutation> permutation)
17+
{
18+
return std::make_shared<ClosureCompoundNodeMdl>(std::move(permutation));
19+
}
20+
21+
ClosureCompoundNodeMdl::ClosureCompoundNodeMdl(std::unique_ptr<NodeGraphPermutation> permutation) :
22+
CompoundNodeMdl(std::move(permutation))
1723
{
18-
return std::make_shared<ClosureCompoundNodeMdl>();
1924
}
2025

2126
void ClosureCompoundNodeMdl::addClassification(ShaderNode& node) const

source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@ MATERIALX_NAMESPACE_BEGIN
1515
class MX_GENMDL_API ClosureCompoundNodeMdl : public CompoundNodeMdl
1616
{
1717
public:
18-
static ShaderNodeImplPtr create();
18+
/// Create with permutation (may be nullptr).
19+
static ShaderNodeImplPtr create(std::unique_ptr<NodeGraphPermutation> permutation);
1920

2021
void addClassification(ShaderNode& node) const override;
2122
void emitFunctionDefinition(const ShaderNode& node, GenContext& context, ShaderStage& stage) const override;
2223
void emitFunctionCall(const ShaderNode& node, GenContext& context, ShaderStage& stage) const override;
24+
25+
explicit ClosureCompoundNodeMdl(std::unique_ptr<NodeGraphPermutation> permutation);
2326
};
2427

2528
MATERIALX_NAMESPACE_END

source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,14 @@ MATERIALX_NAMESPACE_BEGIN
1616

1717
const string CompoundNodeMdl::GEN_USER_DATA_RETURN_STRUCT_FIELD_NAME = "returnStructFieldName";
1818

19-
ShaderNodeImplPtr CompoundNodeMdl::create()
19+
ShaderNodeImplPtr CompoundNodeMdl::create(std::unique_ptr<NodeGraphPermutation> permutation)
20+
{
21+
return std::make_shared<CompoundNodeMdl>(std::move(permutation));
22+
}
23+
24+
CompoundNodeMdl::CompoundNodeMdl(std::unique_ptr<NodeGraphPermutation> permutation) :
25+
CompoundNode(std::move(permutation))
2026
{
21-
return std::make_shared<CompoundNodeMdl>();
2227
}
2328

2429
void CompoundNodeMdl::initialize(const InterfaceElement& element, GenContext& context)

source/MaterialXGenMdl/Nodes/CompoundNodeMdl.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ using GenUserDataStringPtr = std::shared_ptr<GenUserDataString>;
3030
class MX_GENMDL_API CompoundNodeMdl : public CompoundNode
3131
{
3232
public:
33-
static ShaderNodeImplPtr create();
33+
/// Create with permutation (may be nullptr).
34+
static ShaderNodeImplPtr create(std::unique_ptr<NodeGraphPermutation> permutation);
3435

3536
void initialize(const InterfaceElement& element, GenContext& context) override;
3637
void emitFunctionDefinition(const ShaderNode& node, GenContext& context, ShaderStage& stage) const override;
@@ -39,6 +40,8 @@ class MX_GENMDL_API CompoundNodeMdl : public CompoundNode
3940
bool isReturnStruct() const { return !_returnStruct.empty(); }
4041
bool unrollReturnStructMembers() const { return _unrollReturnStructMembers; }
4142

43+
explicit CompoundNodeMdl(std::unique_ptr<NodeGraphPermutation> permutation);
44+
4245
protected:
4346
void emitFunctionSignature(const ShaderNode& node, GenContext& context, ShaderStage& stage) const;
4447

0 commit comments

Comments
 (0)