Skip to content

Commit

Permalink
Create DataLibrary class
Browse files Browse the repository at this point in the history
  • Loading branch information
ashwinbhat committed Oct 5, 2024
1 parent c473375 commit a3164d3
Show file tree
Hide file tree
Showing 15 changed files with 146 additions and 60 deletions.
70 changes: 70 additions & 0 deletions source/MaterialXCore/Datalibrary.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
//
// Copyright Contributors to the MaterialX Project
// SPDX-License-Identifier: Apache-2.0
//

#include <MaterialXCore/Document.h>
#include <MaterialXCore/Datalibrary.h>

MATERIALX_NAMESPACE_BEGIN

const DataLibraryPtr standardDataLibrary = DataLibrary::create();

DataLibraryPtr DataLibrary::create()
{
return std::make_shared<DataLibrary>();
}

// use loadDocuments to build this vector
void DataLibrary::loadDataLibrary(vector<DocumentPtr>& librarydocuments)
{
_datalibrary = createDocument();

for (auto library : librarydocuments)
{
for (auto child : library->getChildren())
{
if (child->getCategory().empty())
{
throw Exception("Trying to import child without a category: " + child->getName());
}

const string childName = child->getQualifiedName(child->getName());

// Check for duplicate elements.
ConstElementPtr previous = _datalibrary->getChild(childName);
if (previous)
{
continue;
}

// Create the imported element.
ElementPtr childCopy = _datalibrary->addChildOfCategory(child->getCategory(), childName);
childCopy->copyContentFrom(child);
if (!childCopy->hasFilePrefix() && library->hasFilePrefix())
{
childCopy->setFilePrefix(library->getFilePrefix());
}
if (!childCopy->hasGeomPrefix() && library->hasGeomPrefix())
{
childCopy->setGeomPrefix(library->getGeomPrefix());
}
if (!childCopy->hasColorSpace() && library->hasColorSpace())
{
childCopy->setColorSpace(library->getColorSpace());
}
if (!childCopy->hasNamespace() && library->hasNamespace())
{
childCopy->setNamespace(library->getNamespace());
}
if (!childCopy->hasSourceUri() && library->hasSourceUri())
{
childCopy->setSourceUri(library->getSourceUri());
}
}
}

}


MATERIALX_NAMESPACE_END
41 changes: 41 additions & 0 deletions source/MaterialXCore/Datalibrary.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//
// Copyright Contributors to the MaterialX Project
// SPDX-License-Identifier: Apache-2.0
//

#ifndef MATERIALX_DATALIBRARY
#define MATERIALX_DATALIBRARY

/// @file
/// The top-level DataLibrary class

#include <MaterialXCore/Document.h>

MATERIALX_NAMESPACE_BEGIN

class DataLibrary;
using DataLibraryPtr = shared_ptr<DataLibrary>;
using ConstDataLibraryPtr = shared_ptr<const DataLibrary>;

class MX_CORE_API DataLibrary
{
public:
ConstDocumentPtr dataLibrary()
{
return _datalibrary;
}

void loadDataLibrary(vector<DocumentPtr>& documents);

static DataLibraryPtr create();

private:
// Shared node library used across documents.
DocumentPtr _datalibrary;
};

extern MX_CORE_API const DataLibraryPtr standardDataLibrary;

MATERIALX_NAMESPACE_END

#endif
9 changes: 7 additions & 2 deletions source/MaterialXCore/Node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <MaterialXCore/Document.h>
#include <MaterialXCore/Material.h>
#include <MaterialXCore/Datalibrary.h>

#include <deque>

Expand Down Expand Up @@ -74,8 +75,12 @@ NodeDefPtr Node::getNodeDef(const string& target, bool allowRoughMatch) const
{
return resolveNameReference<NodeDef>(getNodeDefString());
}
vector<NodeDefPtr> nodeDefs = getDocument()->getMatchingNodeDefs(getQualifiedName(getCategory()));
vector<NodeDefPtr> secondary = getDocument()->getMatchingNodeDefs(getCategory());


