From 35659186fed199e072466c7d1b7f06a481e245f8 Mon Sep 17 00:00:00 2001 From: chris_bednarski Date: Tue, 18 Jul 2023 12:35:27 +1000 Subject: [PATCH] add firewall extension decompiler, make msi modifications work, add service name, edge traversal and interface type attributes --- src/ext/Firewall/ca/firewall.cpp | 780 +++++++++++++----- .../FirewallExtensionFixture.cs | 43 +- .../UsingFirewall/PackageComponents.wxs | 13 +- src/ext/Firewall/wixext/FirewallCompiler.cs | 253 +++++- src/ext/Firewall/wixext/FirewallConstants.cs | 14 +- src/ext/Firewall/wixext/FirewallDecompiler.cs | 233 ++++-- .../wixext/FirewallExtensionFactory.cs | 3 +- .../wixext/FirewallTableDefinitions.cs | 12 +- .../Symbols/WixFirewallExceptionSymbol.cs | 34 +- src/ext/Iis/ca/scacertexec.cpp | 6 +- src/ext/Util/wixext/UtilCompiler.cs | 13 +- 11 files changed, 1097 insertions(+), 307 deletions(-) diff --git a/src/ext/Firewall/ca/firewall.cpp b/src/ext/Firewall/ca/firewall.cpp index b45cbcdd7..0776788db 100644 --- a/src/ext/Firewall/ca/firewall.cpp +++ b/src/ext/Firewall/ca/firewall.cpp @@ -3,12 +3,34 @@ #include "precomp.h" LPCWSTR vcsFirewallExceptionQuery = - L"SELECT `Name`, `RemoteAddresses`, `Port`, `Protocol`, `Program`, `Attributes`, `Profile`, `Component_`, `Description`, `Direction` FROM `Wix4FirewallException`"; -enum eFirewallExceptionQuery { feqName = 1, feqRemoteAddresses, feqPort, feqProtocol, feqProgram, feqAttributes, feqProfile, feqComponent, feqDescription, feqDirection }; + L"SELECT `Name`, `RemoteAddresses`, `Port`, `Protocol`, `Program`, `Attributes`, `Profile`, `Component_`, `Description`, `Direction`, `Service`, `InterfaceTypes`, `EdgeTraversal`, `RemotePort` FROM `Wix4FirewallException`"; +enum eFirewallExceptionQuery { feqName = 1, feqRemoteAddresses, feqPort, feqProtocol, feqProgram, feqAttributes, feqProfile, feqComponent, feqDescription, feqDirection, feqService, feqInterfaceTypes, feqEdgeTraversal, feqRemotePort }; enum eFirewallExceptionTarget { fetPort = 1, fetApplication, fetUnknown }; -enum eFirewallExceptionAttributes { feaIgnoreFailures = 1 }; +enum eFirewallExceptionAttributes { feaIgnoreFailures = 1, feaOverwriteOnChange = 2, feaEnabled = 4, feaEnabledCustomized = 8 }; -/****************************************************************** +#ifndef DISABLE_PROFILES +#define DISABLE_PROFILES FALSE +#endif // !DISABLE_PROFILES + +struct FIREWALL_EXCEPTION_ATTRIBUTES +{ + LPWSTR pwzName; + + LPWSTR pwzRemoteAddresses; + LPWSTR pwzPort; + int iProtocol; + LPWSTR pwzProgram; + int iAttributes; + int iProfile; + LPWSTR pwzDescription; + int iDirection; + LPWSTR pwzService; + LPWSTR pwzInterfaceTypes; + int iEdgeTraversal; + LPWSTR pwzRemotePort; +}; + +/******************************************************************* SchedFirewallExceptions - immediate custom action worker to register and remove firewall exceptions. @@ -26,17 +48,9 @@ static UINT SchedFirewallExceptions( PMSIHANDLE hRec = NULL; LPWSTR pwzCustomActionData = NULL; - LPWSTR pwzName = NULL; - LPWSTR pwzRemoteAddresses = NULL; - LPWSTR pwzPort = NULL; - int iProtocol = 0; - int iAttributes = 0; - int iProfile = 0; - LPWSTR pwzProgram = NULL; LPWSTR pwzComponent = NULL; - LPWSTR pwzFormattedFile = NULL; - LPWSTR pwzDescription = NULL; - int iDirection = MSI_NULL_INTEGER; + + FIREWALL_EXCEPTION_ATTRIBUTES attrs = { 0 }; // initialize hr = WcaInitialize(hInstall, "SchedFirewallExceptions"); @@ -55,36 +69,61 @@ static UINT SchedFirewallExceptions( while (S_OK == (hr = WcaFetchRecord(hView, &hRec))) { - hr = WcaGetRecordFormattedString(hRec, feqName, &pwzName); + ReleaseNullStr(pwzComponent); + ReleaseNullStr(attrs.pwzName); + ReleaseNullStr(attrs.pwzRemoteAddresses); + ReleaseNullStr(attrs.pwzPort); + ReleaseNullStr(attrs.pwzProgram); + ReleaseNullStr(attrs.pwzDescription); + ReleaseNullStr(attrs.pwzService); + ReleaseNullStr(attrs.pwzInterfaceTypes); + ReleaseNullStr(attrs.pwzRemotePort); + + attrs.iDirection = MSI_NULL_INTEGER; + attrs.iEdgeTraversal = MSI_NULL_INTEGER; + + hr = WcaGetRecordFormattedString(hRec, feqName, &attrs.pwzName); ExitOnFailure(hr, "Failed to get firewall exception name."); - hr = WcaGetRecordFormattedString(hRec, feqRemoteAddresses, &pwzRemoteAddresses); + hr = WcaGetRecordFormattedString(hRec, feqRemoteAddresses, &attrs.pwzRemoteAddresses); ExitOnFailure(hr, "Failed to get firewall exception remote addresses."); - hr = WcaGetRecordFormattedString(hRec, feqPort, &pwzPort); + hr = WcaGetRecordFormattedString(hRec, feqPort, &attrs.pwzPort); ExitOnFailure(hr, "Failed to get firewall exception port."); - hr = WcaGetRecordInteger(hRec, feqProtocol, &iProtocol); + hr = WcaGetRecordInteger(hRec, feqProtocol, &attrs.iProtocol); ExitOnFailure(hr, "Failed to get firewall exception protocol."); - hr = WcaGetRecordFormattedString(hRec, feqProgram, &pwzProgram); + hr = WcaGetRecordFormattedString(hRec, feqProgram, &attrs.pwzProgram); ExitOnFailure(hr, "Failed to get firewall exception program."); - hr = WcaGetRecordInteger(hRec, feqAttributes, &iAttributes); + hr = WcaGetRecordInteger(hRec, feqAttributes, &attrs.iAttributes); ExitOnFailure(hr, "Failed to get firewall exception attributes."); - hr = WcaGetRecordInteger(hRec, feqProfile, &iProfile); + hr = WcaGetRecordInteger(hRec, feqProfile, &attrs.iProfile); ExitOnFailure(hr, "Failed to get firewall exception profile."); hr = WcaGetRecordString(hRec, feqComponent, &pwzComponent); ExitOnFailure(hr, "Failed to get firewall exception component."); - hr = WcaGetRecordString(hRec, feqDescription, &pwzDescription); + hr = WcaGetRecordFormattedString(hRec, feqDescription, &attrs.pwzDescription); ExitOnFailure(hr, "Failed to get firewall exception description."); - hr = WcaGetRecordInteger(hRec, feqDirection, &iDirection); + hr = WcaGetRecordInteger(hRec, feqDirection, &attrs.iDirection); ExitOnFailure(hr, "Failed to get firewall exception direction."); + hr = WcaGetRecordFormattedString(hRec, feqService, &attrs.pwzService); + ExitOnFailure(hr, "Failed to get firewall exception service."); + + hr = WcaGetRecordString(hRec, feqInterfaceTypes, &attrs.pwzInterfaceTypes); + ExitOnFailure(hr, "Failed to get firewall exception interface types."); + + hr = WcaGetRecordInteger(hRec, feqEdgeTraversal, &attrs.iEdgeTraversal); + ExitOnFailure(hr, "Failed to get firewall exception edge traversal."); + + hr = WcaGetRecordFormattedString(hRec, feqRemotePort, &attrs.pwzRemotePort); + ExitOnFailure(hr, "Failed to get firewall exception remote port."); + // figure out what we're doing for this exception, treating reinstall the same as install WCA_TODO todoComponent = WcaGetComponentToDo(pwzComponent); if ((WCA_TODO_REINSTALL == todoComponent ? WCA_TODO_INSTALL : todoComponent) != todoSched) @@ -93,30 +132,29 @@ static UINT SchedFirewallExceptions( continue; } - // action :: name :: profile :: remoteaddresses :: attributes :: target :: {port::protocol | path} ++cFirewallExceptions; hr = WcaWriteIntegerToCaData(todoComponent, &pwzCustomActionData); ExitOnFailure(hr, "failed to write exception action to custom action data"); - hr = WcaWriteStringToCaData(pwzName, &pwzCustomActionData); + hr = WcaWriteStringToCaData(attrs.pwzName, &pwzCustomActionData); ExitOnFailure(hr, "failed to write exception name to custom action data"); - hr = WcaWriteIntegerToCaData(iProfile, &pwzCustomActionData); + hr = WcaWriteIntegerToCaData(attrs.iProfile, &pwzCustomActionData); ExitOnFailure(hr, "failed to write exception profile to custom action data"); - hr = WcaWriteStringToCaData(pwzRemoteAddresses, &pwzCustomActionData); + hr = WcaWriteStringToCaData(attrs.pwzRemoteAddresses, &pwzCustomActionData); ExitOnFailure(hr, "failed to write exception remote addresses to custom action data"); - hr = WcaWriteIntegerToCaData(iAttributes, &pwzCustomActionData); + hr = WcaWriteIntegerToCaData(attrs.iAttributes, &pwzCustomActionData); ExitOnFailure(hr, "failed to write exception attributes to custom action data"); - if (*pwzProgram) + if (*attrs.pwzProgram) { // If program is defined, we have an application exception. hr = WcaWriteIntegerToCaData(fetApplication, &pwzCustomActionData); ExitOnFailure(hr, "failed to write exception target (application) to custom action data"); - hr = WcaWriteStringToCaData(pwzProgram, &pwzCustomActionData); + hr = WcaWriteStringToCaData(attrs.pwzProgram, &pwzCustomActionData); ExitOnFailure(hr, "failed to write application path to custom action data"); } else @@ -126,24 +164,36 @@ static UINT SchedFirewallExceptions( ExitOnFailure(hr, "failed to write exception target (port) to custom action data"); } - hr = WcaWriteStringToCaData(pwzPort, &pwzCustomActionData); + hr = WcaWriteStringToCaData(attrs.pwzPort, &pwzCustomActionData); ExitOnFailure(hr, "failed to write application path to custom action data"); - hr = WcaWriteIntegerToCaData(iProtocol, &pwzCustomActionData); + hr = WcaWriteIntegerToCaData(attrs.iProtocol, &pwzCustomActionData); ExitOnFailure(hr, "failed to write exception protocol to custom action data"); - hr = WcaWriteStringToCaData(pwzDescription, &pwzCustomActionData); + hr = WcaWriteStringToCaData(attrs.pwzDescription, &pwzCustomActionData); ExitOnFailure(hr, "failed to write firewall rule description to custom action data"); - hr = WcaWriteIntegerToCaData(iDirection, &pwzCustomActionData); + hr = WcaWriteIntegerToCaData(attrs.iDirection, &pwzCustomActionData); ExitOnFailure(hr, "failed to write firewall rule direction to custom action data"); + + hr = WcaWriteStringToCaData(attrs.pwzService, &pwzCustomActionData); + ExitOnFailure(hr, "failed to write firewall rule service to custom action data"); + + hr = WcaWriteStringToCaData(attrs.pwzInterfaceTypes, &pwzCustomActionData); + ExitOnFailure(hr, "failed to write firewall rule interface types to custom action data"); + + hr = WcaWriteIntegerToCaData(attrs.iEdgeTraversal, &pwzCustomActionData); + ExitOnFailure(hr, "failed to write firewall rule edge traversal to custom action data"); + + hr = WcaWriteStringToCaData(attrs.pwzRemotePort, &pwzCustomActionData); + ExitOnFailure(hr, "failed to write exception source (port) to custom action data"); } // reaching the end of the list is actually a good thing, not an error if (E_NOMOREITEMS == hr) { hr = S_OK; - } + } ExitOnFailure(hr, "failure occured while processing Wix4FirewallException table"); // schedule ExecFirewallExceptions if there's anything to do @@ -172,14 +222,16 @@ static UINT SchedFirewallExceptions( } LExit: - ReleaseStr(pwzCustomActionData); - ReleaseStr(pwzName); - ReleaseStr(pwzRemoteAddresses); - ReleaseStr(pwzPort); - ReleaseStr(pwzProgram); - ReleaseStr(pwzComponent); - ReleaseStr(pwzDescription); - ReleaseStr(pwzFormattedFile); + ReleaseNullStr(attrs.pwzName); + ReleaseNullStr(attrs.pwzRemoteAddresses); + ReleaseNullStr(attrs.pwzPort); + ReleaseNullStr(attrs.pwzProgram); + ReleaseNullStr(attrs.pwzDescription); + ReleaseNullStr(attrs.pwzService); + ReleaseNullStr(attrs.pwzInterfaceTypes); + ReleaseNullStr(attrs.pwzRemotePort); + ReleaseNullStr(pwzComponent); + ReleaseNullStr(pwzCustomActionData); return WcaFinalize(er = FAILED(hr) ? ERROR_INSTALL_FAILURE : er); } @@ -265,6 +317,176 @@ static HRESULT GetFirewallRules( return hr; } +/************************************************************************* + UpdateFwRuleObject - update the common set of properties which are shared + between port and application firewall rules + +**************************************************************************/ +static HRESULT UpdateFwRuleObject( + __in INetFwRule* pNetFwRule, + __in BOOL fOverwriteOnChange, + __in FIREWALL_EXCEPTION_ATTRIBUTES const* pAttrs + ) +{ + HRESULT hr = S_OK; + BSTR bstrPort = NULL; + BSTR bstrDescription = NULL; + BSTR bstrRemoteAddresses = NULL; + BSTR bstrServiceName = NULL; + BSTR bstrInterfaceTypes = NULL; + BSTR bstrRemotePort = NULL; + INetFwRule2* pNetFwRule2 = NULL; + BOOL fEnabled = feaEnabled == (pAttrs->iAttributes & feaEnabled); + BOOL fEnabledCustomized = feaEnabledCustomized == (pAttrs->iAttributes & feaEnabledCustomized); + + // convert to BSTRs to make COM happy + bstrPort = ::SysAllocString(pAttrs->pwzPort); + ExitOnNull(bstrPort, hr, E_OUTOFMEMORY, "failed SysAllocString for port update"); + bstrDescription = ::SysAllocString(pAttrs->pwzDescription); + ExitOnNull(bstrDescription, hr, E_OUTOFMEMORY, "failed SysAllocString for description update"); + bstrRemoteAddresses = ::SysAllocString(pAttrs->pwzRemoteAddresses); + ExitOnNull(bstrRemoteAddresses, hr, E_OUTOFMEMORY, "failed SysAllocString for remote addresses update"); + bstrServiceName = ::SysAllocString(pAttrs->pwzService); + ExitOnNull(bstrServiceName, hr, E_OUTOFMEMORY, "failed SysAllocString for service name update"); + bstrInterfaceTypes = ::SysAllocString(pAttrs->pwzInterfaceTypes); + ExitOnNull(bstrInterfaceTypes, hr, E_OUTOFMEMORY, "failed SysAllocString for interface types update"); + bstrRemotePort = ::SysAllocString(pAttrs->pwzRemotePort); + ExitOnNull(bstrRemotePort, hr, E_OUTOFMEMORY, "failed SysAllocString for remote port update"); + + // change it or enable it (just in case it was disabled) for backwards compatibilty + VARIANT_BOOL fEnable = fOverwriteOnChange ? (fEnabledCustomized && FALSE == fEnabled ? VARIANT_FALSE : VARIANT_TRUE) : VARIANT_TRUE; + hr = pNetFwRule->put_Enabled(fEnable); + ExitOnFailure(hr, "failed to re%s the exception rule", VARIANT_TRUE == fEnable ? "enable" : "disable"); + + if (fOverwriteOnChange) + { + // If you are editing a TCP port rule and converting it into an ICMP rule, + // first delete the port, change protocol from TCP to ICMP, and then add the rule. + hr = pNetFwRule->put_LocalPorts(NULL); + ExitOnFailure(hr, "failed to update exception rule local ports to NULL"); + + hr = pNetFwRule->put_RemotePorts(NULL); + ExitOnFailure(hr, "failed to update exception rule remote ports to NULL"); + + // The Protocol property must be set before the LocalPorts or RemotePorts properties or an error will be returned. + if (MSI_NULL_INTEGER != pAttrs->iProtocol) + { + hr = pNetFwRule->put_Protocol(static_cast (pAttrs->iProtocol)); + ExitOnFailure(hr, "failed to update exception rule protocol"); + } + else if(bstrPort && *bstrPort) + { + // default protocol is "TCP" in the WiX firewall compiler if the port is specified + hr = pNetFwRule->put_Protocol(NET_FW_IP_PROTOCOL_TCP); + ExitOnFailure(hr, "failed to reset exception rule protocol to TCP"); + } + else + { + hr = pNetFwRule->put_Protocol(NET_FW_IP_PROTOCOL_ANY); + ExitOnFailure(hr, "failed to reset exception rule protocol to ANY"); + } + + if (bstrPort && *bstrPort) + { + hr = pNetFwRule->put_LocalPorts(bstrPort); + ExitOnFailure(hr, "failed to update exception rule port"); + } + + if (bstrDescription && *bstrDescription) + { + hr = pNetFwRule->put_Description(bstrDescription); + ExitOnFailure(hr, "failed to update exception rule description '%ls'", bstrDescription); + } + else + { + hr = pNetFwRule->put_Description(NULL); + ExitOnFailure(hr, "failed to remove exception rule description"); + } + + if (MSI_NULL_INTEGER != pAttrs->iDirection) + { + hr = pNetFwRule->put_Direction(static_cast (pAttrs->iDirection)); + ExitOnFailure(hr, "failed to update exception rule direction"); + } + else + { + // If this property is not specified, the default value is in. + hr = pNetFwRule->put_Direction(NET_FW_RULE_DIR_IN); + ExitOnFailure(hr, "failed to reset exception rule direction"); + } + + if (bstrRemoteAddresses && *bstrRemoteAddresses) + { + hr = pNetFwRule->put_RemoteAddresses(bstrRemoteAddresses); + ExitOnFailure(hr, "failed to update exception rule remote addresses '%ls'", bstrRemoteAddresses); + } + else + { + hr = pNetFwRule->put_RemoteAddresses(NULL); + ExitOnFailure(hr, "failed to remove exception rule remote addresses"); + } + + if (bstrServiceName && *bstrServiceName) + { + hr = pNetFwRule->put_ServiceName(bstrServiceName); + ExitOnFailure(hr, "failed to update exception rule service name"); + } + else + { + hr = pNetFwRule->put_ServiceName(NULL); + ExitOnFailure(hr, "failed to remove exception rule service name"); + } + + if (bstrInterfaceTypes && *bstrInterfaceTypes) + { + hr = pNetFwRule->put_InterfaceTypes(bstrInterfaceTypes); + ExitOnFailure(hr, "failed to update exception rule interface types"); + } + else + { + hr = pNetFwRule->put_InterfaceTypes(NULL); + ExitOnFailure(hr, "failed to remove exception rule interface types"); + } + + if (MSI_NULL_INTEGER != pAttrs->iEdgeTraversal) + { + if (SUCCEEDED(pNetFwRule->QueryInterface(__uuidof(INetFwRule2), (void**)&pNetFwRule2))) + { + hr = pNetFwRule2->put_EdgeTraversalOptions(pAttrs->iEdgeTraversal); + ExitOnFailure(hr, "failed to set exception rule edge traversal options property"); + } + else + { + hr = pNetFwRule->put_EdgeTraversal(NET_FW_EDGE_TRAVERSAL_TYPE_DENY != pAttrs->iEdgeTraversal ? VARIANT_TRUE : VARIANT_FALSE); + ExitOnFailure(hr, "failed to set exception rule edge traversal property"); + } + } + else + { + // New rules have the EdgeTraversal property disabled by default. + hr = pNetFwRule->put_EdgeTraversal(VARIANT_FALSE); + ExitOnFailure(hr, "failed to reset exception rule edge traversal property"); + } + + if (bstrRemotePort && *bstrRemotePort) + { + hr = pNetFwRule->put_RemotePorts(bstrRemotePort); + ExitOnFailure(hr, "failed to update exception rule remote port"); + } + } + +LExit: + ReleaseBSTR(bstrPort); + ReleaseBSTR(bstrDescription); + ReleaseBSTR(bstrRemoteAddresses); + ReleaseBSTR(bstrServiceName); + ReleaseBSTR(bstrInterfaceTypes); + ReleaseBSTR(bstrRemotePort); + ReleaseObject(pNetFwRule2); + + return hr; +} + /****************************************************************** CreateFwRuleObject - CoCreate a firewall rule, and set the common set of properties which are shared between port and application firewall rules @@ -272,29 +494,36 @@ static HRESULT GetFirewallRules( ********************************************************************/ static HRESULT CreateFwRuleObject( __in BSTR bstrName, - __in int iProfile, - __in_opt LPCWSTR wzRemoteAddresses, - __in LPCWSTR wzPort, - __in int iProtocol, - __in LPCWSTR wzDescription, - __in int iDirection, + __in FIREWALL_EXCEPTION_ATTRIBUTES const* pAttrs, __out INetFwRule** ppNetFwRule - ) +) { HRESULT hr = S_OK; BSTR bstrRemoteAddresses = NULL; BSTR bstrPort = NULL; BSTR bstrDescription = NULL; + BSTR bstrServiceName = NULL; + BSTR bstrInterfaceTypes = NULL; + BSTR bstrRemotePort = NULL; INetFwRule* pNetFwRule = NULL; + INetFwRule2* pNetFwRule2 = NULL; *ppNetFwRule = NULL; + BOOL fEnabled = feaEnabled == (pAttrs->iAttributes & feaEnabled); + BOOL fEnabledCustomized = feaEnabledCustomized == (pAttrs->iAttributes & feaEnabledCustomized); // convert to BSTRs to make COM happy - bstrRemoteAddresses = ::SysAllocString(wzRemoteAddresses); + bstrRemoteAddresses = ::SysAllocString(pAttrs->pwzRemoteAddresses); ExitOnNull(bstrRemoteAddresses, hr, E_OUTOFMEMORY, "failed SysAllocString for remote addresses"); - bstrPort = ::SysAllocString(wzPort); + bstrPort = ::SysAllocString(pAttrs->pwzPort); ExitOnNull(bstrPort, hr, E_OUTOFMEMORY, "failed SysAllocString for port"); - bstrDescription = ::SysAllocString(wzDescription); + bstrDescription = ::SysAllocString(pAttrs->pwzDescription); ExitOnNull(bstrDescription, hr, E_OUTOFMEMORY, "failed SysAllocString for description"); + bstrServiceName = ::SysAllocString(pAttrs->pwzService); + ExitOnNull(bstrServiceName, hr, E_OUTOFMEMORY, "failed SysAllocString for service name"); + bstrInterfaceTypes = ::SysAllocString(pAttrs->pwzInterfaceTypes); + ExitOnNull(bstrInterfaceTypes, hr, E_OUTOFMEMORY, "failed SysAllocString for interface types"); + bstrRemotePort = ::SysAllocString(pAttrs->pwzRemotePort); + ExitOnNull(bstrPort, hr, E_OUTOFMEMORY, "failed SysAllocString for remote port"); hr = ::CoCreateInstance(__uuidof(NetFwRule), NULL, CLSCTX_ALL, __uuidof(INetFwRule), (void**)&pNetFwRule); ExitOnFailure(hr, "failed to create NetFwRule object"); @@ -302,12 +531,12 @@ static HRESULT CreateFwRuleObject( hr = pNetFwRule->put_Name(bstrName); ExitOnFailure(hr, "failed to set exception name"); - hr = pNetFwRule->put_Profiles(static_cast(iProfile)); + hr = pNetFwRule->put_Profiles(static_cast (pAttrs->iProfile)); ExitOnFailure(hr, "failed to set exception profile"); - if (MSI_NULL_INTEGER != iProtocol) + if (MSI_NULL_INTEGER != pAttrs->iProtocol) { - hr = pNetFwRule->put_Protocol(static_cast(iProtocol)); + hr = pNetFwRule->put_Protocol(static_cast (pAttrs->iProtocol)); ExitOnFailure(hr, "failed to set exception protocol"); } @@ -329,12 +558,47 @@ static HRESULT CreateFwRuleObject( ExitOnFailure(hr, "failed to set exception description '%ls'", bstrDescription); } - if (MSI_NULL_INTEGER != iDirection) + if (MSI_NULL_INTEGER != pAttrs->iDirection) { - hr = pNetFwRule->put_Direction(static_cast (iDirection)); + hr = pNetFwRule->put_Direction(static_cast (pAttrs->iDirection)); ExitOnFailure(hr, "failed to set exception direction"); } + if (bstrServiceName && *bstrServiceName) + { + hr = pNetFwRule->put_ServiceName(bstrServiceName); + ExitOnFailure(hr, "failed to set exception service name"); + } + + if (bstrInterfaceTypes && *bstrInterfaceTypes) + { + hr = pNetFwRule->put_InterfaceTypes(bstrInterfaceTypes); + ExitOnFailure(hr, "failed to set exception interface types"); + } + + if (MSI_NULL_INTEGER != pAttrs->iEdgeTraversal) + { + if (SUCCEEDED(pNetFwRule->QueryInterface(__uuidof(INetFwRule2), (void**)&pNetFwRule2))) + { + hr = pNetFwRule2->put_EdgeTraversalOptions(pAttrs->iEdgeTraversal); + ExitOnFailure(hr, "failed to set exception edge traversal options property"); + } + else + { + hr = pNetFwRule->put_EdgeTraversal(NET_FW_EDGE_TRAVERSAL_TYPE_DENY != pAttrs->iEdgeTraversal ? VARIANT_TRUE : VARIANT_FALSE); + ExitOnFailure(hr, "failed to set exception edge traversal property"); + } + } + + if (bstrRemotePort && *bstrRemotePort) + { + hr = pNetFwRule->put_RemotePorts(bstrRemotePort); + ExitOnFailure(hr, "failed to set exception remote port"); + } + + hr = pNetFwRule->put_Enabled(fEnabledCustomized && FALSE == fEnabled ? VARIANT_FALSE : VARIANT_TRUE); + ExitOnFailure(hr, "failed to enable exception"); + *ppNetFwRule = pNetFwRule; pNetFwRule = NULL; @@ -342,7 +606,11 @@ static HRESULT CreateFwRuleObject( ReleaseBSTR(bstrRemoteAddresses); ReleaseBSTR(bstrPort); ReleaseBSTR(bstrDescription); + ReleaseBSTR(bstrServiceName); + ReleaseBSTR(bstrInterfaceTypes); + ReleaseBSTR(bstrRemotePort); ReleaseObject(pNetFwRule); + ReleaseObject(pNetFwRule2); return hr; } @@ -354,18 +622,23 @@ static HRESULT CreateFwRuleObject( ********************************************************************/ static BOOL FSupportProfiles() { +#if defined(DISABLE_PROFILES) && (DISABLE_PROFILES != FALSE) + return FALSE; +#else + BOOL fSupportProfiles = FALSE; INetFwRules* pNetFwRules = NULL; // We only support profiles if we can co-create an instance of NetFwPolicy2. // This will not work on pre-vista machines. - if (SUCCEEDED(GetFirewallRules(TRUE, &pNetFwRules)) && pNetFwRules != NULL) + if (SUCCEEDED(GetFirewallRules(TRUE, &pNetFwRules)) && NULL != pNetFwRules) { fSupportProfiles = TRUE; ReleaseObject(pNetFwRules); } return fSupportProfiles; +#endif } /****************************************************************** @@ -436,15 +709,9 @@ static HRESULT GetCurrentFirewallProfile( ********************************************************************/ static HRESULT AddApplicationException( - __in LPCWSTR wzFile, - __in LPCWSTR wzName, - __in int iProfile, - __in_opt LPCWSTR wzRemoteAddresses, + __in FIREWALL_EXCEPTION_ATTRIBUTES const* pAttrs, __in BOOL fIgnoreFailures, - __in LPCWSTR wzPort, - __in int iProtocol, - __in LPCWSTR wzDescription, - __in int iDirection + __in BOOL fOverwriteOnChange ) { HRESULT hr = S_OK; @@ -454,9 +721,9 @@ static HRESULT AddApplicationException( INetFwRule* pNetFwRule = NULL; // convert to BSTRs to make COM happy - bstrFile = ::SysAllocString(wzFile); + bstrFile = ::SysAllocString(pAttrs->pwzProgram); ExitOnNull(bstrFile, hr, E_OUTOFMEMORY, "failed SysAllocString for path"); - bstrName = ::SysAllocString(wzName); + bstrName = ::SysAllocString(pAttrs->pwzName); ExitOnNull(bstrName, hr, E_OUTOFMEMORY, "failed SysAllocString for name"); // get the collection of firewall rules @@ -471,21 +738,13 @@ static HRESULT AddApplicationException( hr = pNetFwRules->Item(bstrName, &pNetFwRule); if (HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND) == hr) { - hr = CreateFwRuleObject(bstrName, iProfile, wzRemoteAddresses, wzPort, iProtocol, wzDescription, iDirection, &pNetFwRule); + hr = CreateFwRuleObject(bstrName, pAttrs, &pNetFwRule); ExitOnFailure(hr, "failed to create FwRule object"); - // set edge traversal to true - hr = pNetFwRule->put_EdgeTraversal(VARIANT_TRUE); - ExitOnFailure(hr, "failed to set application exception edgetraversal property"); - // set path hr = pNetFwRule->put_ApplicationName(bstrFile); ExitOnFailure(hr, "failed to set application name"); - - // enable it - hr = pNetFwRule->put_Enabled(VARIANT_TRUE); - ExitOnFailure(hr, "failed to to enable application exception"); - + // add it to the list of authorized apps hr = pNetFwRules->Add(pNetFwRule); ExitOnFailure(hr, "failed to add app to the authorized apps list"); @@ -494,9 +753,16 @@ static HRESULT AddApplicationException( { // we found an existing app exception (if we succeeded, that is) ExitOnFailure(hr, "failed trying to find existing app"); - - // enable it (just in case it was disabled) - pNetFwRule->put_Enabled(VARIANT_TRUE); + + hr = UpdateFwRuleObject(pNetFwRule, fOverwriteOnChange, pAttrs); + ExitOnFailure(hr, "failed to update an existing app"); + + if (fOverwriteOnChange) + { + // set path + hr = pNetFwRule->put_ApplicationName(bstrFile); + ExitOnFailure(hr, "failed to update application name"); + } } LExit: @@ -513,10 +779,9 @@ static HRESULT AddApplicationException( ********************************************************************/ static HRESULT AddApplicationExceptionOnCurrentProfile( - __in LPCWSTR wzFile, - __in LPCWSTR wzName, - __in_opt LPCWSTR wzRemoteAddresses, - __in BOOL fIgnoreFailures + __in FIREWALL_EXCEPTION_ATTRIBUTES const* pAttrs, + __in BOOL fIgnoreFailures, + __in BOOL fOverwriteOnChange ) { HRESULT hr = S_OK; @@ -527,12 +792,15 @@ static HRESULT AddApplicationExceptionOnCurrentProfile( INetFwAuthorizedApplications* pfwApps = NULL; INetFwAuthorizedApplication* pfwApp = NULL; + BOOL fEnabled = feaEnabled == (pAttrs->iAttributes & feaEnabled); + BOOL fEnabledCustomized = feaEnabledCustomized == (pAttrs->iAttributes & feaEnabledCustomized); + // convert to BSTRs to make COM happy - bstrFile = ::SysAllocString(wzFile); + bstrFile = ::SysAllocString(pAttrs->pwzProgram); ExitOnNull(bstrFile, hr, E_OUTOFMEMORY, "failed SysAllocString for path"); - bstrName = ::SysAllocString(wzName); + bstrName = ::SysAllocString(pAttrs->pwzName); ExitOnNull(bstrName, hr, E_OUTOFMEMORY, "failed SysAllocString for name"); - bstrRemoteAddresses = ::SysAllocString(wzRemoteAddresses); + bstrRemoteAddresses = ::SysAllocString(pAttrs->pwzRemoteAddresses); ExitOnNull(bstrRemoteAddresses, hr, E_OUTOFMEMORY, "failed SysAllocString for remote addresses"); // get the firewall profile, which is our entry point for adding exceptions @@ -552,7 +820,7 @@ static HRESULT AddApplicationExceptionOnCurrentProfile( if (HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND) == hr) { // not found, so we get to add it - hr = ::CoCreateInstance(__uuidof(NetFwAuthorizedApplication), NULL, CLSCTX_INPROC_SERVER, __uuidof(INetFwAuthorizedApplication), reinterpret_cast(&pfwApp)); + hr = ::CoCreateInstance(__uuidof(NetFwAuthorizedApplication), NULL, CLSCTX_INPROC_SERVER, __uuidof(INetFwAuthorizedApplication), reinterpret_cast(&pfwApp)); ExitOnFailure(hr, "failed to create authorized app"); // set the display name @@ -570,6 +838,10 @@ static HRESULT AddApplicationExceptionOnCurrentProfile( ExitOnFailure(hr, "failed to set authorized app remote addresses"); } + VARIANT_BOOL fEnable = fEnabledCustomized && FALSE == fEnabled ? VARIANT_FALSE : VARIANT_TRUE; + hr = pfwApp->put_Enabled(fEnable); + ExitOnFailure(hr, "failed to %s authorized app", VARIANT_TRUE == fEnable ? "enable" : "disable"); + // add it to the list of authorized apps hr = pfwApps->Add(pfwApp); ExitOnFailure(hr, "failed to add app to the authorized apps list"); @@ -579,8 +851,29 @@ static HRESULT AddApplicationExceptionOnCurrentProfile( // we found an existing app exception (if we succeeded, that is) ExitOnFailure(hr, "failed trying to find existing app"); - // enable it (just in case it was disabled) - pfwApp->put_Enabled(VARIANT_TRUE); + // change it or enable it (just in case it was disabled, for backwards compatibilty) + VARIANT_BOOL fEnable = fOverwriteOnChange ? (fEnabledCustomized && FALSE == fEnabled ? VARIANT_FALSE : VARIANT_TRUE) : VARIANT_TRUE; + hr = pfwApp->put_Enabled(fEnable); + ExitOnFailure(hr, "failed to re%s authorized app", VARIANT_TRUE == fEnable ? "enable" : "disable"); + + if (fOverwriteOnChange) + { + // set path + hr = pfwApp->put_ProcessImageFileName(bstrFile); + ExitOnFailure(hr, "failed to update authorized app path"); + + // update the allowed remote addresses + if (bstrRemoteAddresses && *bstrRemoteAddresses) + { + hr = pfwApp->put_RemoteAddresses(bstrRemoteAddresses); + ExitOnFailure(hr, "failed to update authorized app remote addresses"); + } + else + { + hr = pfwApp->put_RemoteAddresses(NULL); + ExitOnFailure(hr, "failed to remove authorized app remote addresses"); + } + } } LExit: @@ -599,15 +892,10 @@ static HRESULT AddApplicationExceptionOnCurrentProfile( ********************************************************************/ static HRESULT AddPortException( - __in LPCWSTR wzName, - __in int iProfile, - __in_opt LPCWSTR wzRemoteAddresses, + __in FIREWALL_EXCEPTION_ATTRIBUTES const* pAttrs, __in BOOL fIgnoreFailures, - __in LPCWSTR wzPort, - __in int iProtocol, - __in LPCWSTR wzDescription, - __in int iDirection -) + __in BOOL fOverwriteOnChange + ) { HRESULT hr = S_OK; BSTR bstrName = NULL; @@ -615,7 +903,7 @@ static HRESULT AddPortException( INetFwRule* pNetFwRule = NULL; // convert to BSTRs to make COM happy - bstrName = ::SysAllocString(wzName); + bstrName = ::SysAllocString(pAttrs->pwzName); ExitOnNull(bstrName, hr, E_OUTOFMEMORY, "failed SysAllocString for name"); // get the collection of firewall rules @@ -630,13 +918,9 @@ static HRESULT AddPortException( hr = pNetFwRules->Item(bstrName, &pNetFwRule); if (HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND) == hr) { - hr = CreateFwRuleObject(bstrName, iProfile, wzRemoteAddresses, wzPort, iProtocol, wzDescription, iDirection, &pNetFwRule); + hr = CreateFwRuleObject(bstrName, pAttrs, &pNetFwRule); ExitOnFailure(hr, "failed to create FwRule object"); - // enable it - hr = pNetFwRule->put_Enabled(VARIANT_TRUE); - ExitOnFailure(hr, "failed to to enable port exception"); - // add it to the list of authorized ports hr = pNetFwRules->Add(pNetFwRule); ExitOnFailure(hr, "failed to add app to the authorized ports list"); @@ -646,8 +930,9 @@ static HRESULT AddPortException( // we found an existing port exception (if we succeeded, that is) ExitOnFailure(hr, "failed trying to find existing port rule"); - // enable it (just in case it was disabled) - pNetFwRule->put_Enabled(VARIANT_TRUE); + // overwrite attributes of the existing port exception + hr = UpdateFwRuleObject(pNetFwRule, fOverwriteOnChange, pAttrs); + ExitOnFailure(hr, "failed to update an existing port rule"); } LExit: @@ -663,46 +948,29 @@ static HRESULT AddPortException( ********************************************************************/ static HRESULT AddPortExceptionOnCurrentProfile( - __in LPCWSTR wzName, - __in_opt LPCWSTR wzRemoteAddresses, + __in FIREWALL_EXCEPTION_ATTRIBUTES const* pAttrs, __in BOOL fIgnoreFailures, - __in int iPort, - __in int iProtocol + __in BOOL fOverwriteOnChange ) { HRESULT hr = S_OK; BSTR bstrName = NULL; BSTR bstrRemoteAddresses = NULL; + LONG lPortNumber = wcstol(pAttrs->pwzPort, NULL, 10); INetFwProfile* pfwProfile = NULL; INetFwOpenPorts* pfwPorts = NULL; INetFwOpenPort* pfwPort = NULL; + BOOL fEnabled = feaEnabled == (pAttrs->iAttributes & feaEnabled); + BOOL fEnabledCustomized = feaEnabledCustomized == (pAttrs->iAttributes & feaEnabledCustomized); + // convert to BSTRs to make COM happy - bstrName = ::SysAllocString(wzName); + bstrName = ::SysAllocString(pAttrs->pwzName); ExitOnNull(bstrName, hr, E_OUTOFMEMORY, "failed SysAllocString for name"); - bstrRemoteAddresses = ::SysAllocString(wzRemoteAddresses); + bstrRemoteAddresses = ::SysAllocString(pAttrs->pwzRemoteAddresses); ExitOnNull(bstrRemoteAddresses, hr, E_OUTOFMEMORY, "failed SysAllocString for remote addresses"); - // create and initialize a new open port object - hr = ::CoCreateInstance(__uuidof(NetFwOpenPort), NULL, CLSCTX_INPROC_SERVER, __uuidof(INetFwOpenPort), reinterpret_cast(&pfwPort)); - ExitOnFailure(hr, "failed to create new open port"); - - hr = pfwPort->put_Port(iPort); - ExitOnFailure(hr, "failed to set exception port"); - - hr = pfwPort->put_Protocol(static_cast(iProtocol)); - ExitOnFailure(hr, "failed to set exception protocol"); - - if (bstrRemoteAddresses && *bstrRemoteAddresses) - { - hr = pfwPort->put_RemoteAddresses(bstrRemoteAddresses); - ExitOnFailure(hr, "failed to set exception remote addresses '%ls'", bstrRemoteAddresses); - } - - hr = pfwPort->put_Name(bstrName); - ExitOnFailure(hr, "failed to set exception name"); - - // get the firewall profile, its current list of open ports, and add ours + // get the firewall profile, which is our entry point for adding exceptions hr = GetCurrentFirewallProfile(fIgnoreFailures, &pfwProfile); ExitOnFailure(hr, "failed to get firewall profile"); if (S_FALSE == hr) // user or package author chose to ignore missing firewall @@ -710,11 +978,88 @@ static HRESULT AddPortExceptionOnCurrentProfile( ExitFunction(); } + // first, let's see if the port is already on the exception list hr = pfwProfile->get_GloballyOpenPorts(&pfwPorts); ExitOnFailure(hr, "failed to get open ports"); - hr = pfwPorts->Add(pfwPort); - ExitOnFailure(hr, "failed to add exception to global list"); + // try to find it (i.e., support reinstall) + hr = pfwPorts->Item(lPortNumber, static_cast (pAttrs->iProtocol), &pfwPort); + if (HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND) == hr) + { + // not found, so we get to add it + hr = ::CoCreateInstance(__uuidof(NetFwOpenPort), NULL, CLSCTX_INPROC_SERVER, __uuidof(INetFwOpenPort), reinterpret_cast(&pfwPort)); + ExitOnFailure(hr, "failed to create new open port"); + + hr = pfwPort->put_Name(bstrName); + ExitOnFailure(hr, "failed to set port exception name"); + + hr = pfwPort->put_Protocol(static_cast (pAttrs->iProtocol)); + ExitOnFailure(hr, "failed to set port exception protocol"); + + hr = pfwPort->put_Port(lPortNumber); + ExitOnFailure(hr, "failed to set port exception port"); + + if (bstrRemoteAddresses && *bstrRemoteAddresses) + { + hr = pfwPort->put_RemoteAddresses(bstrRemoteAddresses); + ExitOnFailure(hr, "failed to set port exception remote addresses '%ls'", bstrRemoteAddresses); + } + + VARIANT_BOOL fEnable = fEnabledCustomized && FALSE == fEnabled ? VARIANT_FALSE : VARIANT_TRUE; + hr = pfwPort->put_Enabled(fEnable); + ExitOnFailure(hr, "failed to %s port exception", VARIANT_TRUE == fEnable ? "enable" : "disable"); + + // add it to the list of authorized ports + hr = pfwPorts->Add(pfwPort); + ExitOnFailure(hr, "failed to add port exception to global list"); + } + else + { + // we found an existing port exception (if we succeeded, that is) + ExitOnFailure(hr, "failed trying to find existing port"); + + // change it or enable it (just in case it was disabled, for backwards compatibilty) + VARIANT_BOOL fEnable = fOverwriteOnChange ? (fEnabledCustomized && FALSE == fEnabled ? VARIANT_FALSE : VARIANT_TRUE) : VARIANT_TRUE; + hr = pfwPort->put_Enabled(fEnable); + ExitOnFailure(hr, "failed to re%s existing port", VARIANT_TRUE == fEnable ? "enable" : "disable"); + + if (fOverwriteOnChange) + { + hr = pfwPort->put_Name(bstrName); + ExitOnFailure(hr, "failed to update port exception name"); + + // If you are editing a TCP port rule and converting it into an ICMP rule, + // first delete the port, change protocol from TCP to ICMP, and then add the rule. + hr = pfwPort->put_Port(0); + ExitOnFailure(hr, "failed to update port exception port to NULL"); + + // The Protocol property must be set before the LocalPorts or RemotePorts properties or an error will be returned. + if (MSI_NULL_INTEGER != pAttrs->iProtocol) + { + hr = pfwPort->put_Protocol(static_cast (pAttrs->iProtocol)); + ExitOnFailure(hr, "failed to update port exception protocol"); + } + else + { + hr = pfwPort->put_Protocol(NET_FW_IP_PROTOCOL_ANY); + ExitOnFailure(hr, "failed to reset port exception protocol to ANY"); + } + + hr = pfwPort->put_Port(lPortNumber); + ExitOnFailure(hr, "failed to update port exception port"); + + if (bstrRemoteAddresses && *bstrRemoteAddresses) + { + hr = pfwPort->put_RemoteAddresses(bstrRemoteAddresses); + ExitOnFailure(hr, "failed to update port exception remote addresses '%ls'", bstrRemoteAddresses); + } + else + { + hr = pfwPort->put_RemoteAddresses(NULL); + ExitOnFailure(hr, "failed to remove port exception remote addresses"); + } + } + } LExit: ReleaseBSTR(bstrRemoteAddresses); @@ -825,7 +1170,7 @@ static HRESULT RemovePortExceptionFromCurrentProfile( hr = pfwProfile->get_GloballyOpenPorts(&pfwPorts); ExitOnFailure(hr, "failed to get open ports"); - hr = pfwPorts->Remove(iPort, static_cast(iProtocol)); + hr = pfwPorts->Remove(iPort, static_cast (iProtocol)); ExitOnFailure(hr, "failed to remove open port %d, protocol %d", iPort, iProtocol); LExit: @@ -834,36 +1179,36 @@ static HRESULT RemovePortExceptionFromCurrentProfile( static HRESULT AddApplicationException( __in BOOL fSupportProfiles, - __in LPCWSTR wzFile, - __in LPCWSTR wzName, - __in int iProfile, - __in_opt LPCWSTR wzRemoteAddresses, + __in FIREWALL_EXCEPTION_ATTRIBUTES const* pAttrs, __in BOOL fIgnoreFailures, - __in LPCWSTR wzPort, - __in int iProtocol, - __in LPCWSTR wzDescription, - __in int iDirection -) + __in BOOL fOverwriteOnChange + ) { HRESULT hr = S_OK; if (fSupportProfiles) { - hr = AddApplicationException(wzFile, wzName, iProfile, wzRemoteAddresses, fIgnoreFailures, wzPort, iProtocol, wzDescription, iDirection); + hr = AddApplicationException(pAttrs, fIgnoreFailures, fOverwriteOnChange); } else { - if (0 != *wzPort || MSI_NULL_INTEGER != iProtocol) + if ((pAttrs->pwzPort && NULL != *pAttrs->pwzPort) || MSI_NULL_INTEGER != pAttrs->iProtocol ) { // NOTE: This is treated as an error rather than either creating a rule based on just the application (no port), or // just the port because it is unclear what is the proper fall back. For example, suppose that you have code that // runs in dllhost.exe. Clearly falling back to opening all of dllhost is wrong. Because the firewall is a security // feature, it seems better to require the MSI author to indicate the behavior that they want. - WcaLog(LOGMSG_STANDARD, "FirewallExtension: Cannot add firewall rule '%ls', which defines both an application and a port or protocol. Such a rule requires Microsoft Windows Vista or later.", wzName); + WcaLog(LOGMSG_STANDARD, "FirewallExtension: Cannot add firewall rule '%ls', which defines both an application and a port or protocol. Such a rule requires Microsoft Windows Vista or later.", pAttrs->pwzName); return fIgnoreFailures ? S_OK : E_NOTIMPL; } - hr = AddApplicationExceptionOnCurrentProfile(wzFile, wzName, wzRemoteAddresses, fIgnoreFailures); + if ((pAttrs->pwzService && NULL != *pAttrs->pwzService) || (pAttrs->pwzInterfaceTypes && NULL != *pAttrs->pwzInterfaceTypes) || (pAttrs->pwzRemotePort && NULL != *pAttrs->pwzRemotePort) || MSI_NULL_INTEGER != pAttrs->iEdgeTraversal) + { + WcaLog(LOGMSG_STANDARD, "FirewallExtension: Cannot add firewall rule '%ls', which defines Service, InterfaceTypes, RemotePort or EdgeTraversal. Such a rule requires Microsoft Windows Vista or later.", pAttrs->pwzName); + return fIgnoreFailures ? S_OK : E_NOTIMPL; + } + + hr = AddApplicationExceptionOnCurrentProfile(pAttrs, fIgnoreFailures, fOverwriteOnChange); } return hr; @@ -871,25 +1216,36 @@ static HRESULT AddApplicationException( static HRESULT AddPortException( __in BOOL fSupportProfiles, - __in LPCWSTR wzName, - __in int iProfile, - __in_opt LPCWSTR wzRemoteAddresses, + __in FIREWALL_EXCEPTION_ATTRIBUTES const* pAttrs, __in BOOL fIgnoreFailures, - __in LPCWSTR wzPort, - __in int iProtocol, - __in LPCWSTR wzDescription, - __in int iDirection -) + __in BOOL fOverwriteOnChange + ) { HRESULT hr = S_OK; if (fSupportProfiles) { - hr = AddPortException(wzName, iProfile, wzRemoteAddresses, fIgnoreFailures, wzPort, iProtocol, wzDescription, iDirection); + hr = AddPortException(pAttrs, fIgnoreFailures, fOverwriteOnChange); } else { - hr = AddPortExceptionOnCurrentProfile(wzName, wzRemoteAddresses, fIgnoreFailures, wcstol(wzPort, NULL, 10), iProtocol); + if (pAttrs->pwzProgram && NULL != *pAttrs->pwzProgram) + { + // NOTE: This is treated as an error rather than either creating a rule based on just the port, or + // just the application because it is unclear what is the proper fall back. For example, suppose that you have code that + // runs in dllhost.exe. Clearly falling back to opening all of dllhost is wrong. Because the firewall is a security + // feature, it seems better to require the MSI author to indicate the behavior that they want. + WcaLog(LOGMSG_STANDARD, "FirewallExtension: Cannot add firewall rule '%ls', which defines both an application and a port or protocol. Such a rule requires Microsoft Windows Vista or later.", pAttrs->pwzName); + return fIgnoreFailures ? S_OK : E_NOTIMPL; + } + + if ((pAttrs->pwzService && NULL != *pAttrs->pwzService) || (pAttrs->pwzInterfaceTypes && NULL != *pAttrs->pwzInterfaceTypes) || (pAttrs->pwzRemotePort && NULL != *pAttrs->pwzRemotePort) || MSI_NULL_INTEGER != pAttrs->iEdgeTraversal) + { + WcaLog(LOGMSG_STANDARD, "FirewallExtension: Cannot add firewall rule '%ls', which defines Service, InterfaceTypes, RemotePort or EdgeTraversal. Such a rule requires Microsoft Windows Vista or later.", pAttrs->pwzName); + return fIgnoreFailures ? S_OK : E_NOTIMPL; + } + + hr = AddPortExceptionOnCurrentProfile(pAttrs, fIgnoreFailures, fOverwriteOnChange); } return hr; @@ -898,7 +1254,7 @@ static HRESULT AddPortException( static HRESULT RemoveApplicationException( __in BOOL fSupportProfiles, __in LPCWSTR wzName, - __in LPCWSTR wzFile, + __in LPCWSTR wzFile, __in BOOL fIgnoreFailures, __in LPCWSTR wzPort, __in int iProtocol @@ -912,10 +1268,10 @@ static HRESULT RemoveApplicationException( } else { - if (0 != *wzPort || MSI_NULL_INTEGER != iProtocol) + if ((wzPort && NULL != *wzPort) || MSI_NULL_INTEGER != iProtocol) { WcaLog(LOGMSG_STANDARD, "FirewallExtension: Cannot remove firewall rule '%ls', which defines both an application and a port or protocol. Such a rule requires Microsoft Windows Vista or later.", wzName); - return S_OK; + return fIgnoreFailures ? S_OK : E_NOTIMPL; } hr = RemoveApplicationExceptionFromCurrentProfile(wzFile, fIgnoreFailures); @@ -927,6 +1283,7 @@ static HRESULT RemoveApplicationException( static HRESULT RemovePortException( __in BOOL fSupportProfiles, __in LPCWSTR wzName, + __in_opt LPCWSTR wzFile, __in LPCWSTR wzPort, __in int iProtocol, __in BOOL fIgnoreFailures @@ -940,6 +1297,12 @@ static HRESULT RemovePortException( } else { + if (wzFile && NULL != *wzFile) + { + WcaLog(LOGMSG_STANDARD, "FirewallExtension: Cannot remove firewall rule '%ls', which defines both an application and a port or protocol. Such a rule requires Microsoft Windows Vista or later.", wzName); + return fIgnoreFailures ? S_OK : E_NOTIMPL; + } + hr = RemovePortExceptionFromCurrentProfile(wcstol(wzPort, NULL, 10), iProtocol, fIgnoreFailures); } @@ -960,16 +1323,9 @@ extern "C" UINT __stdcall ExecFirewallExceptions( LPWSTR pwz = NULL; LPWSTR pwzCustomActionData = NULL; int iTodo = WCA_TODO_UNKNOWN; - LPWSTR pwzName = NULL; - LPWSTR pwzRemoteAddresses = NULL; - int iAttributes = 0; int iTarget = fetUnknown; - LPWSTR pwzFile = NULL; - LPWSTR pwzPort = NULL; - LPWSTR pwzDescription = NULL; - int iProtocol = 0; - int iProfile = 0; - int iDirection = 0; + + FIREWALL_EXCEPTION_ATTRIBUTES attrs = { 0 }; // initialize hr = WcaInitialize(hInstall, "ExecFirewallExceptions"); @@ -983,12 +1339,25 @@ extern "C" UINT __stdcall ExecFirewallExceptions( ExitOnFailure(hr, "failed to initialize COM"); // Find out if we support profiles (only on Vista or later) + // can be disabled by #define DISABLE_PROFILES fSupportProfiles = FSupportProfiles(); // loop through all the passed in data pwz = pwzCustomActionData; while (pwz && *pwz) { + ReleaseNullStr(attrs.pwzName); + ReleaseNullStr(attrs.pwzRemoteAddresses); + ReleaseNullStr(attrs.pwzPort); + ReleaseNullStr(attrs.pwzProgram); + ReleaseNullStr(attrs.pwzDescription); + ReleaseNullStr(attrs.pwzService); + ReleaseNullStr(attrs.pwzInterfaceTypes); + ReleaseNullStr(attrs.pwzRemotePort); + + attrs.iDirection = MSI_NULL_INTEGER; + attrs.iEdgeTraversal = MSI_NULL_INTEGER; + // extract the custom action data and if rolling back, swap INSTALL and UNINSTALL hr = WcaReadIntegerFromCaData(&pwz, &iTodo); ExitOnFailure(hr, "failed to read todo from custom action data"); @@ -1004,36 +1373,45 @@ extern "C" UINT __stdcall ExecFirewallExceptions( } } - hr = WcaReadStringFromCaData(&pwz, &pwzName); + hr = WcaReadStringFromCaData(&pwz, &attrs.pwzName); ExitOnFailure(hr, "failed to read name from custom action data"); - hr = WcaReadIntegerFromCaData(&pwz, &iProfile); + hr = WcaReadIntegerFromCaData(&pwz, &attrs.iProfile); ExitOnFailure(hr, "failed to read profile from custom action data"); - hr = WcaReadStringFromCaData(&pwz, &pwzRemoteAddresses); + hr = WcaReadStringFromCaData(&pwz, &attrs.pwzRemoteAddresses); ExitOnFailure(hr, "failed to read remote addresses from custom action data"); - hr = WcaReadIntegerFromCaData(&pwz, &iAttributes); + hr = WcaReadIntegerFromCaData(&pwz, &attrs.iAttributes); ExitOnFailure(hr, "failed to read attributes from custom action data"); - BOOL fIgnoreFailures = feaIgnoreFailures == (iAttributes & feaIgnoreFailures); + BOOL fIgnoreFailures = feaIgnoreFailures == (attrs.iAttributes & feaIgnoreFailures); + BOOL fOverwriteOnChange = feaOverwriteOnChange == (attrs.iAttributes & feaOverwriteOnChange); hr = WcaReadIntegerFromCaData(&pwz, &iTarget); ExitOnFailure(hr, "failed to read target from custom action data"); if (iTarget == fetApplication) { - hr = WcaReadStringFromCaData(&pwz, &pwzFile); + hr = WcaReadStringFromCaData(&pwz, &attrs.pwzProgram); ExitOnFailure(hr, "failed to read file path from custom action data"); } - hr = WcaReadStringFromCaData(&pwz, &pwzPort); + hr = WcaReadStringFromCaData(&pwz, &attrs.pwzPort); ExitOnFailure(hr, "failed to read port from custom action data"); - hr = WcaReadIntegerFromCaData(&pwz, &iProtocol); + hr = WcaReadIntegerFromCaData(&pwz, &attrs.iProtocol); ExitOnFailure(hr, "failed to read protocol from custom action data"); - hr = WcaReadStringFromCaData(&pwz, &pwzDescription); + hr = WcaReadStringFromCaData(&pwz, &attrs.pwzDescription); ExitOnFailure(hr, "failed to read protocol from custom action data"); - hr = WcaReadIntegerFromCaData(&pwz, &iDirection); + hr = WcaReadIntegerFromCaData(&pwz, &attrs.iDirection); ExitOnFailure(hr, "failed to read direction from custom action data"); + hr = WcaReadStringFromCaData(&pwz, &attrs.pwzService); + ExitOnFailure(hr, "failed to read service name from custom action data"); + hr = WcaReadStringFromCaData(&pwz, &attrs.pwzInterfaceTypes); + ExitOnFailure(hr, "failed to read interface types from custom action data"); + hr = WcaReadIntegerFromCaData(&pwz, &attrs.iEdgeTraversal); + ExitOnFailure(hr, "failed to read edge traversal from custom action data"); + hr = WcaReadStringFromCaData(&pwz, &attrs.pwzRemotePort); + ExitOnFailure(hr, "failed to read remote port from custom action data"); switch (iTarget) { @@ -1042,15 +1420,15 @@ extern "C" UINT __stdcall ExecFirewallExceptions( { case WCA_TODO_INSTALL: case WCA_TODO_REINSTALL: - WcaLog(LOGMSG_STANDARD, "Installing firewall exception2 %ls on port %ls, protocol %d", pwzName, pwzPort, iProtocol); - hr = AddPortException(fSupportProfiles, pwzName, iProfile, pwzRemoteAddresses, fIgnoreFailures, pwzPort, iProtocol, pwzDescription, iDirection); - ExitOnFailure(hr, "failed to add/update port exception for name '%ls' on port %ls, protocol %d", pwzName, pwzPort, iProtocol); + WcaLog(LOGMSG_STANDARD, "Installing firewall exception %ls on port %ls, protocol %d", attrs.pwzName, attrs.pwzPort, attrs.iProtocol); + hr = AddPortException(fSupportProfiles , &attrs, fIgnoreFailures, fOverwriteOnChange); + ExitOnFailure(hr, "failed to add/update port exception for name '%ls' on port %ls, protocol %d, service '%ls'", attrs.pwzName, attrs.pwzPort, attrs.iProtocol, attrs.pwzService); break; case WCA_TODO_UNINSTALL: - WcaLog(LOGMSG_STANDARD, "Uninstalling firewall exception2 %ls on port %ls, protocol %d", pwzName, pwzPort, iProtocol); - hr = RemovePortException(fSupportProfiles, pwzName, pwzPort, iProtocol, fIgnoreFailures); - ExitOnFailure(hr, "failed to remove port exception for name '%ls' on port %ls, protocol %d", pwzName, pwzPort, iProtocol); + WcaLog(LOGMSG_STANDARD, "Uninstalling firewall exception %ls on port %ls, protocol %d", attrs.pwzName, attrs.pwzPort, attrs.iProtocol); + hr = RemovePortException(fSupportProfiles, attrs.pwzName, attrs.pwzProgram, attrs.pwzPort, attrs.iProtocol, fIgnoreFailures); + ExitOnFailure(hr, "failed to remove port exception for name '%ls' on port %ls, protocol %d", attrs.pwzName, attrs.pwzPort, attrs.iProtocol); break; } break; @@ -1060,15 +1438,15 @@ extern "C" UINT __stdcall ExecFirewallExceptions( { case WCA_TODO_INSTALL: case WCA_TODO_REINSTALL: - WcaLog(LOGMSG_STANDARD, "Installing firewall exception2 %ls (%ls)", pwzName, pwzFile); - hr = AddApplicationException(fSupportProfiles, pwzFile, pwzName, iProfile, pwzRemoteAddresses, fIgnoreFailures, pwzPort, iProtocol, pwzDescription, iDirection); - ExitOnFailure(hr, "failed to add/update application exception for name '%ls', file '%ls'", pwzName, pwzFile); + WcaLog(LOGMSG_STANDARD, "Installing firewall exception %ls (%ls)", attrs.pwzName, attrs.pwzProgram); + hr = AddApplicationException(fSupportProfiles, &attrs, fIgnoreFailures, fOverwriteOnChange); + ExitOnFailure(hr, "failed to add/update application exception for name '%ls', file '%ls', service '%ls'", attrs.pwzName, attrs.pwzProgram, attrs.pwzService); break; case WCA_TODO_UNINSTALL: - WcaLog(LOGMSG_STANDARD, "Uninstalling firewall exception2 %ls (%ls)", pwzName, pwzFile); - hr = RemoveApplicationException(fSupportProfiles, pwzName, pwzFile, fIgnoreFailures, pwzPort, iProtocol); - ExitOnFailure(hr, "failed to remove application exception for name '%ls', file '%ls'", pwzName, pwzFile); + WcaLog(LOGMSG_STANDARD, "Uninstalling firewall exception %ls (%ls)", attrs.pwzName, attrs.pwzProgram); + hr = RemoveApplicationException(fSupportProfiles, attrs.pwzName, attrs.pwzProgram, fIgnoreFailures, attrs.pwzPort, attrs.iProtocol); + ExitOnFailure(hr, "failed to remove application exception for name '%ls', file '%ls'", attrs.pwzName, attrs.pwzProgram); break; } break; @@ -1076,12 +1454,14 @@ extern "C" UINT __stdcall ExecFirewallExceptions( } LExit: - ReleaseStr(pwzCustomActionData); - ReleaseStr(pwzName); - ReleaseStr(pwzRemoteAddresses); - ReleaseStr(pwzFile); - ReleaseStr(pwzPort); - ReleaseStr(pwzDescription); + ReleaseNullStr(pwzCustomActionData); + ReleaseNullStr(attrs.pwzName); + ReleaseNullStr(attrs.pwzRemoteAddresses); + ReleaseNullStr(attrs.pwzProgram); + ReleaseNullStr(attrs.pwzPort); + ReleaseNullStr(attrs.pwzDescription); + ReleaseNullStr(attrs.pwzService); + ReleaseNullStr(attrs.pwzInterfaceTypes); ::CoUninitialize(); return WcaFinalize(FAILED(hr) ? ERROR_INSTALL_FAILURE : ERROR_SUCCESS); diff --git a/src/ext/Firewall/test/WixToolsetTest.Firewall/FirewallExtensionFixture.cs b/src/ext/Firewall/test/WixToolsetTest.Firewall/FirewallExtensionFixture.cs index b89afaf7a..ef29ade0b 100644 --- a/src/ext/Firewall/test/WixToolsetTest.Firewall/FirewallExtensionFixture.cs +++ b/src/ext/Firewall/test/WixToolsetTest.Firewall/FirewallExtensionFixture.cs @@ -7,6 +7,8 @@ namespace WixToolsetTest.Firewall using WixInternal.Core.TestPackage; using WixToolset.Firewall; using Xunit; + using System.IO; + using System.Xml.Linq; public class FirewallExtensionFixture { @@ -25,8 +27,10 @@ public void CanBuildUsingFirewall() "CustomAction:Wix4RollbackFirewallExceptionsUninstall_X86\t3329\tWix4FWCA_X86\tExecFirewallExceptions\t", "CustomAction:Wix4SchedFirewallExceptionsInstall_X86\t1\tWix4FWCA_X86\tSchedFirewallExceptionsInstall\t", "CustomAction:Wix4SchedFirewallExceptionsUninstall_X86\t1\tWix4FWCA_X86\tSchedFirewallExceptionsUninstall\t", - "Wix4FirewallException:ExampleFirewall\tExampleApp\t*\t42\t6\t[#filNdJBJmq3UCUIwmXS8x21aAsvqzk]\t0\t2147483647\tfilNdJBJmq3UCUIwmXS8x21aAsvqzk\tAn app-based firewall exception\t1", - "Wix4FirewallException:fex70IVsYNnbwiHQrEepmdTPKH8XYs\tExamplePort\tLocalSubnet\t42\t6\t\t0\t2147483647\tfilNdJBJmq3UCUIwmXS8x21aAsvqzk\tA port-based firewall exception\t2", + "Wix4FirewallException:ExampleFirewall\tExampleApp\t*\t42\t6\t[#filNdJBJmq3UCUIwmXS8x21aAsvqzk]\t4\t2147483647\tfilNdJBJmq3UCUIwmXS8x21aAsvqzk\tAn app-based firewall exception\t1\t\tAll\t\t", + "Wix4FirewallException:fex70IVsYNnbwiHQrEepmdTPKH8XYs\tExamplePort\tLocalSubnet\t42\t6\t\t5\t2147483647\tfilNdJBJmq3UCUIwmXS8x21aAsvqzk\tA port-based firewall exception\t2\tftpsrv\tLan\t\t", + "Wix4FirewallException:fexiVb_lnYx2.K.OSyNlgawFJVTqEw\tdefertouser\t\t\t\tfw.exe\t4\t\tfilNdJBJmq3UCUIwmXS8x21aAsvqzk\tDefer to user edge traversal\t1\t\t\t3\t", + "Wix4FirewallException:ServiceInstall.nested\tExamplePort\tLocalSubnet\t3546-7890\t6\t\t5\t2147483647\tfilNdJBJmq3UCUIwmXS8x21aAsvqzk\tA port-based firewall exception for a windows service\t1\tsvc1\tWireless,Lan,RemoteAccess\t1\t", }, results); } @@ -45,11 +49,37 @@ public void CanBuildUsingFirewallARM64() "CustomAction:Wix4RollbackFirewallExceptionsUninstall_A64\t3329\tWix4FWCA_A64\tExecFirewallExceptions\t", "CustomAction:Wix4SchedFirewallExceptionsInstall_A64\t1\tWix4FWCA_A64\tSchedFirewallExceptionsInstall\t", "CustomAction:Wix4SchedFirewallExceptionsUninstall_A64\t1\tWix4FWCA_A64\tSchedFirewallExceptionsUninstall\t", - "Wix4FirewallException:ExampleFirewall\tExampleApp\t*\t42\t6\t[#filNdJBJmq3UCUIwmXS8x21aAsvqzk]\t0\t2147483647\tfilNdJBJmq3UCUIwmXS8x21aAsvqzk\tAn app-based firewall exception\t1", - "Wix4FirewallException:fex70IVsYNnbwiHQrEepmdTPKH8XYs\tExamplePort\tLocalSubnet\t42\t6\t\t0\t2147483647\tfilNdJBJmq3UCUIwmXS8x21aAsvqzk\tA port-based firewall exception\t2", + "Wix4FirewallException:ExampleFirewall\tExampleApp\t*\t42\t6\t[#filNdJBJmq3UCUIwmXS8x21aAsvqzk]\t4\t2147483647\tfilNdJBJmq3UCUIwmXS8x21aAsvqzk\tAn app-based firewall exception\t1\t\tAll\t\t", + "Wix4FirewallException:fex70IVsYNnbwiHQrEepmdTPKH8XYs\tExamplePort\tLocalSubnet\t42\t6\t\t5\t2147483647\tfilNdJBJmq3UCUIwmXS8x21aAsvqzk\tA port-based firewall exception\t2\tftpsrv\tLan\t\t", + "Wix4FirewallException:fexiVb_lnYx2.K.OSyNlgawFJVTqEw\tdefertouser\t\t\t\tfw.exe\t4\t\tfilNdJBJmq3UCUIwmXS8x21aAsvqzk\tDefer to user edge traversal\t1\t\t\t3\t", + "Wix4FirewallException:ServiceInstall.nested\tExamplePort\tLocalSubnet\t3546-7890\t6\t\t5\t2147483647\tfilNdJBJmq3UCUIwmXS8x21aAsvqzk\tA port-based firewall exception for a windows service\t1\tsvc1\tWireless,Lan,RemoteAccess\t1\t", }, results); } + [Fact] + public void CanRoundtripFirewallExceptions() + { + var folder = TestData.Get(@"TestData", "UsingFirewall"); + var build = new Builder(folder, typeof(FirewallExtensionFactory), new[] { folder }); + var output = Path.Combine(folder, "FirewallExceptionDecompile.xml"); + + build.BuildAndDecompileAndBuild(Build, Decompile, output); + + var doc = XDocument.Load(output); + var firewallElementNames = doc.Descendants().Where(e => e.Name.Namespace == "http://wixtoolset.org/schemas/v4/wxs/firewall") + .Select(e => e.Name.LocalName) + .ToArray(); + + WixAssert.CompareLineByLine(new[] + { + "FirewallException", + "FirewallException", + "FirewallException", + "FirewallException", + }, firewallElementNames); + } + + private static void Build(string[] args) { var result = WixRunner.Execute(args); @@ -65,5 +95,10 @@ private static void BuildARM64(string[] args) var result = WixRunner.Execute(newArgs.ToArray()); result.AssertSuccess(); } + private static void Decompile(string[] args) + { + var result = WixRunner.Execute(args); + result.AssertSuccess(); + } } } diff --git a/src/ext/Firewall/test/WixToolsetTest.Firewall/TestData/UsingFirewall/PackageComponents.wxs b/src/ext/Firewall/test/WixToolsetTest.Firewall/TestData/UsingFirewall/PackageComponents.wxs index c712d8959..1e5b0ab96 100644 --- a/src/ext/Firewall/test/WixToolsetTest.Firewall/TestData/UsingFirewall/PackageComponents.wxs +++ b/src/ext/Firewall/test/WixToolsetTest.Firewall/TestData/UsingFirewall/PackageComponents.wxs @@ -1,17 +1,24 @@ + xmlns:fw="http://wixtoolset.org/schemas/v4/wxs/firewall" + xmlns:util="http://wixtoolset.org/schemas/v4/wxs/util"> - + - + + + + + + + diff --git a/src/ext/Firewall/wixext/FirewallCompiler.cs b/src/ext/Firewall/wixext/FirewallCompiler.cs index cbe82d37b..0d00f9032 100644 --- a/src/ext/Firewall/wixext/FirewallCompiler.cs +++ b/src/ext/Firewall/wixext/FirewallCompiler.cs @@ -15,7 +15,7 @@ namespace WixToolset.Firewall /// public sealed class FirewallCompiler : BaseCompilerExtension { - public override XNamespace Namespace => "http://wixtoolset.org/schemas/v4/wxs/firewall"; + public override XNamespace Namespace => FirewallConstants.Namespace; /// /// Processes an element for the Compiler. @@ -35,7 +35,7 @@ public override void ParseElement(Intermediate intermediate, IntermediateSection switch (element.Name.LocalName) { case "FirewallException": - this.ParseFirewallExceptionElement(intermediate, section, element, fileComponentId, fileId); + this.ParseFirewallExceptionElement(intermediate, section, parentElement, element, fileComponentId, fileId, null); break; default: this.ParseHelper.UnexpectedElement(parentElement, element); @@ -48,7 +48,35 @@ public override void ParseElement(Intermediate intermediate, IntermediateSection switch (element.Name.LocalName) { case "FirewallException": - this.ParseFirewallExceptionElement(intermediate, section, element, componentId, null); + this.ParseFirewallExceptionElement(intermediate, section, parentElement, element, componentId, null, null); + break; + default: + this.ParseHelper.UnexpectedElement(parentElement, element); + break; + } + break; + case "ServiceConfig": + var serviceConfigName = context["ServiceConfigServiceName"]; + var serviceConfigComponentId = context["ServiceConfigComponentId"]; + + switch (element.Name.LocalName) + { + case "FirewallException": + this.ParseFirewallExceptionElement(intermediate, section, parentElement, element, serviceConfigComponentId, null, serviceConfigName); + break; + default: + this.ParseHelper.UnexpectedElement(parentElement, element); + break; + } + break; + case "ServiceInstall": + var serviceInstallName = context["ServiceInstallName"]; + var serviceInstallComponentId = context["ServiceInstallComponentId"]; + + switch (element.Name.LocalName) + { + case "FirewallException": + this.ParseFirewallExceptionElement(intermediate, section, parentElement, element, serviceInstallComponentId, null, serviceInstallName); break; default: this.ParseHelper.UnexpectedElement(parentElement, element); @@ -64,24 +92,31 @@ public override void ParseElement(Intermediate intermediate, IntermediateSection /// /// Parses a FirewallException element. /// + /// The parent element of the one being parsed. /// The element to parse. /// Identifier of the component that owns this firewall exception. /// The file identifier of the parent element (null if nested under Component). - private void ParseFirewallExceptionElement(Intermediate intermediate, IntermediateSection section, XElement element, string componentId, string fileId) + /// The service name of the parent element (null if not nested under ServiceConfig or ServiceInstall). + private void ParseFirewallExceptionElement(Intermediate intermediate, IntermediateSection section, XElement parentElement, XElement element, string componentId, string fileId, string serviceName) { var sourceLineNumbers = this.ParseHelper.GetSourceLineNumbers(element); Identifier id = null; string name = null; - int attributes = 0; + int attributes = 0x4; // feaEnabled string file = null; string program = null; + string service = null; string port = null; int? protocol = null; + string protocolValue = null; int? profile = null; string scope = null; string remoteAddresses = null; string description = null; int? direction = null; + string interfaceTypes = null; + int? edgeTraversal = null; + string remotePort = null; foreach (var attrib in element.Attributes()) { @@ -98,7 +133,7 @@ private void ParseFirewallExceptionElement(Intermediate intermediate, Intermedia case "File": if (null != fileId) { - this.Messaging.Write(ErrorMessages.IllegalAttributeWhenNested(sourceLineNumbers, element.Name.LocalName, "File", "File")); + this.Messaging.Write(ErrorMessages.IllegalAttributeWhenNested(sourceLineNumbers, element.Name.LocalName, "File", parentElement.Name.LocalName)); } else { @@ -108,13 +143,13 @@ private void ParseFirewallExceptionElement(Intermediate intermediate, Intermedia case "IgnoreFailure": if (YesNoType.Yes == this.ParseHelper.GetAttributeYesNoValue(sourceLineNumbers, attrib)) { - attributes |= 0x1; // feaIgnoreFailures + attributes |= 0x1; // add feaIgnoreFailures } break; case "Program": if (null != fileId) { - this.Messaging.Write(ErrorMessages.IllegalAttributeWhenNested(sourceLineNumbers, element.Name.LocalName, "Program", "File")); + this.Messaging.Write(ErrorMessages.IllegalAttributeWhenNested(sourceLineNumbers, element.Name.LocalName, "Program", parentElement.Name.LocalName)); } else { @@ -125,7 +160,7 @@ private void ParseFirewallExceptionElement(Intermediate intermediate, Intermedia port = this.ParseHelper.GetAttributeValue(sourceLineNumbers, attrib); break; case "Protocol": - var protocolValue = this.ParseHelper.GetAttributeValue(sourceLineNumbers, attrib); + protocolValue = this.ParseHelper.GetAttributeValue(sourceLineNumbers, attrib); switch (protocolValue) { case "tcp": @@ -135,7 +170,12 @@ private void ParseFirewallExceptionElement(Intermediate intermediate, Intermedia protocol = FirewallConstants.NET_FW_IP_PROTOCOL_UDP; break; default: - this.Messaging.Write(ErrorMessages.IllegalAttributeValue(sourceLineNumbers, element.Name.LocalName, "Protocol", protocolValue, "tcp", "udp")); + int parsedProtocol; + if (!Int32.TryParse(protocolValue, out parsedProtocol) || parsedProtocol > 255 || parsedProtocol < 0) + { + this.Messaging.Write(ErrorMessages.IllegalAttributeValue(sourceLineNumbers, element.Name.LocalName, "Protocol", protocolValue, "tcp", "udp", "0-255")); + } + protocol = parsedProtocol; break; } break; @@ -149,8 +189,20 @@ private void ParseFirewallExceptionElement(Intermediate intermediate, Intermedia case "localSubnet": remoteAddresses = "LocalSubnet"; break; + case "DNS": + remoteAddresses = "dns"; + break; + case "DHCP": + remoteAddresses = "dhcp"; + break; + case "WINS": + remoteAddresses = "wins"; + break; + case "defaultGateway": + remoteAddresses = "DefaultGateway"; + break; default: - this.Messaging.Write(ErrorMessages.IllegalAttributeValue(sourceLineNumbers, element.Name.LocalName, "Scope", scope, "any", "localSubnet")); + this.Messaging.Write(ErrorMessages.IllegalAttributeValue(sourceLineNumbers, element.Name.LocalName, "Scope", scope, "any", "localSubnet", "DNS", "DHCP", "WINS", "defaultGateway")); break; } break; @@ -183,6 +235,64 @@ private void ParseFirewallExceptionElement(Intermediate intermediate, Intermedia ? FirewallConstants.NET_FW_RULE_DIR_OUT : FirewallConstants.NET_FW_RULE_DIR_IN; break; + case "OverwriteOnChange": + if (YesNoType.Yes == this.ParseHelper.GetAttributeYesNoValue(sourceLineNumbers, attrib)) + { + attributes |= 0x2; // add feaOverwriteOnChange + } + else + { + attributes &= ~0x2; // remove feaOverwriteOnChange + } + break; + case "Enabled": + attributes |= 0x8; // add feaEnabledCustomized + if (YesNoType.Yes == this.ParseHelper.GetAttributeYesNoValue(sourceLineNumbers, attrib)) + { + attributes |= 0x4; // add feaEnabled + } + else + { + attributes &= ~0x4; // remove feaEnabled + } + break; + case "Service": + if (null != serviceName) + { + this.Messaging.Write(ErrorMessages.IllegalAttributeWhenNested(sourceLineNumbers, element.Name.LocalName, "Service", parentElement.Name.LocalName)); + } + else + { + service = this.ParseHelper.GetAttributeValue(sourceLineNumbers, attrib); + } + break; + case "InterfaceTypes": + this.ParseInterfaceTypesElement(element, attrib, ref interfaceTypes); + break; + case "EdgeTraversal": + var edgeTraversalValue = this.ParseHelper.GetAttributeValue(sourceLineNumbers, attrib); + switch (edgeTraversalValue) + { + case "Deny": + edgeTraversal = FirewallConstants.NET_FW_EDGE_TRAVERSAL_TYPE_DENY; + break; + case "Allow": + edgeTraversal = FirewallConstants.NET_FW_EDGE_TRAVERSAL_TYPE_ALLOW; + break; + case "DeferToApp": + edgeTraversal = FirewallConstants.NET_FW_EDGE_TRAVERSAL_TYPE_DEFER_TO_APP; + break; + case "DeferToUser": + edgeTraversal = FirewallConstants.NET_FW_EDGE_TRAVERSAL_TYPE_DEFER_TO_USER; + break; + default: + this.Messaging.Write(ErrorMessages.IllegalAttributeValue(sourceLineNumbers, element.Name.LocalName, "EdgeTraversal", edgeTraversalValue, "Deny", "Allow", "DeferToApp", "DeferToUser")); + break; + } + break; + case "RemotePort": + remotePort = this.ParseHelper.GetAttributeValue(sourceLineNumbers, attrib); + break; default: this.ParseHelper.UnexpectedAttribute(element, attrib); break; @@ -227,14 +337,19 @@ private void ParseFirewallExceptionElement(Intermediate intermediate, Intermedia id = this.ParseHelper.CreateIdentifier("fex", name, remoteAddresses, componentId); } + if (null == service) + { + service = serviceName; + } + // Name is required if (null == name) { this.Messaging.Write(ErrorMessages.ExpectedAttribute(sourceLineNumbers, element.Name.LocalName, "Name")); } - // Scope or child RemoteAddress(es) are required - if (null == remoteAddresses) + // Scope or child RemoteAddress(es) are required, unless EdgeTraversal is set to DeferToUser + if (null == remoteAddresses && FirewallConstants.NET_FW_EDGE_TRAVERSAL_TYPE_DEFER_TO_USER != edgeTraversal) { this.Messaging.Write(ErrorMessages.ExpectedAttributeOrElement(sourceLineNumbers, element.Name.LocalName, "Scope", "RemoteAddress")); } @@ -251,6 +366,57 @@ private void ParseFirewallExceptionElement(Intermediate intermediate, Intermedia this.Messaging.Write(FirewallErrors.NoExceptionSpecified(sourceLineNumbers)); } + // Defer to user edge traversal setting can only be used in a firewall rule where program path and TCP/UDP protocol are specified with no additional conditions. + if (FirewallConstants.NET_FW_EDGE_TRAVERSAL_TYPE_DEFER_TO_USER == edgeTraversal) + { + if (protocol.HasValue && !(protocol == FirewallConstants.NET_FW_IP_PROTOCOL_TCP || protocol == FirewallConstants.NET_FW_IP_PROTOCOL_UDP)) + { + this.Messaging.Write(ErrorMessages.IllegalAttributeValueWithLegalList(sourceLineNumbers, element.Name.LocalName, "Protocol", protocolValue, "tcp,udp")); + } + + if (String.IsNullOrEmpty(fileId) && String.IsNullOrEmpty(file) && String.IsNullOrEmpty(program)) + { + this.Messaging.Write(ErrorMessages.ExpectedAttribute(sourceLineNumbers, element.Name.LocalName, "Program", "EdgeTraversal", "DeferToUser")); + } + + if (null != port) + { + this.Messaging.Write(ErrorMessages.IllegalAttributeWithOtherAttribute(sourceLineNumbers, element.Name.LocalName, "Port", "EdgeTraversal", "DeferToUser")); + } + + if (null != scope) + { + this.Messaging.Write(ErrorMessages.IllegalAttributeWithOtherAttribute(sourceLineNumbers, element.Name.LocalName, "Scope", "EdgeTraversal", "DeferToUser")); + } + + if (null != profile) + { + this.Messaging.Write(ErrorMessages.IllegalAttributeWithOtherAttribute(sourceLineNumbers, element.Name.LocalName, "Profile", "EdgeTraversal", "DeferToUser")); + } + + if (null != service) + { + if (null != serviceName) + { + this.Messaging.Write(ErrorMessages.IllegalAttributeValueWhenNested(sourceLineNumbers, element.Name.LocalName, "EdgeTraversal", "DeferToUser", parentElement.Name.LocalName)); + } + else + { + this.Messaging.Write(ErrorMessages.IllegalAttributeWithOtherAttribute(sourceLineNumbers, element.Name.LocalName, "Service", "EdgeTraversal", "DeferToUser")); + } + } + + if (null != interfaceTypes) + { + this.Messaging.Write(ErrorMessages.IllegalAttributeWithOtherAttribute(sourceLineNumbers, element.Name.LocalName, "InterfaceTypes", "EdgeTraversal", "DeferToUser")); + } + + if (null != remotePort) + { + this.Messaging.Write(ErrorMessages.IllegalAttributeWithOtherAttribute(sourceLineNumbers, element.Name.LocalName, "RemotePort", "EdgeTraversal", "DeferToUser")); + } + } + if (!this.Messaging.EncounteredError) { // at this point, File attribute and File parent element are treated the same @@ -263,12 +429,19 @@ private void ParseFirewallExceptionElement(Intermediate intermediate, Intermedia { Name = name, RemoteAddresses = remoteAddresses, - Profile = profile ?? FirewallConstants.NET_FW_PROFILE2_ALL, ComponentRef = componentId, Description = description, Direction = direction ?? FirewallConstants.NET_FW_RULE_DIR_IN, + Service = service, + InterfaceTypes = interfaceTypes, + RemotePort = remotePort, }); + if (FirewallConstants.NET_FW_EDGE_TRAVERSAL_TYPE_DEFER_TO_USER != edgeTraversal) + { + symbol.Profile = profile ?? FirewallConstants.NET_FW_PROFILE2_ALL; + } + if (!String.IsNullOrEmpty(port)) { symbol.Port = port; @@ -300,11 +473,63 @@ private void ParseFirewallExceptionElement(Intermediate intermediate, Intermedia symbol.Attributes = attributes; } + if (edgeTraversal.HasValue) + { + symbol.EdgeTraversal = edgeTraversal.Value; + } + this.ParseHelper.CreateCustomActionReference(sourceLineNumbers, section, "Wix4SchedFirewallExceptionsInstall", this.Context.Platform, CustomActionPlatforms.ARM64 | CustomActionPlatforms.X64 | CustomActionPlatforms.X86); this.ParseHelper.CreateCustomActionReference(sourceLineNumbers, section, "Wix4SchedFirewallExceptionsUninstall", this.Context.Platform, CustomActionPlatforms.ARM64 | CustomActionPlatforms.X64 | CustomActionPlatforms.X86); } } + /// + /// Parses an InterfaceTypes element + /// + /// The element to parse. + /// The attribute to parse. + private void ParseInterfaceTypesElement(XElement element, XAttribute attribute, ref string interfaceTypes) + { + var sourceLineNumbers = this.ParseHelper.GetSourceLineNumbers(element); + var interfaceTypeValue = this.ParseHelper.GetAttributeIntegerValue(sourceLineNumbers, attribute, 0, Int32.MaxValue); + + if (Int32.MaxValue == interfaceTypeValue) + { + interfaceTypes = "All"; + } + else + { + if (0x1 == (interfaceTypeValue & 0x1)) + { + interfaceTypes = "Wireless"; + } + + if (0x2 == (interfaceTypeValue & 0x2)) + { + if (String.IsNullOrEmpty(interfaceTypes)) + { + interfaceTypes = "Lan"; + } + else + { + interfaceTypes = String.Concat(interfaceTypes, ",", "Lan"); + } + } + + if (0x4 == (interfaceTypeValue & 0x4)) + { + if (String.IsNullOrEmpty(interfaceTypes)) + { + interfaceTypes = "RemoteAccess"; + } + else + { + interfaceTypes = String.Concat(interfaceTypes, ",", "RemoteAccess"); + } + } + } + } + /// /// Parses a RemoteAddress element /// diff --git a/src/ext/Firewall/wixext/FirewallConstants.cs b/src/ext/Firewall/wixext/FirewallConstants.cs index 7bb12ba47..0526a304d 100644 --- a/src/ext/Firewall/wixext/FirewallConstants.cs +++ b/src/ext/Firewall/wixext/FirewallConstants.cs @@ -2,12 +2,14 @@ namespace WixToolset.Firewall { - using System; - using System.Collections.Generic; - using System.Text; + using System.Xml.Linq; static class FirewallConstants { + internal static readonly XNamespace Namespace = "http://wixtoolset.org/schemas/v4/wxs/firewall"; + internal static readonly XName FirewallExceptionName = Namespace + "FirewallException"; + internal static readonly XName RemoteAddressName = Namespace + "RemoteAddress"; + // from icftypes.h public const int NET_FW_RULE_DIR_IN = 1; public const int NET_FW_RULE_DIR_OUT = 2; @@ -19,5 +21,11 @@ static class FirewallConstants public const int NET_FW_PROFILE2_PRIVATE = 0x0002; public const int NET_FW_PROFILE2_PUBLIC = 0x0004; public const int NET_FW_PROFILE2_ALL = 0x7FFFFFFF; + + // from icftypes.h + public const int NET_FW_EDGE_TRAVERSAL_TYPE_DENY = 0; + public const int NET_FW_EDGE_TRAVERSAL_TYPE_ALLOW = 1; + public const int NET_FW_EDGE_TRAVERSAL_TYPE_DEFER_TO_APP = 2; + public const int NET_FW_EDGE_TRAVERSAL_TYPE_DEFER_TO_USER = 3; } } diff --git a/src/ext/Firewall/wixext/FirewallDecompiler.cs b/src/ext/Firewall/wixext/FirewallDecompiler.cs index c9478de15..e5ed63ae9 100644 --- a/src/ext/Firewall/wixext/FirewallDecompiler.cs +++ b/src/ext/Firewall/wixext/FirewallDecompiler.cs @@ -2,54 +2,53 @@ namespace WixToolset.Firewall { -#if TODO_CONSIDER_DECOMPILER using System; - using System.Collections; - using System.Diagnostics; - using System.Globalization; + using System.Collections.Generic; + using System.Xml.Linq; using WixToolset.Data; + using WixToolset.Data.WindowsInstaller; using WixToolset.Extensibility; - using Firewall = WixToolset.Extensions.Serialize.Firewall; - using Wix = WixToolset.Data.Serialize; /// /// The decompiler for the WiX Toolset Firewall Extension. /// - public sealed class FirewallDecompiler : DecompilerExtension + public sealed class FirewallDecompiler : BaseWindowsInstallerDecompilerExtension { - /// - /// Creates a decompiler for Firewall Extension. - /// - public FirewallDecompiler() - { - this.TableDefinitions = FirewallExtensionData.GetExtensionTableDefinitions(); - } + public override IReadOnlyCollection TableDefinitions => FirewallTableDefinitions.All; /// - /// Get the extensions library to be removed. + /// Called at the beginning of the decompilation of a database. /// - /// Table definitions for library. - /// Library to remove from decompiled output. - public override Library GetLibraryToRemove(TableDefinitionCollection tableDefinitions) + /// The collection of all tables. + public override void PreDecompileTables(TableIndexedCollection tables) { - return FirewallExtensionData.GetExtensionLibrary(tableDefinitions); } /// /// Decompiles an extension table. /// /// The table to decompile. - public override void DecompileTable(Table table) + public override bool TryDecompileTable(Table table) { switch (table.Name) { - case "WixFirewallException": + case "Wix4FirewallException": this.DecompileWixFirewallExceptionTable(table); break; default: - base.DecompileTable(table); - break; + return false; } + + return true; + } + + /// + /// Finalize decompilation. + /// + /// The collection of all tables. + public override void PostDecompileTables(TableIndexedCollection tables) + { + this.FinalizeFirewallExceptionTable(tables); } /// @@ -60,38 +59,42 @@ private void DecompileWixFirewallExceptionTable(Table table) { foreach (Row row in table.Rows) { - Firewall.FirewallException fire = new Firewall.FirewallException(); - fire.Id = (string)row[0]; - fire.Name = (string)row[1]; + var firewallException = new XElement(FirewallConstants.FirewallExceptionName, + new XAttribute("Id", row.FieldAsString(0)), + new XAttribute("Name", row.FieldAsString(1)) + ); - string[] addresses = ((string)row[2]).Split(','); - if (1 == addresses.Length) + if (!row.IsColumnEmpty(2)) { - // special-case the Scope attribute values - if ("*" == addresses[0]) + string[] addresses = ((string)row[2]).Split(','); + if (1 == addresses.Length) { - fire.Scope = Firewall.FirewallException.ScopeType.any; - } - else if ("LocalSubnet" == addresses[0]) - { - fire.Scope = Firewall.FirewallException.ScopeType.localSubnet; + // special-case the Scope attribute values + if ("*" == addresses[0]) + { + firewallException.Add(new XAttribute("Scope", "any")); + } + else if ("LocalSubnet" == addresses[0]) + { + firewallException.Add(new XAttribute("Scope", "localSubnet")); + } + else + { + FirewallDecompiler.AddRemoteAddress(firewallException, addresses[0]); + } } else { - FirewallDecompiler.AddRemoteAddress(fire, addresses[0]); - } - } - else - { - foreach (string address in addresses) - { - FirewallDecompiler.AddRemoteAddress(fire, address); + foreach (string address in addresses) + { + FirewallDecompiler.AddRemoteAddress(firewallException, address); + } } } if (!row.IsColumnEmpty(3)) { - fire.Port = (string)row[3]; + firewallException.Add(new XAttribute("Port", row.FieldAsString(3))); } if (!row.IsColumnEmpty(4)) @@ -99,25 +102,28 @@ private void DecompileWixFirewallExceptionTable(Table table) switch (Convert.ToInt32(row[4])) { case FirewallConstants.NET_FW_IP_PROTOCOL_TCP: - fire.Protocol = Firewall.FirewallException.ProtocolType.tcp; + firewallException.Add(new XAttribute("Protocol", "tcp")); break; case FirewallConstants.NET_FW_IP_PROTOCOL_UDP: - fire.Protocol = Firewall.FirewallException.ProtocolType.udp; + firewallException.Add(new XAttribute("Protocol", "udp")); break; } } if (!row.IsColumnEmpty(5)) { - fire.Program = (string)row[5]; + firewallException.Add(new XAttribute("Program", row.FieldAsString(5))); } if (!row.IsColumnEmpty(6)) { - int attr = Convert.ToInt32(row[6]); - if (0x1 == (attr & 0x1)) // feaIgnoreFailures + var attr = Convert.ToInt32(row[6]); + AttributeIfNotNull("IgnoreFailure", 0x1 == (attr & 0x1)); + + // default value is true + if (0x2 != (attr & 0x2)) { - fire.IgnoreFailure = Firewall.YesNoType.yes; + AttributeIfNotNull("EdgeTraversal", false); } } @@ -126,24 +132,23 @@ private void DecompileWixFirewallExceptionTable(Table table) switch (Convert.ToInt32(row[7])) { case FirewallConstants.NET_FW_PROFILE2_DOMAIN: - fire.Profile = Firewall.FirewallException.ProfileType.domain; + firewallException.Add(new XAttribute("Profile", "domain")); break; case FirewallConstants.NET_FW_PROFILE2_PRIVATE: - fire.Profile = Firewall.FirewallException.ProfileType.@private; + firewallException.Add(new XAttribute("Profile", "private")); break; case FirewallConstants.NET_FW_PROFILE2_PUBLIC: - fire.Profile = Firewall.FirewallException.ProfileType.@public; + firewallException.Add(new XAttribute("Profile", "public")); break; case FirewallConstants.NET_FW_PROFILE2_ALL: - fire.Profile = Firewall.FirewallException.ProfileType.all; + firewallException.Add(new XAttribute("Profile", "all")); break; } } - // Description column is new in v3.6 - if (9 < row.Fields.Length && !row.IsColumnEmpty(9)) + if (!row.IsColumnEmpty(9)) { - fire.Description = (string)row[9]; + firewallException.Add(new XAttribute("Description", row.FieldAsString(9))); } if (!row.IsColumnEmpty(10)) @@ -151,32 +156,118 @@ private void DecompileWixFirewallExceptionTable(Table table) switch (Convert.ToInt32(row[10])) { case FirewallConstants.NET_FW_RULE_DIR_IN: - fire.Direction = Firewall.FirewallException.DirectionType.@in; + + firewallException.Add(AttributeIfNotNull("Outbound", false)); break; case FirewallConstants.NET_FW_RULE_DIR_OUT: - fire.Direction = Firewall.FirewallException.DirectionType.@out; + firewallException.Add(AttributeIfNotNull("Outbound", true)); break; } } - Wix.Component component = (Wix.Component)this.Core.GetIndexedElement("Component", (string)row[8]); - if (null != component) - { - component.AddChild(fire); - } - else + // Introduced after 4.0.1 + if (row.Fields.Length > 11) { - this.Core.OnMessage(WixWarnings.ExpectedForeignRow(row.SourceLineNumbers, table.Name, row.GetPrimaryKey(DecompilerConstants.PrimaryKeyDelimiter), "Component_", (string)row[6], "Component")); + if (!row.IsColumnEmpty(11)) + { + firewallException.Add(new XAttribute("Service", row.FieldAsString(11))); + } + + if (!row.IsColumnEmpty(12)) + { + var interfaceTypes = row.FieldAsString(12); + var interfaceTypesValue = 0; + if ("All" == interfaceTypes) + { + interfaceTypesValue = Int32.MaxValue; + } + else + { + if (interfaceTypes.Contains("Wireless")) + { + interfaceTypesValue |= 0x1; + } + + if (interfaceTypes.Contains("Lan")) + { + interfaceTypesValue |= 0x2; + } + + if (interfaceTypes.Contains("RemoteAccess")) + { + interfaceTypesValue |= 0x4; + } + } + + firewallException.Add(new XAttribute("InterfaceTypes", interfaceTypesValue)); + } + + if (!row.IsColumnEmpty(13)) + { + switch (Convert.ToInt32(row[13])) + { + case FirewallConstants.NET_FW_EDGE_TRAVERSAL_TYPE_DENY: + firewallException.Add(new XAttribute("EdgeTraversal", "Deny")); + break; + case FirewallConstants.NET_FW_EDGE_TRAVERSAL_TYPE_ALLOW: + firewallException.Add(new XAttribute("EdgeTraversal", "Allow")); + break; + case FirewallConstants.NET_FW_EDGE_TRAVERSAL_TYPE_DEFER_TO_APP: + firewallException.Add(new XAttribute("EdgeTraversal", "DeferToApp")); + break; + case FirewallConstants.NET_FW_EDGE_TRAVERSAL_TYPE_DEFER_TO_USER: + firewallException.Add(new XAttribute("EdgeTraversal", "DeferToUser")); + break; + } + } + + if (!row.IsColumnEmpty(14)) + { + firewallException.Add(new XAttribute("RemotePort", row.FieldAsString(14))); + } } + + this.DecompilerHelper.IndexElement(row, firewallException); } } - private static void AddRemoteAddress(Firewall.FirewallException fire, string address) + private static void AddRemoteAddress(XElement firewallException, string address) { - Firewall.RemoteAddress remote = new Firewall.RemoteAddress(); - remote.Content = address; - fire.AddChild(remote); + var remoteAddress = new XElement(FirewallConstants.RemoteAddressName, + new XAttribute("Value", address) + ); + + firewallException.AddAfterSelf(remoteAddress); + } + + private static XAttribute AttributeIfNotNull(string name, bool value) + { + return new XAttribute(name, value ? "yes" : "no"); + } + + /// + /// Finalize the FirewallException table. + /// + /// Collection of all tables. + private void FinalizeFirewallExceptionTable(TableIndexedCollection tables) + { + if (tables.TryGetTable("Wix4FirewallException", out var firewallExceptionTable)) + { + foreach (var row in firewallExceptionTable.Rows) + { + var xmlConfig = this.DecompilerHelper.GetIndexedElement(row); + + var componentId = row.FieldAsString(8); + if (this.DecompilerHelper.TryGetIndexedElement("Component", componentId, out var component)) + { + component.Add(xmlConfig); + } + else + { + this.Messaging.Write(WarningMessages.ExpectedForeignRow(row.SourceLineNumbers, firewallExceptionTable.Name, row.GetPrimaryKey(), "Component_", componentId, "Component")); + } + } + } } } -#endif } diff --git a/src/ext/Firewall/wixext/FirewallExtensionFactory.cs b/src/ext/Firewall/wixext/FirewallExtensionFactory.cs index 279b322a2..20209c28e 100644 --- a/src/ext/Firewall/wixext/FirewallExtensionFactory.cs +++ b/src/ext/Firewall/wixext/FirewallExtensionFactory.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation and contributors. All rights reserved. Licensed under the Microsoft Reciprocal License. See LICENSE.TXT file in the project root for full license information. +// Copyright (c) .NET Foundation and contributors. All rights reserved. Licensed under the Microsoft Reciprocal License. See LICENSE.TXT file in the project root for full license information. namespace WixToolset.Firewall { @@ -11,6 +11,7 @@ public class FirewallExtensionFactory : BaseExtensionFactory protected override IReadOnlyCollection ExtensionTypes => new[] { typeof(FirewallCompiler), + typeof(FirewallDecompiler), typeof(FirewallExtensionData), typeof(FirewallWindowsInstallerBackendBinderExtension), }; diff --git a/src/ext/Firewall/wixext/FirewallTableDefinitions.cs b/src/ext/Firewall/wixext/FirewallTableDefinitions.cs index 04918f5f0..6bb6c63d5 100644 --- a/src/ext/Firewall/wixext/FirewallTableDefinitions.cs +++ b/src/ext/Firewall/wixext/FirewallTableDefinitions.cs @@ -13,15 +13,19 @@ public static class FirewallTableDefinitions { new ColumnDefinition("Wix4FirewallException", ColumnType.String, 72, primaryKey: true, nullable: false, ColumnCategory.Identifier, description: "The primary key, a non-localized token.", modularizeType: ColumnModularizeType.Column), new ColumnDefinition("Name", ColumnType.Localized, 255, primaryKey: false, nullable: true, ColumnCategory.Formatted, description: "Localizable display name.", modularizeType: ColumnModularizeType.Property), - new ColumnDefinition("RemoteAddresses", ColumnType.String, 0, primaryKey: false, nullable: false, ColumnCategory.Formatted, description: "Remote address to accept incoming connections from.", modularizeType: ColumnModularizeType.Property), + new ColumnDefinition("RemoteAddresses", ColumnType.String, 0, primaryKey: false, nullable: true, ColumnCategory.Formatted, description: "Remote address to accept incoming connections from.", modularizeType: ColumnModularizeType.Property), new ColumnDefinition("Port", ColumnType.String, 0, primaryKey: false, nullable: true, ColumnCategory.Formatted, minValue: 1, description: "Port number.", modularizeType: ColumnModularizeType.Property), - new ColumnDefinition("Protocol", ColumnType.Number, 1, primaryKey: false, nullable: true, ColumnCategory.Integer, minValue: 6, maxValue: 17, description: "Protocol (6=TCP; 17=UDP)."), + new ColumnDefinition("Protocol", ColumnType.Number, 1, primaryKey: false, nullable: true, ColumnCategory.Integer, minValue: 0, maxValue: 255, description: "Protocol (6=TCP; 17=UDP). https://www.iana.org/assignments/protocol-numbers"), new ColumnDefinition("Program", ColumnType.String, 255, primaryKey: false, nullable: true, ColumnCategory.Formatted, description: "Exception for a program (formatted path name).", modularizeType: ColumnModularizeType.Property), - new ColumnDefinition("Attributes", ColumnType.Number, 4, primaryKey: false, nullable: true, ColumnCategory.Unknown, description: "Vital=1"), - new ColumnDefinition("Profile", ColumnType.Number, 4, primaryKey: false, nullable: false, ColumnCategory.Integer, minValue: 1, maxValue: 2147483647, description: "Profile (1=domain; 2=private; 4=public; 2147483647=all)."), + new ColumnDefinition("Attributes", ColumnType.Number, 4, primaryKey: false, nullable: true, ColumnCategory.Unknown, description: "Vital=1, OverwriteOnChange=2, Enabled=4, EnabledCustomized=8"), + new ColumnDefinition("Profile", ColumnType.Number, 4, primaryKey: false, nullable: true, ColumnCategory.Integer, minValue: 1, description: "Profile (1=domain; 2=private; 4=public; 2147483647=all)."), new ColumnDefinition("Component_", ColumnType.String, 72, primaryKey: false, nullable: false, ColumnCategory.Identifier, keyTable: "Component", keyColumn: 1, description: "Foreign key into the Component table referencing component that controls the firewall configuration.", modularizeType: ColumnModularizeType.Column), new ColumnDefinition("Description", ColumnType.String, 255, primaryKey: false, nullable: true, ColumnCategory.Formatted, description: "Description displayed in Windows Firewall manager for this firewall rule."), new ColumnDefinition("Direction", ColumnType.Number, 1, primaryKey: false, nullable: true, ColumnCategory.Integer, minValue: 1, maxValue: 2, description: "Direction (1=in; 2=out)"), + new ColumnDefinition("Service", ColumnType.String, 256, primaryKey: false, nullable: true, ColumnCategory.Formatted, description: "Windows Service short name (optional)."), + new ColumnDefinition("InterfaceTypes", ColumnType.String, 0, primaryKey: false, nullable: true, ColumnCategory.Unknown, description: "Optional comma separated list of interface types enforced by this firewall rule (combination of Wireless,Lan,RemoteAccess or All). Null is equivalent to All."), + new ColumnDefinition("EdgeTraversal", ColumnType.Number, 1, primaryKey: false, nullable: true, ColumnCategory.Integer, minValue: 0, maxValue: 3, description: "Edge traversal (0=Deny; 1=Allow; 2=DeferToApp; 3=DeferToUser;)"), + new ColumnDefinition("RemotePort", ColumnType.String, 0, primaryKey: false, nullable: true, ColumnCategory.Formatted, minValue: 1, description: "Remote port number.", modularizeType: ColumnModularizeType.Property), }, symbolIdIsPrimaryKey: true ); diff --git a/src/ext/Firewall/wixext/Symbols/WixFirewallExceptionSymbol.cs b/src/ext/Firewall/wixext/Symbols/WixFirewallExceptionSymbol.cs index 620de9693..4d77e32dd 100644 --- a/src/ext/Firewall/wixext/Symbols/WixFirewallExceptionSymbol.cs +++ b/src/ext/Firewall/wixext/Symbols/WixFirewallExceptionSymbol.cs @@ -21,6 +21,10 @@ public static partial class FirewallSymbolDefinitions new IntermediateFieldDefinition(nameof(WixFirewallExceptionSymbolFields.ComponentRef), IntermediateFieldType.String), new IntermediateFieldDefinition(nameof(WixFirewallExceptionSymbolFields.Description), IntermediateFieldType.String), new IntermediateFieldDefinition(nameof(WixFirewallExceptionSymbolFields.Direction), IntermediateFieldType.Number), + new IntermediateFieldDefinition(nameof(WixFirewallExceptionSymbolFields.Service), IntermediateFieldType.String), + new IntermediateFieldDefinition(nameof(WixFirewallExceptionSymbolFields.InterfaceTypes), IntermediateFieldType.String), + new IntermediateFieldDefinition(nameof(WixFirewallExceptionSymbolFields.EdgeTraversal), IntermediateFieldType.Number), + new IntermediateFieldDefinition(nameof(WixFirewallExceptionSymbolFields.RemotePort), IntermediateFieldType.String), }, typeof(WixFirewallExceptionSymbol)); } @@ -42,6 +46,10 @@ public enum WixFirewallExceptionSymbolFields ComponentRef, Description, Direction, + Service, + InterfaceTypes, + EdgeTraversal, + RemotePort } public class WixFirewallExceptionSymbol : IntermediateSymbol @@ -115,5 +123,29 @@ public int Direction get => this.Fields[(int)WixFirewallExceptionSymbolFields.Direction].AsNumber(); set => this.Set((int)WixFirewallExceptionSymbolFields.Direction, value); } + + public string Service + { + get => this.Fields[(int)WixFirewallExceptionSymbolFields.Service].AsString(); + set => this.Set((int)WixFirewallExceptionSymbolFields.Service, value); + } + + public string InterfaceTypes + { + get => this.Fields[(int)WixFirewallExceptionSymbolFields.InterfaceTypes].AsString(); + set => this.Set((int)WixFirewallExceptionSymbolFields.InterfaceTypes, value); + } + + public int EdgeTraversal + { + get => this.Fields[(int)WixFirewallExceptionSymbolFields.EdgeTraversal].AsNumber(); + set => this.Set((int)WixFirewallExceptionSymbolFields.EdgeTraversal, value); + } + + public string RemotePort + { + get => this.Fields[(int)WixFirewallExceptionSymbolFields.RemotePort].AsString(); + set => this.Set((int)WixFirewallExceptionSymbolFields.RemotePort, value); + } } -} \ No newline at end of file +} diff --git a/src/ext/Iis/ca/scacertexec.cpp b/src/ext/Iis/ca/scacertexec.cpp index 95870c79b..352644afe 100644 --- a/src/ext/Iis/ca/scacertexec.cpp +++ b/src/ext/Iis/ca/scacertexec.cpp @@ -154,7 +154,7 @@ static HRESULT ExecuteCertificateOperation( LPWSTR pwzPFXPassword = NULL; LPWSTR pwzFilePath = NULL; BYTE* pbData = NULL; - DWORD cbData = 0; + DWORD_PTR cbData = 0; DWORD_PTR cbPFXPassword = 0; BOOL fUserStoreLocation = (CERT_SYSTEM_STORE_CURRENT_USER == dwStoreLocation); @@ -174,7 +174,7 @@ static HRESULT ExecuteCertificateOperation( ExitOnFailure(hr, "Failed to parse certificate attribute"); if (SCA_ACTION_INSTALL == saAction) // install operations need more data { - hr = WcaReadStreamFromCaData(&pwz, &pbData, (DWORD_PTR*)&cbData); + hr = WcaReadStreamFromCaData(&pwz, &pbData, &cbData); ExitOnFailure(hr, "Failed to parse certificate stream."); hr = WcaReadStringFromCaData(&pwz, &pwzPFXPassword); @@ -192,7 +192,7 @@ static HRESULT ExecuteCertificateOperation( // CertAddCertificateContextToStore(CERT_STORE_ADD_REPLACE_EXISTING) does not remove the private key if the cert is replaced UninstallCertificatePackage(hCertStore, fUserStoreLocation, pwzName); - hr = InstallCertificatePackage(hCertStore, fUserStoreLocation, pwzName, pbData, cbData, iAttributes & SCA_CERT_ATTRIBUTE_VITAL, pwzPFXPassword); + hr = InstallCertificatePackage(hCertStore, fUserStoreLocation, pwzName, pbData, (DWORD)cbData, iAttributes & SCA_CERT_ATTRIBUTE_VITAL, pwzPFXPassword); ExitOnFailure(hr, "Failed to install certificate."); } else diff --git a/src/ext/Util/wixext/UtilCompiler.cs b/src/ext/Util/wixext/UtilCompiler.cs index 5fefed916..f7bb0614d 100644 --- a/src/ext/Util/wixext/UtilCompiler.cs +++ b/src/ext/Util/wixext/UtilCompiler.cs @@ -3123,8 +3123,14 @@ private void ParseServiceConfigElement(Intermediate intermediate, IntermediateSe // if this element is a child of ServiceInstall then ignore the service name provided. if ("ServiceInstall" == parentTableName) { - // TODO: the ServiceName attribute should not be allowed in this case (the overwriting behavior may confuse users) - serviceName = parentTableServiceName; + if (null == serviceName || parentTableServiceName == serviceName) + { + serviceName = parentTableServiceName; + } + else + { + this.Messaging.Write(ErrorMessages.IllegalAttributeWhenNested(sourceLineNumbers, element.Name.LocalName, "ServiceName", parentTableName)); + } newService = true; } else @@ -3136,7 +3142,8 @@ private void ParseServiceConfigElement(Intermediate intermediate, IntermediateSe } } - this.ParseHelper.ParseForExtensionElements(this.Context.Extensions, intermediate, section, element); + var context = new Dictionary() { { "ServiceConfigComponentId", componentId }, { "ServiceConfigServiceName", serviceName } }; + this.ParseHelper.ParseForExtensionElements(this.Context.Extensions, intermediate, section, element, context); if (!this.Messaging.EncounteredError) {