From 604499a1e37eaf21150970815aecbdf179664d9c Mon Sep 17 00:00:00 2001 From: Ashwin Bhat Date: Wed, 9 Oct 2024 11:24:37 -0700 Subject: [PATCH] Update getMatching* logic and build fixes Update getMatchingNodeDefs and getMatchingImplementations such that node from datalibrary and local document are searched. Build fixes and code cleanup --- source/MaterialXCore/Document.cpp | 40 ++++++++--------------- source/MaterialXCore/Document.h | 4 +-- source/MaterialXCore/Element.cpp | 5 +-- source/MaterialXCore/Element.h | 4 +-- source/MaterialXCore/Interface.cpp | 9 ++--- source/MaterialXCore/Node.cpp | 1 + source/MaterialXGenShader/ShaderGraph.cpp | 5 --- 7 files changed, 25 insertions(+), 43 deletions(-) diff --git a/source/MaterialXCore/Document.cpp b/source/MaterialXCore/Document.cpp index 267597e266..d2047615c9 100644 --- a/source/MaterialXCore/Document.cpp +++ b/source/MaterialXCore/Document.cpp @@ -357,13 +357,10 @@ vector Document::getMaterialOutputs() const vector Document::getMatchingNodeDefs(const string& nodeName) const { - // Return all nodedefs from datalibrary if available - if (_dataLibrary) - { - auto datalibrarynodes = _dataLibrary->getMatchingNodeDefs(nodeName); - if (!datalibrarynodes.empty()) - return datalibrarynodes; - } + // Gather all nodedefs from datalibrary if available + vector matchingNodeDefs = hasDataLibrary() ? + getRegisteredDataLibrary()->getMatchingNodeDefs(nodeName) : + vector(); // Refresh the cache. _cache->refresh(); @@ -371,37 +368,28 @@ vector Document::getMatchingNodeDefs(const string& nodeName) const // Return all nodedefs matching the given node name. if (_cache->nodeDefMap.count(nodeName)) { - return _cache->nodeDefMap.at(nodeName); - } - else - { - return vector(); + matchingNodeDefs.insert(matchingNodeDefs.end(), _cache->nodeDefMap.at(nodeName).begin(), _cache->nodeDefMap.at(nodeName).end()); } + + return matchingNodeDefs; } vector Document::getMatchingImplementations(const string& nodeDef) const { - - // Return all implementations from datalibrary if available - if (_dataLibrary) - { - auto datalibrarynodes = _dataLibrary->getMatchingImplementations(nodeDef); - if (!datalibrarynodes.empty()) - return datalibrarynodes; - } - + // Gather all implementations from datalibrary if available + vector matchingImplementations = hasDataLibrary() ? + getRegisteredDataLibrary()->getMatchingImplementations(nodeDef) : + vector(); // Refresh the cache. _cache->refresh(); // Return all implementations matching the given nodedef string. if (_cache->implementationMap.count(nodeDef)) { - return _cache->implementationMap.at(nodeDef); - } - else - { - return vector(); + matchingImplementations.insert(matchingImplementations.end(), _cache->implementationMap.at(nodeDef).begin(), _cache->implementationMap.at(nodeDef).end()); } + + return matchingImplementations; } bool Document::validate(string* message) const diff --git a/source/MaterialXCore/Document.h b/source/MaterialXCore/Document.h index 05e4f97eb3..eaaf6a5971 100644 --- a/source/MaterialXCore/Document.h +++ b/source/MaterialXCore/Document.h @@ -546,7 +546,7 @@ class MX_CORE_API Document : public GraphElement /// Return the UnitDef, if any, with the given name. UnitDefPtr getUnitDef(const string& name) const { - return getChildOfType(name); + return hasDataLibrary() ? getChildOfType(getRegisteredDataLibrary(), name) : getChildOfType(name); } /// Return a vector of all Member elements in the TypeDef. @@ -577,7 +577,7 @@ class MX_CORE_API Document : public GraphElement /// Return the UnitTypeDef, if any, with the given name. UnitTypeDefPtr getUnitTypeDef(const string& name) const { - return getChildOfType(name); + return hasDataLibrary() ? getChildOfType(getRegisteredDataLibrary(), name) : getChildOfType(name); } /// Return a vector of all UnitTypeDef elements in the document. diff --git a/source/MaterialXCore/Element.cpp b/source/MaterialXCore/Element.cpp index eb67c453e1..94836245a7 100644 --- a/source/MaterialXCore/Element.cpp +++ b/source/MaterialXCore/Element.cpp @@ -564,10 +564,7 @@ bool ValueElement::validate(string* message) const const string& unittype = getUnitType(); if (!unittype.empty()) { - - unitTypeDef = getDocument()->hasDataLibrary() ? - getDocument()->getRegisteredDataLibrary()->getUnitTypeDef(unittype) : - getDocument()->getUnitTypeDef(unittype); + unitTypeDef = getDocument()->getUnitTypeDef(unittype); validateRequire(unitTypeDef != nullptr, res, message, "Unit type definition does not exist in document"); } } diff --git a/source/MaterialXCore/Element.h b/source/MaterialXCore/Element.h index 373ac16c6b..f9b86f1d7e 100644 --- a/source/MaterialXCore/Element.h +++ b/source/MaterialXCore/Element.h @@ -444,7 +444,7 @@ class MX_CORE_API Element : public std::enable_shared_from_this /// Return the child element from data library , if any, with the given name and subclass. /// If a child with the given name exists, but belongs to a different /// subclass, then an empty shared pointer is returned. - template shared_ptr getChildOfType(ConstDocumentPtr datalibrary, const string& name) const + template shared_ptr getChildOfType(ConstElementPtr datalibrary, const string& name) const { ElementPtr child = datalibrary->getChild(name); if (!child) @@ -483,7 +483,7 @@ class MX_CORE_API Element : public std::enable_shared_from_this /// Return a combined vector of all child elements including the Data Library that are instances of the given /// subclass, optionally filtered by the given category string. The returned /// vector maintains the order in which children were added. - template vector> getChildrenOfType(ConstDocumentPtr datalibrary, const string& category = EMPTY_STRING) const + template vector> getChildrenOfType(ConstElementPtr datalibrary, const string& category = EMPTY_STRING) const { vector> libraryChildren = datalibrary->getChildrenOfType(category); vector> children = getChildrenOfType(category); diff --git a/source/MaterialXCore/Interface.cpp b/source/MaterialXCore/Interface.cpp index dfb84efc04..8f3f1b3acd 100644 --- a/source/MaterialXCore/Interface.cpp +++ b/source/MaterialXCore/Interface.cpp @@ -282,10 +282,11 @@ GeomPropDefPtr Input::getDefaultGeomProp() const const string& defaultGeomProp = getAttribute(DEFAULT_GEOM_PROP_ATTRIBUTE); if (!defaultGeomProp.empty()) { - ConstDocumentPtr doc = getDocument()->hasDataLibrary() ? - getDocument()->getRegisteredDataLibrary() : - getDocument(); - return doc->getChildOfType(defaultGeomProp); + ConstDocumentPtr doc = getDocument(); + if (doc->hasDataLibrary()) + return doc->getChildOfType(doc->getRegisteredDataLibrary(),defaultGeomProp); + else + return doc->getChildOfType(defaultGeomProp); } return nullptr; } diff --git a/source/MaterialXCore/Node.cpp b/source/MaterialXCore/Node.cpp index 7086cdebac..c63e7c0f6c 100644 --- a/source/MaterialXCore/Node.cpp +++ b/source/MaterialXCore/Node.cpp @@ -70,6 +70,7 @@ string Node::getConnectedNodeName(const string& inputName) const NodeDefPtr Node::getNodeDef(const string& target, bool allowRoughMatch) const { + // Collect document nodes vector nodeDefs = getDocument()->getMatchingNodeDefs(getQualifiedName(getCategory())); vector secondary = getDocument()->getMatchingNodeDefs(getCategory()); nodeDefs.insert(nodeDefs.end(), secondary.begin(), secondary.end()); diff --git a/source/MaterialXGenShader/ShaderGraph.cpp b/source/MaterialXGenShader/ShaderGraph.cpp index 371ef0d544..3b0c280e21 100644 --- a/source/MaterialXGenShader/ShaderGraph.cpp +++ b/source/MaterialXGenShader/ShaderGraph.cpp @@ -202,16 +202,11 @@ void ShaderGraph::addDefaultGeomNode(ShaderInput* input, const GeomPropDef& geom // input here and ignore the type of the geomprop. They are required to have the same type. string geomNodeDefName = "ND_" + geomprop.getGeomProp() + "_" + input->getType().getName(); NodeDefPtr geomNodeDef = _document->getNodeDef(geomNodeDefName); - if (!geomNodeDef && _document->hasDataLibrary()) - { - geomNodeDef = _document->getRegisteredDataLibrary()->getNodeDef(geomNodeDefName); if (!geomNodeDef) { - throw ExceptionShaderGenError("Could not find a nodedef named '" + geomNodeDefName + "' for defaultgeomprop on input '" + input->getFullName() + "'"); } - } ShaderNodePtr geomNode = ShaderNode::create(this, geomNodeName, *geomNodeDef, context); addNode(geomNode);