// If a nodelibrary is not registered, use the document to locate nodedefs
ConstDocumentPtr datalibrarydoc = standardDataLibrary ? standardDataLibrary->dataLibrary() : getDocument();
vector<NodeDefPtr> nodeDefs = datalibrarydoc->getMatchingNodeDefs(getQualifiedName(getCategory()));
vector<NodeDefPtr> secondary = datalibrarydoc->getMatchingNodeDefs(getCategory());
vector<NodeDefPtr> roughMatches;
nodeDefs.insert(nodeDefs.end(), secondary.begin(), secondary.end());
for (NodeDefPtr nodeDef : nodeDefs)
Expand Down
9 changes: 0 additions & 9 deletions source/MaterialXGenShader/ColorManagementSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,8 @@ ColorManagementSystem::ColorManagementSystem()
{
}

void ColorManagementSystem::loadLibrary(DocumentPtr document)
{
_document = document;
}

bool ColorManagementSystem::supportsTransform(const ColorSpaceTransform& transform) const
{
if (!_document)
{
throw ExceptionShaderGenError("No library loaded for color management system");
}
return getNodeDef(transform) != nullptr;
}

Expand Down
7 changes: 1 addition & 6 deletions source/MaterialXGenShader/ColorManagementSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <MaterialXGenShader/TypeDesc.h>

#include <MaterialXCore/Document.h>
#include <MaterialXCore/Datalibrary.h>

MATERIALX_NAMESPACE_BEGIN

Expand Down Expand Up @@ -53,10 +54,6 @@ class MX_GENSHADER_API ColorManagementSystem
/// Return the ColorManagementSystem name
virtual const string& getName() const = 0;

/// Load a library of implementations from the provided document,
/// replacing any previously loaded content.
virtual void loadLibrary(DocumentPtr document);

/// Returns whether this color management system supports a provided transform
bool supportsTransform(const ColorSpaceTransform& transform) const;

Expand All @@ -71,8 +68,6 @@ class MX_GENSHADER_API ColorManagementSystem
/// Returns a nodedef for a given transform
virtual NodeDefPtr getNodeDef(const ColorSpaceTransform& transform) const = 0;

protected:
DocumentPtr _document;
};

MATERIALX_NAMESPACE_END
Expand Down
7 changes: 1 addition & 6 deletions source/MaterialXGenShader/DefaultColorManagementSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,11 @@ const string& DefaultColorManagementSystem::getName() const

