From 93a7bfec18f3e69390bf3332ffa48569e9d48eb8 Mon Sep 17 00:00:00 2001 From: Krzysztof Jakubowski Date: Fri, 20 Sep 2024 14:59:45 +0200 Subject: [PATCH] Fix performance issues with graph edge iteration in ShaderGraph In complex materials graph and shader graph edge iteration can be extremely slow, because some edges may be visited many times unnecessarily. This is especially noticable in two functions: ShaderGraph::addUpstreamDependencies and ShaderGraph::optimize() . GraphIterator and ShaderGraphEdgeIterator classes iterate over DAGs without marking nodes as visited, which may lead to exponential traversal time for some DAGs: https://stackoverflow.com/a/69326676 This patch adds an option to skip visited edges in GraphIterator and modifies ShaderGraphEdgeIterator so that each edge is visited only once. --- source/MaterialXCore/Element.cpp | 4 ++-- source/MaterialXCore/Element.h | 3 ++- source/MaterialXCore/Interface.cpp | 2 +- source/MaterialXCore/Traversal.cpp | 10 ++++++++-- source/MaterialXCore/Traversal.h | 6 +++++- source/MaterialXGenShader/ShaderGraph.cpp | 14 ++++++++++---- source/MaterialXGenShader/ShaderGraph.h | 21 ++++++++++++++++++++- 7 files changed, 48 insertions(+), 12 deletions(-) diff --git a/source/MaterialXCore/Element.cpp b/source/MaterialXCore/Element.cpp index 94836245a7..4e9cfbfb80 100644 --- a/source/MaterialXCore/Element.cpp +++ b/source/MaterialXCore/Element.cpp @@ -339,9 +339,9 @@ TreeIterator Element::traverseTree() const return TreeIterator(getSelfNonConst()); } -GraphIterator Element::traverseGraph() const +GraphIterator Element::traverseGraph(bool skipVisitedEdges) const { - return GraphIterator(getSelfNonConst()); + return GraphIterator(getSelfNonConst(), skipVisitedEdges); } Edge Element::getUpstreamEdge(size_t) const diff --git a/source/MaterialXCore/Element.h b/source/MaterialXCore/Element.h index d1abcdfdca..68740f5b8a 100644 --- a/source/MaterialXCore/Element.h +++ b/source/MaterialXCore/Element.h @@ -638,6 +638,7 @@ class MX_CORE_API Element : public std::enable_shared_from_this /// Traverse the dataflow graph from the given element to each of its /// upstream sources in depth-first order, using pre-order visitation. + /// @param skipVisitedEdges Makes sure that each edge is visited only once. /// @throws ExceptionFoundCycle if a cycle is encountered. /// @return A GraphIterator object. /// @details Example usage with an implicit iterator: @@ -659,7 +660,7 @@ class MX_CORE_API Element : public std::enable_shared_from_this /// @endcode /// @sa getUpstreamEdge /// @sa getUpstreamElement - GraphIterator traverseGraph() const; + GraphIterator traverseGraph(bool skipVisitedEdges = false) const; /// Return the Edge with the given index that lies directly upstream from /// this element in the dataflow graph. diff --git a/source/MaterialXCore/Interface.cpp b/source/MaterialXCore/Interface.cpp index 24d3ccfe0d..b68a362ab3 100644 --- a/source/MaterialXCore/Interface.cpp +++ b/source/MaterialXCore/Interface.cpp @@ -334,7 +334,7 @@ bool Output::hasUpstreamCycle() const { try { - for (Edge edge : traverseGraph()) { } + for (Edge edge : traverseGraph(true)) { } } catch (ExceptionFoundCycle&) { diff --git a/source/MaterialXCore/Traversal.cpp b/source/MaterialXCore/Traversal.cpp index c78b5825d4..2843a076ea 100644 --- a/source/MaterialXCore/Traversal.cpp +++ b/source/MaterialXCore/Traversal.cpp @@ -113,7 +113,7 @@ GraphIterator& GraphIterator::operator++() // Traverse to the first upstream edge of this element. _stack.emplace_back(_upstreamElem, 0); Edge nextEdge = _upstreamElem->getUpstreamEdge(0); - if (nextEdge && nextEdge.getUpstreamElement()) + if (nextEdge && nextEdge.getUpstreamElement() && (!_skipVisitedEdges || !skipOrMarkAsVisited(nextEdge))) { extendPathUpstream(nextEdge.getUpstreamElement(), nextEdge.getConnectingElement()); return *this; @@ -140,7 +140,7 @@ GraphIterator& GraphIterator::operator++() if (parentFrame.second + 1 < parentFrame.first->getUpstreamEdgeCount()) { Edge nextEdge = parentFrame.first->getUpstreamEdge(++parentFrame.second); - if (nextEdge && nextEdge.getUpstreamElement()) + if (nextEdge && nextEdge.getUpstreamElement() && (!_skipVisitedEdges || !skipOrMarkAsVisited(nextEdge))) { extendPathUpstream(nextEdge.getUpstreamElement(), nextEdge.getConnectingElement()); return *this; @@ -177,6 +177,12 @@ void GraphIterator::returnPathDownstream(ElementPtr upstreamElem) _connectingElem = ElementPtr(); } +bool GraphIterator::skipOrMarkAsVisited(const Edge& edge) +{ + auto [it, inserted] = _visitedEdges.emplace(edge); + return !inserted; +} + // // InheritanceIterator methods // diff --git a/source/MaterialXCore/Traversal.h b/source/MaterialXCore/Traversal.h index a22d909cbe..6923aedd5b 100644 --- a/source/MaterialXCore/Traversal.h +++ b/source/MaterialXCore/Traversal.h @@ -191,9 +191,10 @@ class MX_CORE_API TreeIterator class MX_CORE_API GraphIterator { public: - explicit GraphIterator(ElementPtr elem) : + explicit GraphIterator(ElementPtr elem, bool skipVisitedEdges = false) : _upstreamElem(elem), _prune(false), + _skipVisitedEdges(skipVisitedEdges), _holdCount(0) { _pathElems.insert(elem); @@ -316,13 +317,16 @@ class MX_CORE_API GraphIterator private: void extendPathUpstream(ElementPtr upstreamElem, ElementPtr connectingElem); void returnPathDownstream(ElementPtr upstreamElem); + bool skipOrMarkAsVisited(const Edge&); private: ElementPtr _upstreamElem; ElementPtr _connectingElem; ElementSet _pathElems; vector _stack; + std::set _visitedEdges; bool _prune; + bool _skipVisitedEdges; size_t _holdCount; }; diff --git a/source/MaterialXGenShader/ShaderGraph.cpp b/source/MaterialXGenShader/ShaderGraph.cpp index 0f39b46c32..1dd473c861 100644 --- a/source/MaterialXGenShader/ShaderGraph.cpp +++ b/source/MaterialXGenShader/ShaderGraph.cpp @@ -152,7 +152,7 @@ void ShaderGraph::addUpstreamDependencies(const Element& root, GenContext& conte { std::set processedOutputs; - for (Edge edge : root.traverseGraph()) + for (Edge edge : root.traverseGraph(true)) { ElementPtr upstreamElement = edge.getUpstreamElement(); if (!upstreamElement) @@ -900,7 +900,7 @@ void ShaderGraph::optimize() ShaderOutput* upstreamPort = outputSocket->getConnection(); if (upstreamPort && upstreamPort->getNode() != this) { - for (ShaderGraphEdge edge : ShaderGraph::traverseUpstream(upstreamPort)) + for (ShaderGraphEdge edge : traverseUpstream(upstreamPort)) { ShaderNode* node = edge.upstream->getNode(); if (usedNodesSet.count(node) == 0) @@ -1190,7 +1190,7 @@ ShaderGraphEdgeIterator& ShaderGraphEdgeIterator::operator++() ShaderInput* input = _upstream->getNode()->getInput(0); ShaderOutput* output = input->getConnection(); - if (output && !output->getNode()->isAGraph()) + if (output && !output->getNode()->isAGraph() && !skipOrMarkAsVisited({ output, input })) { extendPathUpstream(output, input); return *this; @@ -1218,7 +1218,7 @@ ShaderGraphEdgeIterator& ShaderGraphEdgeIterator::operator++() ShaderInput* input = parentFrame.first->getNode()->getInput(++parentFrame.second); ShaderOutput* output = input->getConnection(); - if (output && !output->getNode()->isAGraph()) + if (output && !output->getNode()->isAGraph() && !skipOrMarkAsVisited({ output, input })) { extendPathUpstream(output, input); return *this; @@ -1259,4 +1259,10 @@ void ShaderGraphEdgeIterator::returnPathDownstream(ShaderOutput* upstream) _downstream = nullptr; } +bool ShaderGraphEdgeIterator::skipOrMarkAsVisited(ShaderGraphEdge edge) +{ + auto [it, inserted] = _visitedEdges.emplace(edge); + return !inserted; +} + MATERIALX_NAMESPACE_END diff --git a/source/MaterialXGenShader/ShaderGraph.h b/source/MaterialXGenShader/ShaderGraph.h index 6320d04296..8437fd407b 100644 --- a/source/MaterialXGenShader/ShaderGraph.h +++ b/source/MaterialXGenShader/ShaderGraph.h @@ -112,7 +112,8 @@ class MX_GENSHADER_API ShaderGraph : public ShaderNode /// Sort the nodes in topological order. void topologicalSort(); - /// Return an iterator for traversal upstream from the given output + /// Return an iterator for traversal upstream from the given output. + /// Edges are visited only once. static ShaderGraphEdgeIterator traverseUpstream(ShaderOutput* output); /// Return the map of unique identifiers used in the scope of this graph. @@ -209,6 +210,22 @@ class MX_GENSHADER_API ShaderGraphEdge downstream(down) { } + + bool operator==(const ShaderGraphEdge& rhs) const + { + return upstream == rhs.upstream && downstream == rhs.downstream; + } + + bool operator!=(const ShaderGraphEdge& rhs) const + { + return !(*this == rhs); + } + + bool operator<(const ShaderGraphEdge& rhs) const + { + return std::tie(upstream, downstream) < std::tie(rhs.upstream, rhs.downstream); + } + ShaderOutput* upstream; ShaderInput* downstream; }; @@ -254,12 +271,14 @@ class MX_GENSHADER_API ShaderGraphEdgeIterator private: void extendPathUpstream(ShaderOutput* upstream, ShaderInput* downstream); void returnPathDownstream(ShaderOutput* upstream); + bool skipOrMarkAsVisited(ShaderGraphEdge); ShaderOutput* _upstream; ShaderInput* _downstream; using StackFrame = std::pair; std::vector _stack; std::set _path; + std::set _visitedEdges; }; MATERIALX_NAMESPACE_END