Skip to content

Commit

Permalink
Add target based scaling support for MSSQL (#169)
Browse files Browse the repository at this point in the history
Co-authored-by: Chris Gillum <[email protected]>
  • Loading branch information
bachuv and cgillum authored Oct 10, 2024
1 parent 6003e41 commit e477b00
Show file tree
Hide file tree
Showing 11 changed files with 181 additions and 19 deletions.
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
# Changelog

## v1.3.1 (Unreleased)
## v1.4.0

### New

* Support for Azure Functions target-based scaling ([#169](https://github.com/microsoft/durabletask-mssql/pull/169))
* Added `net6.0` TFM to Microsoft.DurableTask.SqlServer.AzureFunctions

### Updates

* Fix SQL retry logic to open a new connection if a previous failure closed the connection ([#221](https://github.com/microsoft/durabletask-mssql/pull/221)) - contributed by [@microrama](https://github.com/microrama)
* Pin Microsoft.Azure.WebJobs.Extensions.DurableTask dependency to 2.13.7 instead of wildcard to avoid accidental build breaks

## v1.3.0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
<Import Project="../common.props" />

<PropertyGroup>
<TargetFrameworks>netstandard2.0</TargetFrameworks>
<TargetFrameworks>netstandard2.0;net6.0</TargetFrameworks>
</PropertyGroup>

<PropertyGroup Condition="'$(TargetFramework)' == 'net6.0'">
<DefineConstants>$(DefineConstants);FUNCTIONS_V4</DefineConstants>
</PropertyGroup>

<!-- NuGet package settings -->
Expand All @@ -16,7 +20,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.Azure.WebJobs.Extensions.DurableTask" Version="2.*" />
<PackageReference Include="Microsoft.Azure.WebJobs.Extensions.DurableTask" Version="2.13.7" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class SqlDurabilityProvider : DurabilityProvider
readonly SqlOrchestrationService service;

SqlScaleMonitor? scaleMonitor;
#if FUNCTIONS_V4
SqlTargetScaler? targetScaler;
#endif

public SqlDurabilityProvider(
SqlOrchestrationService service,
Expand Down Expand Up @@ -197,8 +200,33 @@ public override bool TryGetScaleMonitor(
string storageConnectionString,
out IScaleMonitor scaleMonitor)
{
scaleMonitor = this.scaleMonitor ??= new SqlScaleMonitor(this.service, hubName);
if (this.scaleMonitor == null)
{
var sqlMetricsProvider = new SqlMetricsProvider(this.service);
this.scaleMonitor = new SqlScaleMonitor(hubName, sqlMetricsProvider);
}

scaleMonitor = this.scaleMonitor;
return true;
}

#if FUNCTIONS_V4
public override bool TryGetTargetScaler(
string functionId,
string functionName,
string hubName,
string connectionName,
out ITargetScaler targetScaler)
{
if (this.targetScaler == null)
{
var sqlMetricsProvider = new SqlMetricsProvider(this.service);
this.targetScaler = new SqlTargetScaler(hubName, sqlMetricsProvider);
}

targetScaler = this.targetScaler;
return true;
}
#endif
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public DurabilityProvider GetDurabilityProvider(DurableClientAttribute attribute
lock (this.clientProviders)
{
string key = GetDurabilityProviderKey(attribute);
if (this.clientProviders.TryGetValue(key, out DurabilityProvider clientProvider))
if (this.clientProviders.TryGetValue(key, out DurabilityProvider? clientProvider))
{
return clientProvider;
}
Expand Down
29 changes: 29 additions & 0 deletions src/DurableTask.SqlServer.AzureFunctions/SqlMetricsProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

namespace DurableTask.SqlServer.AzureFunctions
{
using System.Threading;
using System.Threading.Tasks;

public class SqlMetricsProvider
{
readonly SqlOrchestrationService service;

public SqlMetricsProvider(SqlOrchestrationService service)
{
this.service = service;
}

public virtual async Task<SqlScaleMetric> GetMetricsAsync(int? previousWorkerCount = null)
{
// GetRecommendedReplicaCountAsync will write a trace if the recommendation results
// in a worker count that is different from the worker count we pass in as an argument.
int recommendedReplicaCount = await this.service.GetRecommendedReplicaCountAsync(
previousWorkerCount,
CancellationToken.None);

return new SqlScaleMetric { RecommendedReplicaCount = recommendedReplicaCount };
}
}
}
2 changes: 1 addition & 1 deletion src/DurableTask.SqlServer.AzureFunctions/SqlScaleMetric.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace DurableTask.SqlServer.AzureFunctions
{
using Microsoft.Azure.WebJobs.Host.Scale;

class SqlScaleMetric : ScaleMetrics
public class SqlScaleMetric : ScaleMetrics
{
public int RecommendedReplicaCount { get; set; }
}
Expand Down
24 changes: 13 additions & 11 deletions src/DurableTask.SqlServer.AzureFunctions/SqlScaleMonitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,22 @@ class SqlScaleMonitor : IScaleMonitor<SqlScaleMetric>
static readonly ScaleStatus NoScaleVote = new ScaleStatus { Vote = ScaleVote.None };
static readonly ScaleStatus ScaleOutVote = new ScaleStatus { Vote = ScaleVote.ScaleOut };

readonly SqlOrchestrationService service;
readonly SqlMetricsProvider metricsProvider;

int? previousWorkerCount = -1;

public SqlScaleMonitor(SqlOrchestrationService service, string taskHubName)
public SqlScaleMonitor(string taskHubName, SqlMetricsProvider sqlMetricsProvider)
{
this.service = service ?? throw new ArgumentNullException(nameof(service));
this.Descriptor = new ScaleMonitorDescriptor($"DurableTask-SqlServer:{taskHubName ?? "default"}");
// Scalers in Durable Functions are shared for all functions in the same task hub.
// So instead of using a function ID, we use the task hub name as the basis for the descriptor ID.
string id = $"DurableTask-SqlServer:{taskHubName ?? "default"}";

#if FUNCTIONS_V4
this.Descriptor = new ScaleMonitorDescriptor(id: id, functionId: id);
#else
this.Descriptor = new ScaleMonitorDescriptor(id);
#endif
this.metricsProvider = sqlMetricsProvider ?? throw new ArgumentNullException(nameof(sqlMetricsProvider));
}

/// <inheritdoc />
Expand All @@ -38,13 +46,7 @@ public SqlScaleMonitor(SqlOrchestrationService service, string taskHubName)
/// <inheritdoc />
public async Task<SqlScaleMetric> GetMetricsAsync()
{
// GetRecommendedReplicaCountAsync will write a trace if the recommendation results
// in a worker count that is different from the worker count we pass in as an argument.
int recommendedReplicaCount = await this.service.GetRecommendedReplicaCountAsync(
this.previousWorkerCount,
CancellationToken.None);

return new SqlScaleMetric { RecommendedReplicaCount = recommendedReplicaCount };
return await this.metricsProvider.GetMetricsAsync(this.previousWorkerCount);
}

/// <inheritdoc />
Expand Down
37 changes: 37 additions & 0 deletions src/DurableTask.SqlServer.AzureFunctions/SqlTargetScaler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

#if FUNCTIONS_V4
namespace DurableTask.SqlServer.AzureFunctions
{
using System;
using System.Threading.Tasks;
using Microsoft.Azure.WebJobs.Host.Scale;

public class SqlTargetScaler : ITargetScaler
{
readonly SqlMetricsProvider sqlMetricsProvider;

public SqlTargetScaler(string taskHubName, SqlMetricsProvider sqlMetricsProvider)
{
this.sqlMetricsProvider = sqlMetricsProvider;

// Scalers in Durable Functions are shared for all functions in the same task hub.
// So instead of using a function ID, we use the task hub name as the basis for the descriptor ID.
string id = $"DurableTask-SqlServer:{taskHubName ?? "default"}";
this.TargetScalerDescriptor = new TargetScalerDescriptor(id);
}

public TargetScalerDescriptor TargetScalerDescriptor { get; }

public async Task<TargetScalerResult> GetScaleResultAsync(TargetScalerContext context)
{
SqlScaleMetric sqlScaleMetric = await this.sqlMetricsProvider.GetMetricsAsync();
return new TargetScalerResult
{
TargetWorkerCount = Math.Max(0, sqlScaleMetric.RecommendedReplicaCount),
};
}
}
}
#endif
2 changes: 1 addition & 1 deletion src/common.props
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
<!-- Version settings: https://andrewlock.net/version-vs-versionsuffix-vs-packageversion-what-do-they-all-mean/ -->
<PropertyGroup>
<MajorVersion>1</MajorVersion>
<MinorVersion>3</MinorVersion>
<MinorVersion>4</MinorVersion>
<PatchVersion>0</PatchVersion>
<VersionPrefix>$(MajorVersion).$(MinorVersion).$(PatchVersion)</VersionPrefix>
<VersionSuffix></VersionSuffix>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

namespace DurableTask.SqlServer.AzureFunctions.Tests
{
using DurableTask.Core;
using Microsoft.Azure.WebJobs.Extensions.DurableTask;
using Microsoft.Azure.WebJobs.Host.Scale;
using Moq;
using Xunit;

public class TargetBasedScalingTests
{
readonly Mock<SqlMetricsProvider> metricsProviderMock;
readonly Mock<IOrchestrationService> orchestrationServiceMock;

public TargetBasedScalingTests()
{
this.orchestrationServiceMock = new Mock<IOrchestrationService>(MockBehavior.Strict);

SqlOrchestrationService? nullServiceArg = null; // not needed for this test
this.metricsProviderMock = new Mock<SqlMetricsProvider>(
behavior: MockBehavior.Strict,
nullServiceArg);
}

[Theory]
[InlineData(0)]
[InlineData(10)]
[InlineData(20)]
public async void TargetBasedScalingTest(int expectedTargetWorkerCount)
{
var durabilityProviderMock = new Mock<DurabilityProvider>(
MockBehavior.Strict,
"storageProviderName",
this.orchestrationServiceMock.Object,
new Mock<IOrchestrationServiceClient>().Object,
"connectionName");

var sqlScaleMetric = new SqlScaleMetric()
{
RecommendedReplicaCount = expectedTargetWorkerCount,
};

this.metricsProviderMock.Setup(m => m.GetMetricsAsync(null)).ReturnsAsync(sqlScaleMetric);

var targetScaler = new SqlTargetScaler(
"functionId",
this.metricsProviderMock.Object);

TargetScalerResult result = await targetScaler.GetScaleResultAsync(new TargetScalerContext());

Assert.Equal(expectedTargetWorkerCount, result.TargetWorkerCount);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ async Task ValidateDatabaseSchemaAsync(TestDatabase database, string schemaName
database.ConnectionString,
schemaName);
Assert.Equal(1, currentSchemaVersion.Major);
Assert.Equal(3, currentSchemaVersion.Minor);
Assert.Equal(4, currentSchemaVersion.Minor);
Assert.Equal(0, currentSchemaVersion.Patch);
}

Expand Down

0 comments on commit e477b00

Please sign in to comment.