diff --git a/source/MaterialXGenMdl/MdlShaderGenerator.cpp b/source/MaterialXGenMdl/MdlShaderGenerator.cpp index 71db0d7cf4..1bc65d5cc4 100644 --- a/source/MaterialXGenMdl/MdlShaderGenerator.cpp +++ b/source/MaterialXGenMdl/MdlShaderGenerator.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -157,6 +158,24 @@ ShaderPtr MdlShaderGenerator::generate(const string& name, ElementPtr element, G emitLineEnd(stage, true); } + // Emit custom node imports for nodes in the graph + for (ShaderNode* node : graph.getNodes()) + { + const ShaderNodeImpl& impl = node->getImplementation(); + const CustomCodeNodeMdl* customNode = dynamic_cast(&impl); + if (customNode) + { + const string& importName = customNode->getQualifiedModuleName(); + if (!importName.empty()) + { + emitString("import ", stage); + emitString(importName, stage); + emitString("::*", stage); + emitLineEnd(stage, true); + } + } + } + // Add global constants and type definitions emitTypeDefinitions(context, stage); @@ -353,14 +372,31 @@ ShaderNodeImplPtr MdlShaderGenerator::getImplementation(const NodeDef& nodedef, impl = _implFactory.create(name); if (!impl) { - // Fall back to source code implementation. - if (outputType.isClosure()) + // When `file` and `function` are provided we consider this node a user node + const string file = implElement->getTypedAttribute("file"); + const string function = implElement->getTypedAttribute("function"); + // Or, if `sourcecode` is provided we consider this node a user node with inline implementation + // inline implementations are not supposed to have replacement markers + const string sourcecode = implElement->getTypedAttribute("sourcecode"); + if ((!file.empty() && !function.empty()) || (!sourcecode.empty() && sourcecode.find("{{") == string::npos)) + { + impl = CustomCodeNodeMdl::create(); + } + else if (file.empty() && sourcecode.empty()) { - impl = ClosureSourceCodeNodeMdl::create(); + throw ExceptionShaderGenError("No valid MDL implementation found for '" + name + "'"); } else { - impl = SourceCodeNodeMdl::create(); + // Fall back to source code implementation. + if (outputType.isClosure()) + { + impl = ClosureSourceCodeNodeMdl::create(); + } + else + { + impl = SourceCodeNodeMdl::create(); + } } } } @@ -386,6 +422,7 @@ string MdlShaderGenerator::getUpstreamResult(const ShaderInput* input, GenContex return ShaderGenerator::getUpstreamResult(input, context); } + const MdlSyntax& mdlSyntax = static_cast(getSyntax()); string variable; const ShaderNode* upstreamNode = upstreamOutput->getNode(); if (!upstreamNode->isAGraph() && upstreamNode->numOutputs() > 1) @@ -397,7 +434,18 @@ string MdlShaderGenerator::getUpstreamResult(const ShaderInput* input, GenContex } else { - variable = upstreamNode->getName() + "_result.mxp_" + upstreamOutput->getName(); + const string& fieldName = upstreamOutput->getName(); + const CustomCodeNodeMdl* upstreamCustomNodeMdl = dynamic_cast(&upstreamNode->getImplementation()); + if (upstreamCustomNodeMdl) + { + // Prefix the port name depending on the CustomCodeNode + variable = upstreamNode->getName() + "_result." + upstreamCustomNodeMdl->modifyPortName(fieldName, mdlSyntax); + } + else + { + // Existing implementations and none user defined structs will keep the prefix always to not break existing content + variable = upstreamNode->getName() + "_result." + mdlSyntax.modifyPortName(upstreamOutput->getName()); + } } } else diff --git a/source/MaterialXGenMdl/MdlSyntax.cpp b/source/MaterialXGenMdl/MdlSyntax.cpp index 0244a8352f..c0859f5ae2 100644 --- a/source/MaterialXGenMdl/MdlSyntax.cpp +++ b/source/MaterialXGenMdl/MdlSyntax.cpp @@ -29,6 +29,8 @@ TYPEDESC_REGISTER_TYPE(MDL_SCATTER_MODE, "scatter_mode") namespace { +const string MARKER_MDL_VERSION_SUFFIX = "MDL_VERSION_SUFFIX"; + class MdlFilenameTypeSyntax : public ScalarTypeSyntax { public: @@ -195,6 +197,8 @@ const StringVec MdlSyntax::FILTERTYPE_MEMBERS = { "box", "gaussian" }; const StringVec MdlSyntax::DISTRIBUTIONTYPE_MEMBERS = { "ggx" }; const StringVec MdlSyntax::SCATTER_MODE_MEMBERS = { "R", "T", "RT" }; +const string MdlSyntax::PORT_NAME_PREFIX = "mxp_"; + // // MdlSyntax methods // @@ -202,22 +206,40 @@ const StringVec MdlSyntax::SCATTER_MODE_MEMBERS = { "R", "T", "RT" }; MdlSyntax::MdlSyntax() { // Add in all reserved words and keywords in MDL + // Formatted as in the MDL Specification 1.9.2 for easy comparing registerReservedWords( - { // Reserved words - "annotation", "bool", "bool2", "bool3", "bool4", "break", "bsdf", "bsdf_measurement", "case", "cast", "color", "const", - "continue", "default", "do", "double", "double2", "double2x2", "double2x3", "double3", "double3x2", "double3x3", "double3x4", - "double4", "double4x3", "double4x4", "double4x2", "double2x4", "edf", "else", "enum", "export", "false", "float", "float2", - "float2x2", "float2x3", "float3", "float3x2", "float3x3", "float3x4", "float4", "float4x3", "float4x4", "float4x2", "float2x4", - "for", "hair_bsdf", "if", "import", "in", "int", "int2", "int3", "int4", "intensity_mode", "intensity_power", "intensity_radiant_exitance", - "let", "light_profile", "material", "material_emission", "material_geometry", "material_surface", "material_volume", "mdl", "module", - "package", "return", "string", "struct", "switch", "texture_2d", "texture_3d", "texture_cube", "texture_ptex", "true", "typedef", "uniform", - "using", "varying", "vdf", "while", - // Reserved for future use - "auto", "catch", "char", "class", "const_cast", "delete", "dynamic_cast", "explicit", "extern", "external", "foreach", "friend", "goto", - "graph", "half", "half2", "half2x2", "half2x3", "half3", "half3x2", "half3x3", "half3x4", "half4", "half4x3", "half4x4", "half4x2", "half2x4", - "inline", "inout", "lambda", "long", "mutable", "namespace", "native", "new", "operator", "out", "phenomenon", "private", "protected", "public", - "reinterpret_cast", "sampler", "shader", "short", "signed", "sizeof", "static", "static_cast", "technique", "template", "this", "throw", "try", - "typeid", "typename", "union", "unsigned", "virtual", "void", "volatile", "wchar_t" }); + { // Reserved words + "annotation", "double2", "float", "in", "operator", + "auto", "double2x2", "float2", "int", "package", + "bool", "double2x3", "float2x2", "int2", "return", + "bool2", "double3", "float2x3", "int3", "string", + "bool3", "double3x2", "float3", "int4", "struct", + "bool4", "double3x3", "float3x2", "intensity_mode", "struct_category", + "break", "double3x4", "float3x3", "intensity_power", "switch", + "bsdf", "double4", "float3x4", "intensity_radiant_exitance", "texture_2d", + "bsdf_measurement", "double4x3", "float4", "let", "texture_3d", + "case", "double4x4", "float4x3", "light_profile", "texture_cube", + "cast", "double4x2", "float4x4", "material", "texture_ptex", + "color", "double2x4", "float4x2", "material_emission", "true", + "const", "edf", "float2x4", "material_geometry", "typedef", + "continue", "else", "for", "material_surface", "uniform", + "declarative", "enum", "hair_bsdf", "material_volume", "using", + "default", "export", "if", "mdl", "varying", + "do", "false", "import", "module", "vdf", + "double", "while", + + // Reserved for future use + "catch", "friend", "half3x4", "mutable", "sampler", "throw", + "char", "goto", "half4", "namespace", "shader", "try", + "class", "graph", "half4x3", "native", "short", "typeid", + "const_cast", "half", "half4x4", "new", "signed", "typename", + "delete", "half2", "half4x2", "out", "sizeof", "union", + "dynamic_cast", "half2x2", "half2x4", "phenomenon", "static", "unsigned", + "explicit", "half2x3", "inline", "private", "static_cast", "virtual", + "extern", "half3", "inout", "protected", "technique", "void", + "external", "half3x2", "lambda", "public", "template", "volatile", + "foreach", "half3x3", "long", "reinterpret_cast", "this", "wchar_t", + }); // Register restricted tokens in MDL StringMap tokens; @@ -533,4 +555,41 @@ void MdlSyntax::makeValidName(string& name) const } } +string MdlSyntax::modifyPortName(const string& word) const +{ + return PORT_NAME_PREFIX + word; +} + +string MdlSyntax::replaceSourceCodeMarkers(const string& nodeName, const string& soureCode, std::function lambda) const +{ + // An inline function call + // Replace tokens of the format "{{}}" + static const string prefix("{{"); + static const string postfix("}}"); + + size_t pos = 0; + size_t i = soureCode.find_first_of(prefix); + StringVec code; + while (i != string::npos) + { + code.push_back(soureCode.substr(pos, i - pos)); + size_t j = soureCode.find_first_of(postfix, i + 2); + if (j == string::npos) + { + throw ExceptionShaderGenError("Malformed inline expression in implementation for node " + nodeName); + } + const string marker = soureCode.substr(i + 2, j - i - 2); + code.push_back(lambda(marker)); + pos = j + 2; + i = soureCode.find_first_of(prefix, pos); + } + code.push_back(soureCode.substr(pos)); + return joinStrings(code, EMPTY_STRING); +} + +const string MdlSyntax::getMdlVersionSuffixMarker() const +{ + return MARKER_MDL_VERSION_SUFFIX; +} + MATERIALX_NAMESPACE_END diff --git a/source/MaterialXGenMdl/MdlSyntax.h b/source/MaterialXGenMdl/MdlSyntax.h index e282cefeca..48fd6214c0 100644 --- a/source/MaterialXGenMdl/MdlSyntax.h +++ b/source/MaterialXGenMdl/MdlSyntax.h @@ -53,6 +53,7 @@ class MX_GENMDL_API MdlSyntax : public Syntax static const StringVec FILTERTYPE_MEMBERS; static const StringVec DISTRIBUTIONTYPE_MEMBERS; static const StringVec SCATTER_MODE_MEMBERS; + static const string PORT_NAME_PREFIX; // Applied to input and output names to avoid collisions with reserved words in MDL /// Get an type description for an enumeration based on member value TypeDesc getEnumeratedType(const string& value) const; @@ -63,6 +64,16 @@ class MX_GENMDL_API MdlSyntax : public Syntax /// Modify the given name string to remove any invalid characters or tokens. void makeValidName(string& name) const override; + + /// To avoid collisions with reserved names in MDL, input and output names are prefixed. + string modifyPortName(const string& word) const; + + /// Replaces all markers in a source code string indicated by {{...}}. + /// The replacement is defined by a callback function. + string replaceSourceCodeMarkers(const string& nodeName, const string& soureCode, std::function lambda) const; + + /// Get the MDL language versing marker: {{MDL_VERSION_SUFFIX}}. + const string getMdlVersionSuffixMarker() const; }; namespace Type diff --git a/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp b/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp index 0a2004cd71..59b488c875 100644 --- a/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp +++ b/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp @@ -4,6 +4,7 @@ // #include +#include #include #include @@ -29,6 +30,7 @@ void ClosureCompoundNodeMdl::emitFunctionDefinition(const ShaderNode& node, GenC DEFINE_SHADER_STAGE(stage, Stage::PIXEL) { const ShaderGenerator& shadergen = context.getShaderGenerator(); + const MdlSyntax& mdlSyntax = static_cast(shadergen.getSyntax()); // Emit functions for all child nodes shadergen.emitFunctionDefinitions(*_rootGraph, context, stage); @@ -146,7 +148,7 @@ void ClosureCompoundNodeMdl::emitFunctionDefinition(const ShaderNode& node, GenC for (const ShaderGraphOutputSocket* output : _rootGraph->getOutputSockets()) { const string result = shadergen.getUpstreamResult(output, context); - shadergen.emitLine(resultVariableName + ".mxp_" + output->getName() + " = " + result, stage); + shadergen.emitLine(resultVariableName + mdlSyntax.modifyPortName(output->getName()) + " = " + result, stage); } shadergen.emitLine("return " + resultVariableName, stage); } diff --git a/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp b/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp index 80c449e507..2dab721e24 100644 --- a/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp +++ b/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -46,6 +47,7 @@ void CompoundNodeMdl::emitFunctionDefinition(const ShaderNode& node, GenContext& DEFINE_SHADER_STAGE(stage, Stage::PIXEL) { const ShaderGenerator& shadergen = context.getShaderGenerator(); + const MdlSyntax& syntax = static_cast(shadergen.getSyntax()); const bool isMaterialExpr = (_rootGraph->hasClassification(ShaderNode::Classification::CLOSURE) || _rootGraph->hasClassification(ShaderNode::Classification::SHADER)); @@ -83,7 +85,7 @@ void CompoundNodeMdl::emitFunctionDefinition(const ShaderNode& node, GenContext& for (const ShaderGraphOutputSocket* output : _rootGraph->getOutputSockets()) { const string result = shadergen.getUpstreamResult(output, context); - shadergen.emitLine(resultVariableName + ".mxp_" + output->getName() + " = " + result, stage); + shadergen.emitLine(resultVariableName + "." + syntax.modifyPortName(output->getName()) + " = " + result, stage); } shadergen.emitLine("return " + resultVariableName, stage); } @@ -180,7 +182,7 @@ void CompoundNodeMdl::emitFunctionCall(const ShaderNode& node, GenContext& conte void CompoundNodeMdl::emitFunctionSignature(const ShaderNode&, GenContext& context, ShaderStage& stage) const { const ShaderGenerator& shadergen = context.getShaderGenerator(); - const Syntax& syntax = shadergen.getSyntax(); + const MdlSyntax& syntax = static_cast(shadergen.getSyntax()); if (!_returnStruct.empty()) { @@ -208,7 +210,7 @@ void CompoundNodeMdl::emitFunctionSignature(const ShaderNode&, GenContext& conte shadergen.emitScopeBegin(stage, Syntax::CURLY_BRACKETS); for (const ShaderGraphOutputSocket* output : _rootGraph->getOutputSockets()) { - shadergen.emitLine(syntax.getTypeName(output->getType()) + " mxp_" + output->getName(), stage); + shadergen.emitLine(syntax.getTypeName(output->getType()) + " " + syntax.modifyPortName(output->getName()), stage); } shadergen.emitScopeEnd(stage, true); shadergen.emitLineBreak(stage); diff --git a/source/MaterialXGenMdl/Nodes/CustomNodeMdl.cpp b/source/MaterialXGenMdl/Nodes/CustomNodeMdl.cpp new file mode 100644 index 0000000000..7215b90256 --- /dev/null +++ b/source/MaterialXGenMdl/Nodes/CustomNodeMdl.cpp @@ -0,0 +1,278 @@ +// +// Copyright Contributors to the MaterialX Project +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include + +#include +#include +#include +#include +#include + +MATERIALX_NAMESPACE_BEGIN + +ShaderNodeImplPtr CustomCodeNodeMdl::create() +{ + return std::make_shared(); +} + +const string& CustomCodeNodeMdl::getQualifiedModuleName() const +{ + return _qualifiedModuleName; +} + +string CustomCodeNodeMdl::modifyPortName(const string& name, const MdlSyntax& syntax) const +{ + if (_useExternalSourceCode) + { + const StringSet& reservedWords = syntax.getReservedWords(); + if (reservedWords.find(name) == reservedWords.end()) + { + // Use existing MDL parameter names if they don't collide with a reserved word. + // This allows us to reference MDL existing functions without changing the MDL source code. + return name; + } + } + return syntax.modifyPortName(name); +} + +void CustomCodeNodeMdl::initialize(const InterfaceElement& element, GenContext& context) +{ + SourceCodeNodeMdl::initialize(element, context); + if (_inlined) + { + _useExternalSourceCode = false; + initializeForInlineSourceCode(element, context); + } + else + { + _useExternalSourceCode = true; + initializeForExternalSourceCode(element, context); + } +} + +void CustomCodeNodeMdl::initializeForInlineSourceCode(const InterfaceElement& element, GenContext& context) +{ + const Implementation& impl = static_cast(element); + // Store the inline source because the `_functionSource` is used for the function call template string + // that matched the regular MaterialX to MDL function mapping. + _inlineSourceCode = impl.getAttribute("sourcecode"); + if (_inlineSourceCode.empty()) + { + throw ExceptionShaderGenError("No source code was specified for the implementation '" + impl.getName() + "'"); + } + if (_inlineSourceCode.find("//") != string::npos) + { + throw ExceptionShaderGenError("Source code contains unsupported comments '//', please use '/* comment */' instead in '" + impl.getName() + "'"); + } + + NodeDefPtr nodeDef = impl.getNodeDef(); + _inlineFunctionName = nodeDef->getName(); + _hash = std::hash{}(_inlineFunctionName); // make sure we emit the function definition only once + + const ShaderGenerator& shadergen = context.getShaderGenerator(); + const MdlSyntax& syntax = static_cast(shadergen.getSyntax()); + // Construct the function call template string + initializeFunctionCallTemplateString(syntax, *nodeDef); + // Collect information about output names and defaults + initializeOutputDefaults(syntax, *nodeDef); +} + +void CustomCodeNodeMdl::initializeForExternalSourceCode(const InterfaceElement& element, GenContext& context) +{ + // Format the function source in a way that the ShaderCodeNodeMdl (the base class of the current one) can deal with it + const ShaderGenerator& shadergen = context.getShaderGenerator(); + const MdlShaderGenerator& shadergenMdl = static_cast(shadergen); + const MdlSyntax& syntax = static_cast(shadergen.getSyntax()); + const string uniformPrefix = syntax.getUniformQualifier() + " "; + + // Map `file` to a qualified MDL module name + const Implementation& impl = static_cast(element); + string moduleName = impl.getAttribute("file"); + if (moduleName.empty()) + { + throw ExceptionShaderGenError("No source file was specified for the implementation '" + impl.getName() + "'"); + } + if (_functionName.empty()) + { + throw ExceptionShaderGenError("No function name was specified for the implementation '" + impl.getName() + "'"); + } + + string mdlModuleName = replaceSubstrings(moduleName, { { "/", "::" } }); + if (!stringStartsWith(mdlModuleName, "::")) + { + mdlModuleName = "::" + mdlModuleName; + } + + if (!stringEndsWith(mdlModuleName, ".mdl")) + { + throw ExceptionShaderGenError("Referenced source file is not an MDL module: '" + moduleName + + "' used by implementation '" + impl.getName() + "'"); + } + else + { + mdlModuleName = mdlModuleName.substr(0, mdlModuleName.size() - 4); + } + const string versionSuffix = shadergenMdl.getMdlVersionFilenameSuffix(context); + _qualifiedModuleName = syntax.replaceSourceCodeMarkers(element.getName(), mdlModuleName, + [&versionSuffix, &syntax](const string& marker) + { + return marker == syntax.getMdlVersionSuffixMarker() ? versionSuffix : marker; + }); + + NodeDefPtr nodeDef = impl.getNodeDef(); + // Construct the function call template string + initializeFunctionCallTemplateString(syntax, *nodeDef); + // Collect information about output names and defaults + initializeOutputDefaults(syntax, *nodeDef); +} + +void CustomCodeNodeMdl::initializeFunctionCallTemplateString(const MdlSyntax& syntax, const NodeDef& nodeDef) +{ + // Construct the fully qualified function name for external functions + if (_useExternalSourceCode) + { + _functionSource = _qualifiedModuleName.substr(2) + "::" + _functionName + "("; + } + // or simple name for local functions + else + { + _functionSource = _inlineFunctionName + "("; + } + + // Function parameters + string delim = EMPTY_STRING; + for (const InputPtr& input : nodeDef.getInputs()) + { + string inputName = modifyPortName(input->getName(), syntax); + _functionSource += delim + inputName + ": {{" + input->getName() + "}}"; + if (delim == EMPTY_STRING) + delim = Syntax::COMMA + " "; + } + _functionSource += ")"; + _inlined = true; +} + +void CustomCodeNodeMdl::initializeOutputDefaults(const MdlSyntax&, const NodeDef& nodeDef) +{ + for (const OutputPtr& output : nodeDef.getOutputs()) + { + _outputDefaults.push_back(output->getValue()); + } +} + +void CustomCodeNodeMdl::emitFunctionDefinition(const ShaderNode& node, GenContext& context, ShaderStage& stage) const +{ + // No source code printing for externally defined functions + if (_useExternalSourceCode) + { + return; + } + + const ShaderGenerator& shadergen = context.getShaderGenerator(); + const MdlSyntax& syntax = static_cast(shadergen.getSyntax()); + shadergen.emitComment("generated code for implementation: '" + node.getImplementation().getName() + "'", stage); + + // Function return type + struct Field + { + string name; + string type_name; + string default_value; + }; + vector outputs; + size_t i = 0; + for (const ShaderOutput* output : node.getOutputs()) + { + string name = modifyPortName(output->getName(), syntax); + TypeDesc type = output->getType(); + const ValuePtr defaultValue = _outputDefaults[i]; + outputs.push_back({ + name, + syntax.getTypeName(type), + defaultValue ? syntax.getValue(type, *defaultValue.get()) : syntax.getDefaultValue(type) + }); + ++i; + } + + size_t numOutputs = node.getOutputs().size(); + string returnTypeName; + if (numOutputs == 1) + { + returnTypeName = outputs.back().type_name; + } + else + { + returnTypeName = _inlineFunctionName + "_return_type"; + shadergen.emitLine("struct " + returnTypeName, stage, false); + shadergen.emitScopeBegin(stage, Syntax::CURLY_BRACKETS); + for (const auto& field : outputs) + { + // ignore the default values here, they have to be initialized in the body + shadergen.emitLine(field.type_name + " " + field.name, stage); + } + shadergen.emitScopeEnd(stage, Syntax::CURLY_BRACKETS); + shadergen.emitLineEnd(stage, false); + } + // Signature + shadergen.emitString(returnTypeName + " " + _inlineFunctionName, stage); + { + // Function parameters + shadergen.emitScopeBegin(stage, Syntax::PARENTHESES); + size_t paramCount = node.getInputs().size(); + const string uniformPrefix = syntax.getUniformQualifier() + " "; + for (const ShaderInput* input : node.getInputs()) + { + const string& qualifier = input->isUniform() || input->getType() == Type::FILENAME ? uniformPrefix : EMPTY_STRING; + const string& type = syntax.getTypeName(input->getType()); + const string name = modifyPortName(input->getName(), syntax); + const string& delim = --paramCount == 0 ? EMPTY_STRING : Syntax::COMMA; + shadergen.emitString(" " + qualifier + type + " " + name + delim + Syntax::NEWLINE, stage); + } + shadergen.emitScopeEnd(stage, false, true); + } + { + // Function body + shadergen.emitScopeBegin(stage, Syntax::CURLY_BRACKETS); + + // Out variable initialization + shadergen.emitComment("initialize outputs:", stage); + for (const auto& field : outputs) + { + shadergen.emitLine(field.type_name + " " + field.name + " = " + field.default_value, stage); + } + + // User defined code + shadergen.emitComment("inlined shader source code:", stage); + shadergen.emitLine(_inlineSourceCode, stage, false); + + // Output packing + shadergen.emitComment("pack (in case of multiple outputs) and return outputs:", stage); + if (numOutputs == 1) + { + shadergen.emitLine("return " + outputs.back().name, stage, true); + } + else + { + // Return a constructor call of the return struct type + shadergen.emitString(" return " + returnTypeName + "(", stage); + string delim = EMPTY_STRING; + for (const auto& field : outputs) + { + shadergen.emitString(delim + field.name, stage); + if (delim == EMPTY_STRING) + delim = Syntax::COMMA + " "; + } + shadergen.emitString(")", stage); + shadergen.emitLineEnd(stage, true); + } + shadergen.emitScopeEnd(stage, false, true); + } + shadergen.emitLine("", stage, false); // empty line for spacing +} + +MATERIALX_NAMESPACE_END diff --git a/source/MaterialXGenMdl/Nodes/CustomNodeMdl.h b/source/MaterialXGenMdl/Nodes/CustomNodeMdl.h new file mode 100644 index 0000000000..daf4c873a5 --- /dev/null +++ b/source/MaterialXGenMdl/Nodes/CustomNodeMdl.h @@ -0,0 +1,56 @@ +// +// Copyright Contributors to the MaterialX Project +// SPDX-License-Identifier: Apache-2.0 +// + +#ifndef MATERIALX_CUSTOMNODEMDL_H +#define MATERIALX_CUSTOMNODEMDL_H + +#include + +MATERIALX_NAMESPACE_BEGIN + +class MdlSyntax; +class NodeDef; + +/// Node to handle user defined implementations in external MDL files or using the inline `sourcecode` attribute. +class MX_GENMDL_API CustomCodeNodeMdl : public SourceCodeNodeMdl +{ + public: + static ShaderNodeImplPtr create(); + void initialize(const InterfaceElement& element, GenContext& context) override; + void emitFunctionDefinition(const ShaderNode& node, GenContext& context, ShaderStage& stage) const override; + + /// Get the MDL qualified name of the externally references user module. + /// It's used for import statements and functions calls in the generated target code. + const string& getQualifiedModuleName() const; + + /// To avoid collisions with reserved names in MDL, input and output names are prefixed. + /// In the `sourcecode` case all inputs and outputs are prefixed so authors don't need knowledge about reserved words in MDL. + /// In the `file` and `function` case, only reserved names are prefixed to support existing MDL implementations without changes. + string modifyPortName(const string& name, const MdlSyntax& syntax) const; + + protected: + /// Initialize function for nodes that use the inline `sourcecode` attribute. + void initializeForInlineSourceCode(const InterfaceElement& element, GenContext& context); + + /// Initialize function for nodes that use the `file` and `function` attribute. + void initializeForExternalSourceCode(const InterfaceElement& element, GenContext& context); + + /// Computes the function call string with replacement markers use by base class. + void initializeFunctionCallTemplateString(const MdlSyntax& syntax, const NodeDef& node); + + /// Keep track of the default values needed for the inline `sourcecode` case. + void initializeOutputDefaults(const MdlSyntax& syntax, const NodeDef& node); + + std::vector _outputDefaults; ///< store default values of the node definition + + bool _useExternalSourceCode; // Indicates that `file` and `function` are used by this node implementation + string _inlineFunctionName; // Name of the functionDefinition to emit + string _inlineSourceCode; // The actual inline source code + string _qualifiedModuleName; // MDL qualified name derived from the `file` attribute +}; + +MATERIALX_NAMESPACE_END + +#endif diff --git a/source/MaterialXGenMdl/Nodes/ImageNodeMdl.h b/source/MaterialXGenMdl/Nodes/ImageNodeMdl.h index fe88b2ce09..8363c6f4fa 100644 --- a/source/MaterialXGenMdl/Nodes/ImageNodeMdl.h +++ b/source/MaterialXGenMdl/Nodes/ImageNodeMdl.h @@ -18,7 +18,7 @@ class MX_GENMDL_API ImageNodeMdl : public SourceCodeNodeMdl using BASE = SourceCodeNodeMdl; public: - static const string FLIP_V; ///< the empty string "" + static const string FLIP_V; ///< name of the additional parameter "flip_v" static ShaderNodeImplPtr create(); diff --git a/source/MaterialXGenMdl/Nodes/MaterialNodeMdl.cpp b/source/MaterialXGenMdl/Nodes/MaterialNodeMdl.cpp index d13aa2871f..6891fc1778 100644 --- a/source/MaterialXGenMdl/Nodes/MaterialNodeMdl.cpp +++ b/source/MaterialXGenMdl/Nodes/MaterialNodeMdl.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -32,6 +33,7 @@ void MaterialNodeMdl::emitFunctionCall(const ShaderNode& _node, GenContext& cont const ShaderGenerator& shadergen = context.getShaderGenerator(); const MdlShaderGenerator& shadergenMdl = static_cast(shadergen); + const MdlSyntax& mdlSyntax = static_cast(shadergen.getSyntax()); // Emit the function call for upstream surface shader. const ShaderNode* surfaceshaderNode = surfaceshaderInput->getConnection()->getNode(); @@ -50,8 +52,7 @@ void MaterialNodeMdl::emitFunctionCall(const ShaderNode& _node, GenContext& cont for (ShaderInput* input : node.getInputs()) { shadergen.emitString(delim, stage); - shadergen.emitString("mxp_", stage); - shadergen.emitString(input->getName(), stage); + shadergen.emitString(mdlSyntax.modifyPortName(input->getName()), stage); shadergen.emitString(": ", stage); shadergen.emitInput(input, context, stage); delim = ", "; diff --git a/source/MaterialXGenMdl/Nodes/SourceCodeNodeMdl.cpp b/source/MaterialXGenMdl/Nodes/SourceCodeNodeMdl.cpp index f237a0db54..5547ae309d 100644 --- a/source/MaterialXGenMdl/Nodes/SourceCodeNodeMdl.cpp +++ b/source/MaterialXGenMdl/Nodes/SourceCodeNodeMdl.cpp @@ -13,51 +13,28 @@ #include #include +#include + #include MATERIALX_NAMESPACE_BEGIN -namespace // anonymous -{ -const string MARKER_MDL_VERSION_SUFFIX = "MDL_VERSION_SUFFIX"; - -StringVec replaceSourceCodeMarkers(const string& nodeName, const string& soureCode, std::function lambda) +ShaderNodeImplPtr SourceCodeNodeMdl::create() { - // An inline function call - // Replace tokens of the format "{{}}" - static const string prefix("{{"); - static const string postfix("}}"); - - size_t pos = 0; - size_t i = soureCode.find_first_of(prefix); - StringVec code; - while (i != string::npos) - { - code.push_back(soureCode.substr(pos, i - pos)); - size_t j = soureCode.find_first_of(postfix, i + 2); - if (j == string::npos) - { - throw ExceptionShaderGenError("Malformed inline expression in implementation for node " + nodeName); - } - const string marker = soureCode.substr(i + 2, j - i - 2); - code.push_back(lambda(marker)); - pos = j + 2; - i = soureCode.find_first_of(prefix, pos); - } - code.push_back(soureCode.substr(pos)); - return code; + return std::make_shared(); } -} // anonymous namespace - -ShaderNodeImplPtr SourceCodeNodeMdl::create() +void SourceCodeNodeMdl::resolveSourceCode(const InterfaceElement& /*element*/, GenContext& /*context*/) { - return std::make_shared(); + // Initialize without fetching the source code from file. + // The resolution of MDL modules is done by the MDL compiler when loading the generated source code. + // All references MDL modules must be accessible via MDL search paths set up by the consuming application. } void SourceCodeNodeMdl::initialize(const InterfaceElement& element, GenContext& context) { SourceCodeNode::initialize(element, context); + const MdlSyntax& syntax = static_cast(context.getShaderGenerator().getSyntax()); const Implementation& impl = static_cast(element); NodeDefPtr nodeDef = impl.getNodeDef(); @@ -77,11 +54,10 @@ void SourceCodeNodeMdl::initialize(const InterfaceElement& element, GenContext& const ShaderGenerator& shadergen = context.getShaderGenerator(); const MdlShaderGenerator& shadergenMdl = static_cast(shadergen); const string versionSuffix = shadergenMdl.getMdlVersionFilenameSuffix(context); - StringVec code = replaceSourceCodeMarkers(getName(), functionName, [&versionSuffix](const string& marker) + functionName = syntax.replaceSourceCodeMarkers(getName(), functionName, [&versionSuffix, syntax](const string& marker) { - return marker == MARKER_MDL_VERSION_SUFFIX ? versionSuffix : EMPTY_STRING; + return marker == syntax.getMdlVersionSuffixMarker() ? versionSuffix : EMPTY_STRING; }); - functionName = std::accumulate(code.begin(), code.end(), EMPTY_STRING); _returnStruct = functionName + "__result"; } else @@ -103,12 +79,13 @@ void SourceCodeNodeMdl::emitFunctionCall(const ShaderNode& node, GenContext& con const MdlShaderGenerator& shadergenMdl = static_cast(shadergen); if (_inlined) { + const MdlSyntax& syntax = static_cast(shadergenMdl.getSyntax()); const string versionSuffix = shadergenMdl.getMdlVersionFilenameSuffix(context); - StringVec code = replaceSourceCodeMarkers(node.getName(), _functionSource, - [&shadergenMdl, &context, &node, &versionSuffix](const string& marker) + string code = syntax.replaceSourceCodeMarkers(node.getName(), _functionSource, + [&shadergenMdl, &context, &node, &versionSuffix, syntax](const string& marker) { // Special handling for the version suffix of MDL source code modules. - if (marker == MARKER_MDL_VERSION_SUFFIX) + if (marker == syntax.getMdlVersionSuffixMarker()) { return versionSuffix; } @@ -131,7 +108,7 @@ void SourceCodeNodeMdl::emitFunctionCall(const ShaderNode& node, GenContext& con // Emit the struct multioutput. const string resultVariableName = node.getName() + "_result"; shadergen.emitLineBegin(stage); - shadergen.emitString(_returnStruct + " " + resultVariableName + " = ", stage); + shadergen.emitString("auto " + resultVariableName + " = ", stage); } else { @@ -141,10 +118,7 @@ void SourceCodeNodeMdl::emitFunctionCall(const ShaderNode& node, GenContext& con shadergen.emitString(" = ", stage); } - for (const string& c : code) - { - shadergen.emitString(c, stage); - } + shadergen.emitString(code, stage); shadergen.emitLineEnd(stage); } else @@ -156,7 +130,7 @@ void SourceCodeNodeMdl::emitFunctionCall(const ShaderNode& node, GenContext& con // Emit the struct multioutput. const string resultVariableName = node.getName() + "_result"; shadergen.emitLineBegin(stage); - shadergen.emitString(_returnStruct + " " + resultVariableName + " = ", stage); + shadergen.emitString("auto " + resultVariableName + " = ", stage); } else { diff --git a/source/MaterialXGenMdl/Nodes/SourceCodeNodeMdl.h b/source/MaterialXGenMdl/Nodes/SourceCodeNodeMdl.h index 23c47e6249..7f595704cc 100644 --- a/source/MaterialXGenMdl/Nodes/SourceCodeNodeMdl.h +++ b/source/MaterialXGenMdl/Nodes/SourceCodeNodeMdl.h @@ -25,6 +25,7 @@ class MX_GENMDL_API SourceCodeNodeMdl : public SourceCodeNode void emitFunctionCall(const ShaderNode& node, GenContext& context, ShaderStage& stage) const override; protected: + void resolveSourceCode(const InterfaceElement& element, GenContext& context) override; string _returnStruct; }; diff --git a/source/MaterialXGenMdl/Nodes/SurfaceNodeMdl.cpp b/source/MaterialXGenMdl/Nodes/SurfaceNodeMdl.cpp index dd0444d129..8904d835f4 100644 --- a/source/MaterialXGenMdl/Nodes/SurfaceNodeMdl.cpp +++ b/source/MaterialXGenMdl/Nodes/SurfaceNodeMdl.cpp @@ -6,6 +6,7 @@ #include #include +#include #include @@ -55,6 +56,7 @@ void SurfaceNodeMdl::emitFunctionCall(const ShaderNode& node, GenContext& contex DEFINE_SHADER_STAGE(stage, Stage::PIXEL) { const MdlShaderGenerator& shadergen = static_cast(context.getShaderGenerator()); + const MdlSyntax& mdlSyntax = static_cast(shadergen.getSyntax()); // Emit calls for the closure dependencies upstream from this node. shadergen.emitDependentFunctionCalls(node, context, stage, ShaderNode::Classification::CLOSURE); @@ -84,8 +86,7 @@ void SurfaceNodeMdl::emitFunctionCall(const ShaderNode& node, GenContext& contex for (ShaderInput* input : node.getInputs()) { shadergen.emitString(delim, stage); - shadergen.emitString("mxp_", stage); - shadergen.emitString(input->getName(), stage); + shadergen.emitString(mdlSyntax.modifyPortName(input->getName()), stage); shadergen.emitString(": ", stage); shadergen.emitInput(input, context, stage); delim = ", "; diff --git a/source/MaterialXGenShader/Nodes/SourceCodeNode.cpp b/source/MaterialXGenShader/Nodes/SourceCodeNode.cpp index 936edf9b46..32fc345354 100644 --- a/source/MaterialXGenShader/Nodes/SourceCodeNode.cpp +++ b/source/MaterialXGenShader/Nodes/SourceCodeNode.cpp @@ -25,6 +25,20 @@ ShaderNodeImplPtr SourceCodeNode::create() return std::make_shared(); } +void SourceCodeNode::resolveSourceCode(const InterfaceElement& element, GenContext& context) +{ + const Implementation& impl = static_cast(element); + + FilePath localPath = FilePath(impl.getActiveSourceUri()).getParentPath(); + _sourceFilename = context.resolveSourceFile(impl.getAttribute("file"), localPath); + _functionSource = readFile(_sourceFilename); + if (_functionSource.empty()) + { + throw ExceptionShaderGenError("Failed to get source code from file '" + _sourceFilename.asString() + + "' used by implementation '" + impl.getName() + "'"); + } +} + void SourceCodeNode::initialize(const InterfaceElement& element, GenContext& context) { ShaderNodeImpl::initialize(element, context); @@ -40,19 +54,13 @@ void SourceCodeNode::initialize(const InterfaceElement& element, GenContext& con _functionSource = impl.getAttribute("sourcecode"); if (_functionSource.empty()) { - FilePath localPath = FilePath(impl.getActiveSourceUri()).getParentPath(); - _sourceFilename = context.resolveSourceFile(impl.getAttribute("file"), localPath); - _functionSource = readFile(_sourceFilename); - if (_functionSource.empty()) - { - throw ExceptionShaderGenError("Failed to get source code from file '" + _sourceFilename.asString() + - "' used by implementation '" + impl.getName() + "'"); - } + resolveSourceCode(element, context); } // Find the function name to use // If no function is given the source will be inlined. _functionName = impl.getAttribute("function"); + _inlined = _functionName.empty(); if (!_inlined) { diff --git a/source/MaterialXGenShader/Nodes/SourceCodeNode.h b/source/MaterialXGenShader/Nodes/SourceCodeNode.h index a208185a8d..6169dc0f61 100644 --- a/source/MaterialXGenShader/Nodes/SourceCodeNode.h +++ b/source/MaterialXGenShader/Nodes/SourceCodeNode.h @@ -26,6 +26,9 @@ class MX_GENSHADER_API SourceCodeNode : public ShaderNodeImpl void emitFunctionCall(const ShaderNode& node, GenContext& context, ShaderStage& stage) const override; protected: + /// Resolve the source file and read the source code during the initialization of the node. + virtual void resolveSourceCode(const InterfaceElement& element, GenContext& context); + bool _inlined; string _functionName; string _functionSource; diff --git a/source/MaterialXTest/MaterialXGenMdl/GenMdl.cpp b/source/MaterialXTest/MaterialXGenMdl/GenMdl.cpp index 4bec0016de..610b237163 100644 --- a/source/MaterialXTest/MaterialXGenMdl/GenMdl.cpp +++ b/source/MaterialXTest/MaterialXGenMdl/GenMdl.cpp @@ -270,7 +270,7 @@ void MdlShaderGeneratorTester::compileSource(const std::vector& so CHECK(returnValue == 0); } - if (!renderExec.empty()) // render if renderer is availabe + if (!renderExec.empty()) // render if renderer is available { std::string renderCommand = renderExec;