NodeDefPtr DefaultColorManagementSystem::getNodeDef(const ColorSpaceTransform& transform) const
{
if (!_document)
{
throw ExceptionShaderGenError("No library loaded for color management system");
}

string sourceSpace = COLOR_SPACE_REMAP.count(transform.sourceSpace) ? COLOR_SPACE_REMAP.at(transform.sourceSpace) : transform.sourceSpace;
string targetSpace = COLOR_SPACE_REMAP.count(transform.targetSpace) ? COLOR_SPACE_REMAP.at(transform.targetSpace) : transform.targetSpace;
string nodeName = sourceSpace + "_to_" + targetSpace;

for (NodeDefPtr nodeDef : _document->getMatchingNodeDefs(nodeName))
for (NodeDefPtr nodeDef : standardDataLibrary->dataLibrary()->getMatchingNodeDefs(nodeName))
{
for (OutputPtr output : nodeDef->getOutputs())
{
Expand Down
3 changes: 2 additions & 1 deletion source/MaterialXGenShader/ShaderGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,13 @@ void ShaderGraph::addDefaultGeomNode(ShaderInput* input, const GeomPropDef& geom
const string geomNodeName = "geomprop_" + geomprop.getName();
ShaderNode* node = getNode(geomNodeName);

ConstDocumentPtr datalibrarydoc = standardDataLibrary ? standardDataLibrary->dataLibrary() : _document;
if (!node)
{
// Find the nodedef for the geometric node referenced by the geomprop. Use the type of the
// input here and ignore the type of the geomprop. They are required to have the same type.
string geomNodeDefName = "ND_" + geomprop.getGeomProp() + "_" + input->getType().getName();
NodeDefPtr geomNodeDef = _document->getNodeDef(geomNodeDefName);
NodeDefPtr geomNodeDef = datalibrarydoc->getNodeDef(geomNodeDefName);
if (!geomNodeDef)
{
throw ExceptionShaderGenError("Could not find a nodedef named '" + geomNodeDefName +
Expand Down
14 changes: 2 additions & 12 deletions source/MaterialXGenShader/UnitSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,6 @@ UnitSystem::UnitSystem(const string& target) :
{
}

void UnitSystem::loadLibrary(DocumentPtr document)
{
_document = document;
}

void UnitSystem::setUnitConverterRegistry(UnitConverterRegistryPtr registry)
{
_unitRegistry = registry;
Expand All @@ -145,13 +140,8 @@ UnitSystemPtr UnitSystem::create(const string& language)

NodeDefPtr UnitSystem::getNodeDef(const UnitTransform& transform) const
{
if (!_document)
{
throw ExceptionShaderGenError("No library loaded for unit system");
}

const string MULTIPLY_NODE_NAME = "multiply";
for (NodeDefPtr nodeDef : _document->getMatchingNodeDefs(MULTIPLY_NODE_NAME))
for (NodeDefPtr nodeDef : standardDataLibrary->dataLibrary()->getMatchingNodeDefs(MULTIPLY_NODE_NAME))
{
for (OutputPtr output : nodeDef->getOutputs())
{
Expand Down Expand Up @@ -183,7 +173,7 @@ ShaderNodePtr UnitSystem::createNode(ShaderGraph* parent, const UnitTransform& t
}

// Scalar unit conversion
UnitTypeDefPtr scalarTypeDef = _document->getUnitTypeDef(transform.unitType);
UnitTypeDefPtr scalarTypeDef = standardDataLibrary->dataLibrary()->getUnitTypeDef(transform.unitType);
if (!_unitRegistry || !_unitRegistry->getUnitConverter(scalarTypeDef))
{
throw ExceptionTypeError("Unit registry unavaliable or undefined unit converter for: " + transform.unitType);
Expand Down
5 changes: 1 addition & 4 deletions source/MaterialXGenShader/UnitSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <MaterialXCore/Unit.h>

#include <MaterialXCore/Document.h>
#include <MaterialXCore/Datalibrary.h>

MATERIALX_NAMESPACE_BEGIN

Expand Down Expand Up @@ -68,9 +69,6 @@ class MX_GENSHADER_API UnitSystem
/// Returns the currently assigned unit converter registry
virtual UnitConverterRegistryPtr getUnitConverterRegistry() const;

/// assign document with unit implementations replacing any previously loaded content.
virtual void loadLibrary(DocumentPtr document);

/// Returns whether this unit system supports a provided transform
bool supportsTransform(const UnitTransform& transform) const;

Expand All @@ -89,7 +87,6 @@ class MX_GENSHADER_API UnitSystem

protected:
UnitConverterRegistryPtr _unitRegistry;
DocumentPtr _document;
string _target;
};

Expand Down
2 changes: 0 additions & 2 deletions source/MaterialXRender/TextureBaker.inl
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,6 @@ DocumentPtr TextureBaker<Renderer, ShaderGen>::bakeMaterialToDoc(DocumentPtr doc
genContext.getOptions().targetDistanceUnit = _distanceUnit;

DefaultColorManagementSystemPtr cms = DefaultColorManagementSystem::create(genContext.getShaderGenerator().getTarget());
cms->loadLibrary(doc);
genContext.registerSourceCodeSearchPath(searchPath);
genContext.getShaderGenerator().setColorManagementSystem(cms);

Expand Down Expand Up @@ -644,7 +643,6 @@ void TextureBaker<Renderer, ShaderGen>::setupUnitSystem(DocumentPtr unitDefiniti
UnitConverterRegistryPtr registry = UnitConverterRegistry::create();
registry->addUnitConverter(distanceTypeDef, LinearUnitConverter::create(distanceTypeDef));
registry->addUnitConverter(angleTypeDef, LinearUnitConverter::create(angleTypeDef));
_generator->getUnitSystem()->loadLibrary(unitDefinitions);
_generator->getUnitSystem()->setUnitConverterRegistry(registry);
}

Expand Down
33 changes: 21 additions & 12 deletions source/MaterialXTest/MaterialXGenShader/GenShaderUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,14 +312,30 @@ void testUniqueNames(mx::GenContext& context, const std::string& stage)
REQUIRE(sgNode1->getOutput()->getVariable() == "unique_names_out");
}

void loadDefaultDataLibrary()
{
const mx::FileSearchPath libSearchPath(mx::getDefaultDataSearchPath());

// Load the standard libraries.
mx::FilePath libraryroot(libSearchPath.asString() + "/libraries");
std::vector<mx::DocumentPtr> documentList;
mx::StringVec libdocumentsPaths;
mx::StringVec liberrorLog;
mx::loadDocuments(libraryroot, libSearchPath, {}, {}, documentList, libdocumentsPaths, nullptr, &liberrorLog);

if (liberrorLog.size() == 0)
mx::standardDataLibrary->loadDataLibrary(documentList);
}

// Test ShaderGen performance
void shaderGenPerformanceTest(mx::GenContext& context)
{
mx::DocumentPtr nodeLibrary = mx::createDocument();
const mx::FileSearchPath libSearchPath(mx::getDefaultDataSearchPath());

// Load the standard libraries.
loadLibraries({ "libraries" }, libSearchPath, nodeLibrary);
loadDefaultDataLibrary();

//loadLibraries({ "libraries" }, libSearchPath, nodeLibrary);
context.registerSourceCodeSearchPath(libSearchPath);

// Enable Color Management
Expand All @@ -328,23 +344,19 @@ void shaderGenPerformanceTest(mx::GenContext& context)

REQUIRE(colorManagementSystem);
if (colorManagementSystem)
{
context.getShaderGenerator().setColorManagementSystem(colorManagementSystem);
colorManagementSystem->loadLibrary(nodeLibrary);
}

// Enable Unit System
mx::UnitSystemPtr unitSystem = mx::UnitSystem::create(context.getShaderGenerator().getTarget());
REQUIRE(unitSystem);
if (unitSystem)
{
context.getShaderGenerator().setUnitSystem(unitSystem);
unitSystem->loadLibrary(nodeLibrary);
// Setup Unit converters
unitSystem->setUnitConverterRegistry(mx::UnitConverterRegistry::create());
mx::UnitTypeDefPtr distanceTypeDef = nodeLibrary->getUnitTypeDef("distance");
mx::UnitTypeDefPtr distanceTypeDef = mx::standardDataLibrary->dataLibrary()->getUnitTypeDef("distance");
unitSystem->getUnitConverterRegistry()->addUnitConverter(distanceTypeDef, mx::LinearUnitConverter::create(distanceTypeDef));
mx::UnitTypeDefPtr angleTypeDef = nodeLibrary->getUnitTypeDef("angle");
mx::UnitTypeDefPtr angleTypeDef = mx::standardDataLibrary->dataLibrary()->getUnitTypeDef("angle");
unitSystem->getUnitConverterRegistry()->addUnitConverter(angleTypeDef, mx::LinearUnitConverter::create(angleTypeDef));
context.getOptions().targetDistanceUnit = "meter";
}
Expand All @@ -357,7 +369,6 @@ void shaderGenPerformanceTest(mx::GenContext& context)
std::vector<mx::DocumentPtr> loadedDocuments;
mx::StringVec documentsPaths;
mx::StringVec errorLog;

for (const auto& testRoot : testRootPaths)
{
mx::loadDocuments(testRoot, libSearchPath, {}, {}, loadedDocuments, documentsPaths,
Expand All @@ -372,7 +383,6 @@ void shaderGenPerformanceTest(mx::GenContext& context)
std::shuffle(loadedDocuments.begin(), loadedDocuments.end(), rng);
for (const auto& doc : loadedDocuments)
{
doc->importLibrary(nodeLibrary);
std::vector<mx::TypedElementPtr> elements = mx::findRenderableElements(doc);

REQUIRE(elements.size() > 0);
Expand All @@ -389,6 +399,7 @@ void shaderGenPerformanceTest(mx::GenContext& context)
REQUIRE(shader != nullptr);
REQUIRE(shader->getSourceCode(mx::Stage::PIXEL).length() > 0);
}

}

void ShaderGeneratorTester::checkImplementationUsage(const mx::StringSet& usedImpls,
Expand Down Expand Up @@ -519,7 +530,6 @@ void ShaderGeneratorTester::addColorManagement()
else
{
_shaderGenerator->setColorManagementSystem(_colorManagementSystem);
_colorManagementSystem->loadLibrary(_dependLib);
}
}
}
Expand All @@ -537,7 +547,6 @@ void ShaderGeneratorTester::addUnitSystem()
else
{
_shaderGenerator->setUnitSystem(_unitSystem);
_unitSystem->loadLibrary(_dependLib);
_unitSystem->setUnitConverterRegistry(mx::UnitConverterRegistry::create());
mx::UnitTypeDefPtr distanceTypeDef = _dependLib->getUnitTypeDef("distance");
_unitSystem->getUnitConverterRegistry()->addUnitConverter(distanceTypeDef, mx::LinearUnitConverter::create(distanceTypeDef));
Expand Down
Loading

0 comments on commit a3164d3

Please sign in to comment.