diff --git a/source/MaterialXCore/Element.cpp b/source/MaterialXCore/Element.cpp index 94836245a7..4e9cfbfb80 100644 --- a/source/MaterialXCore/Element.cpp +++ b/source/MaterialXCore/Element.cpp @@ -339,9 +339,9 @@ TreeIterator Element::traverseTree() const return TreeIterator(getSelfNonConst()); } -GraphIterator Element::traverseGraph() const +GraphIterator Element::traverseGraph(bool skipVisitedEdges) const { - return GraphIterator(getSelfNonConst()); + return GraphIterator(getSelfNonConst(), skipVisitedEdges); } Edge Element::getUpstreamEdge(size_t) const diff --git a/source/MaterialXCore/Element.h b/source/MaterialXCore/Element.h index d1abcdfdca..68740f5b8a 100644 --- a/source/MaterialXCore/Element.h +++ b/source/MaterialXCore/Element.h @@ -638,6 +638,7 @@ class MX_CORE_API Element : public std::enable_shared_from_this /// Traverse the dataflow graph from the given element to each of its /// upstream sources in depth-first order, using pre-order visitation. + /// @param skipVisitedEdges Makes sure that each edge is visited only once. /// @throws ExceptionFoundCycle if a cycle is encountered. /// @return A GraphIterator object. /// @details Example usage with an implicit iterator: @@ -659,7 +660,7 @@ class MX_CORE_API Element : public std::enable_shared_from_this /// @endcode /// @sa getUpstreamEdge /// @sa getUpstreamElement - GraphIterator traverseGraph() const; + GraphIterator traverseGraph(bool skipVisitedEdges = false) const; /// Return the Edge with the given index that lies directly upstream from /// this element in the dataflow graph. diff --git a/source/MaterialXCore/Interface.cpp b/source/MaterialXCore/Interface.cpp index 24d3ccfe0d..b68a362ab3 100644 --- a/source/MaterialXCore/Interface.cpp +++ b/source/MaterialXCore/Interface.cpp @@ -334,7 +334,7 @@ bool Output::hasUpstreamCycle() const { try { - for (Edge edge : traverseGraph()) { } + for (Edge edge : traverseGraph(true)) { } } catch (ExceptionFoundCycle&) { diff --git a/source/MaterialXCore/Traversal.cpp b/source/MaterialXCore/Traversal.cpp index c78b5825d4..2843a076ea 100644 --- a/source/MaterialXCore/Traversal.cpp +++ b/source/MaterialXCore/Traversal.cpp @@ -113,7 +113,7 @@ GraphIterator& GraphIterator::operator++() // Traverse to the first upstream edge of this element. _stack.emplace_back(_upstreamElem, 0); Edge nextEdge = _upstreamElem->getUpstreamEdge(0); - if (nextEdge && nextEdge.getUpstreamElement()) + if (nextEdge && nextEdge.getUpstreamElement() && (!_skipVisitedEdges || !skipOrMarkAsVisited(nextEdge))) { extendPathUpstream(nextEdge.getUpstreamElement(), nextEdge.getConnectingElement()); return *this; @@ -140,7 +140,7 @@ GraphIterator& GraphIterator::operator++() if (parentFrame.second + 1 < parentFrame.first->getUpstreamEdgeCount()) { Edge nextEdge = parentFrame.first->getUpstreamEdge(++parentFrame.second); - if (nextEdge && nextEdge.getUpstreamElement()) + if (nextEdge && nextEdge.getUpstreamElement() && (!_skipVisitedEdges || !skipOrMarkAsVisited(nextEdge))) { extendPathUpstream(nextEdge.getUpstreamElement(), nextEdge.getConnectingElement()); return *this; @@ -177,6 +177,12 @@ void GraphIterator::returnPathDownstream(ElementPtr upstreamElem) _connectingElem = ElementPtr(); } +bool GraphIterator::skipOrMarkAsVisited(const Edge& edge) +{ + auto [it, inserted] = _visitedEdges.emplace(edge); + return !inserted; +} + // // InheritanceIterator methods // diff --git a/source/MaterialXCore/Traversal.h b/source/MaterialXCore/Traversal.h index a22d909cbe..6923aedd5b 100644 --- a/source/MaterialXCore/Traversal.h +++ b/source/MaterialXCore/Traversal.h @@ -191,9 +191,10 @@ class MX_CORE_API TreeIterator class MX_CORE_API GraphIterator { public: - explicit GraphIterator(ElementPtr elem) : + explicit GraphIterator(ElementPtr elem, bool skipVisitedEdges = false) : _upstreamElem(elem), _prune(false), + _skipVisitedEdges(skipVisitedEdges), _holdCount(0) { _pathElems.insert(elem); @@ -316,13 +317,16 @@ class MX_CORE_API GraphIterator private: void extendPathUpstream(ElementPtr upstreamElem, ElementPtr connectingElem); void returnPathDownstream(ElementPtr upstreamElem); + bool skipOrMarkAsVisited(const Edge&); private: ElementPtr _upstreamElem; ElementPtr _connectingElem; ElementSet _pathElems; vector _stack; + std::set _visitedEdges; bool _prune; + bool _skipVisitedEdges; size_t _holdCount; }; diff --git a/source/MaterialXGenShader/ShaderGraph.cpp b/source/MaterialXGenShader/ShaderGraph.cpp index 0f39b46c32..1dd473c861 100644 --- a/source/MaterialXGenShader/ShaderGraph.cpp +++ b/source/MaterialXGenShader/ShaderGraph.cpp @@ -152,7 +152,7 @@ void ShaderGraph::addUpstreamDependencies(const Element& root, GenContext& conte { std::set processedOutputs; - for (Edge edge : root.traverseGraph()) + for (Edge edge : root.traverseGraph(true)) { ElementPtr upstreamElement = edge.getUpstreamElement(); if (!upstreamElement) @@ -900,7 +900,7 @@ void ShaderGraph::optimize() ShaderOutput* upstreamPort = outputSocket->getConnection(); if (upstreamPort && upstreamPort->getNode() != this) { - for (ShaderGraphEdge edge : ShaderGraph::traverseUpstream(upstreamPort)) + for (ShaderGraphEdge edge : traverseUpstream(upstreamPort)) { ShaderNode* node = edge.upstream->getNode(); if (usedNodesSet.count(node) == 0) @@ -1190,7 +1190,7 @@ ShaderGraphEdgeIterator& ShaderGraphEdgeIterator::operator++() ShaderInput* input = _upstream->getNode()->getInput(0); ShaderOutput* output = input->getConnection(); - if (output && !output->getNode()->isAGraph()) + if (output && !output->getNode()->isAGraph() && !skipOrMarkAsVisited({ output, input })) { extendPathUpstream(output, input); return *this; @@ -1218,7 +1218,7 @@ ShaderGraphEdgeIterator& ShaderGraphEdgeIterator::operator++() ShaderInput* input = parentFrame.first->getNode()->getInput(++parentFrame.second); ShaderOutput* output = input->getConnection(); - if (output && !output->getNode()->isAGraph()) + if (output && !output->getNode()->isAGraph() && !skipOrMarkAsVisited({ output, input })) { extendPathUpstream(output, input); return *this; @@ -1259,4 +1259,10 @@ void ShaderGraphEdgeIterator::returnPathDownstream(ShaderOutput* upstream) _downstream = nullptr; } +bool ShaderGraphEdgeIterator::skipOrMarkAsVisited(ShaderGraphEdge edge) +{ + auto [it, inserted] = _visitedEdges.emplace(edge); + return !inserted; +} + MATERIALX_NAMESPACE_END diff --git a/source/MaterialXGenShader/ShaderGraph.h b/source/MaterialXGenShader/ShaderGraph.h index 6320d04296..8437fd407b 100644 --- a/source/MaterialXGenShader/ShaderGraph.h +++ b/source/MaterialXGenShader/ShaderGraph.h @@ -112,7 +112,8 @@ class MX_GENSHADER_API ShaderGraph : public ShaderNode /// Sort the nodes in topological order. void topologicalSort(); - /// Return an iterator for traversal upstream from the given output + /// Return an iterator for traversal upstream from the given output. + /// Edges are visited only once. static ShaderGraphEdgeIterator traverseUpstream(ShaderOutput* output); /// Return the map of unique identifiers used in the scope of this graph. @@ -209,6 +210,22 @@ class MX_GENSHADER_API ShaderGraphEdge downstream(down) { } + + bool operator==(const ShaderGraphEdge& rhs) const + { + return upstream == rhs.upstream && downstream == rhs.downstream; + } + + bool operator!=(const ShaderGraphEdge& rhs) const + { + return !(*this == rhs); + } + + bool operator<(const ShaderGraphEdge& rhs) const + { + return std::tie(upstream, downstream) < std::tie(rhs.upstream, rhs.downstream); + } + ShaderOutput* upstream; ShaderInput* downstream; }; @@ -254,12 +271,14 @@ class MX_GENSHADER_API ShaderGraphEdgeIterator private: void extendPathUpstream(ShaderOutput* upstream, ShaderInput* downstream); void returnPathDownstream(ShaderOutput* upstream); + bool skipOrMarkAsVisited(ShaderGraphEdge); ShaderOutput* _upstream; ShaderInput* _downstream; using StackFrame = std::pair; std::vector _stack; std::set _path; + std::set _visitedEdges; }; MATERIALX_NAMESPACE_END