Skip to content

Commit

Permalink
Add support for enumeration serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
AArnott committed Oct 28, 2024
1 parent 86d057a commit 12ffadd
Show file tree
Hide file tree
Showing 5 changed files with 253 additions and 15 deletions.
5 changes: 5 additions & 0 deletions src/Nerdbank.MessagePack/Converters/DictionaryConverter`3.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ internal class DictionaryConverter<TDictionary, TKey, TValue>(Func<TDictionary,
/// <inheritdoc/>
public override TDictionary? Deserialize(ref MessagePackReader reader)
{
if (reader.TryReadNil())
{
return default;
}

throw new NotSupportedException();
}

Expand Down
172 changes: 172 additions & 0 deletions src/Nerdbank.MessagePack/Converters/EnumerableConverter`2.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
// Copyright (c) Andrew Arnott. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

#pragma warning disable SA1402 // File may only contain a single type

namespace Nerdbank.MessagePack.Converters;

/// <summary>
/// Serializes an enumerable.
/// Deserialization is not supported.
/// </summary>
/// <typeparam name="TEnumerable">The concrete type of enumerable.</typeparam>
/// <typeparam name="TElement">The type of element in the enumerable.</typeparam>
internal class EnumerableConverter<TEnumerable, TElement>(Func<TEnumerable, IEnumerable<TElement>> getEnumerable, IMessagePackConverter<TElement> elementConverter) : IMessagePackConverter<TEnumerable>
{
/// <inheritdoc/>
public override TEnumerable? Deserialize(ref MessagePackReader reader)
{
if (reader.TryReadNil())
{
return default;
}

throw new NotImplementedException();
}

/// <inheritdoc/>
public override void Serialize(ref MessagePackWriter writer, ref TEnumerable? value)
{
if (value is null)
{
writer.WriteNil();
return;
}

IEnumerable<TElement> enumerable = getEnumerable(value);
if (!Enumerable.TryGetNonEnumeratedCount(enumerable, out int count))
{
writer.WriteArrayHeader(count);
foreach (TElement element in enumerable)
{
TElement? el = element;
elementConverter.Serialize(ref writer, ref el);
}
}
else
{
TElement?[] array = enumerable.ToArray();
writer.WriteArrayHeader(array.Length);
for (int i = 0; i < array.Length; i++)
{
elementConverter.Serialize(ref writer, ref array[i]);
}
}
}

/// <summary>
/// Reads one element from the reader.
/// </summary>
/// <param name="reader">The reader.</param>
/// <returns>The element.</returns>
protected TElement ReadElement(ref MessagePackReader reader) => elementConverter.Deserialize(ref reader)!;
}

/// <summary>
/// Serializes and deserializes a mutable enumerable.
/// </summary>
/// <inheritdoc cref="EnumerableConverter{TEnumerable, TElement}"/>
/// <param name="getEnumerable"><inheritdoc cref="EnumerableConverter{TEnumerable, TElement}" path="/param[@name='getEnumerable']"/></param>
/// <param name="elementConverter"><inheritdoc cref="EnumerableConverter{TEnumerable, TElement}" path="/param[@name='elementConverter']"/></param>
/// <param name="addElement">The delegate that adds an element to the enumerable.</param>
/// <param name="ctor">The default constructor for the enumerable type.</param>
internal class MutableEnumerableConverter<TEnumerable, TElement>(
Func<TEnumerable, IEnumerable<TElement>> getEnumerable,
IMessagePackConverter<TElement> elementConverter,
Setter<TEnumerable, TElement> addElement,
Func<TEnumerable> ctor) : EnumerableConverter<TEnumerable, TElement>(getEnumerable, elementConverter)
{
/// <inheritdoc/>
public override TEnumerable? Deserialize(ref MessagePackReader reader)
{
if (reader.TryReadNil())
{
return default;
}

TEnumerable result = ctor();
int count = reader.ReadArrayHeader();
for (int i = 0; i < count; i++)
{
addElement(ref result, this.ReadElement(ref reader));
}

return result;
}
}

/// <summary>
/// Serializes and deserializes an immutable enumerable.
/// </summary>
/// <inheritdoc cref="EnumerableConverter{TEnumerable, TElement}"/>
/// <param name="getEnumerable"><inheritdoc cref="EnumerableConverter{TEnumerable, TElement}" path="/param[@name='getEnumerable']"/></param>
/// <param name="elementConverter"><inheritdoc cref="EnumerableConverter{TEnumerable, TElement}" path="/param[@name='elementConverter']"/></param>
/// <param name="ctor">A enumerable initializer that constructs from a span of elements.</param>
internal class SpanEnumerableConverter<TEnumerable, TElement>(
Func<TEnumerable, IEnumerable<TElement>> getEnumerable,
IMessagePackConverter<TElement> elementConverter,
SpanConstructor<TElement, TEnumerable> ctor) : EnumerableConverter<TEnumerable, TElement>(getEnumerable, elementConverter)
{
/// <inheritdoc/>
public override TEnumerable? Deserialize(ref MessagePackReader reader)
{
if (reader.TryReadNil())
{
return default;
}

int count = reader.ReadArrayHeader();
TElement[] elements = ArrayPool<TElement>.Shared.Rent(count);
try
{
for (int i = 0; i < count; i++)
{
elements[i] = this.ReadElement(ref reader);
}

return ctor(elements.AsSpan(0, count));
}
finally
{
ArrayPool<TElement>.Shared.Return(elements);
}
}
}

/// <summary>
/// Serializes and deserializes an enumerable that initializes from an enumerable of elements.
/// </summary>
/// <inheritdoc cref="EnumerableConverter{TEnumerable, TElement}"/>
/// <param name="getEnumerable"><inheritdoc cref="EnumerableConverter{TEnumerable, TElement}" path="/param[@name='getEnumerable']"/></param>
/// <param name="elementConverter"><inheritdoc cref="EnumerableConverter{TEnumerable, TElement}" path="/param[@name='elementConverter']"/></param>
/// <param name="ctor">A enumerable initializer that constructs from an enumerable of elements.</param>
internal class EnumerableEnumerableConverter<TEnumerable, TElement>(
Func<TEnumerable, IEnumerable<TElement>> getEnumerable,
IMessagePackConverter<TElement> elementConverter,
Func<IEnumerable<TElement>, TEnumerable> ctor) : EnumerableConverter<TEnumerable, TElement>(getEnumerable, elementConverter)
{
/// <inheritdoc/>
public override TEnumerable? Deserialize(ref MessagePackReader reader)
{
if (reader.TryReadNil())
{
return default;
}

int count = reader.ReadArrayHeader();
TElement[] elements = ArrayPool<TElement>.Shared.Rent(count);
try
{
for (int i = 0; i < count; i++)
{
elements[i] = this.ReadElement(ref reader);
}

return ctor(elements.Take(count));
}
finally
{
ArrayPool<TElement>.Shared.Return(elements);
}
}
}
36 changes: 21 additions & 15 deletions src/Nerdbank.MessagePack/StandardVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -124,22 +124,28 @@ internal class StandardVisitor(MessagePackSerializer owner) : TypeShapeVisitor
Func<TDictionary, IReadOnlyDictionary<TKey, TValue>> getReadable = dictionaryShape.GetGetDictionary();

// Deserialization functions.
switch (dictionaryShape.ConstructionStrategy)
return dictionaryShape.ConstructionStrategy switch
{
case CollectionConstructionStrategy.None:
return new DictionaryConverter<TDictionary, TKey, TValue>(getReadable, keyConverter, valueConverter);
case CollectionConstructionStrategy.Mutable:
Setter<TDictionary, KeyValuePair<TKey, TValue>> addEntry = dictionaryShape.GetAddKeyValuePair();
Func<TDictionary> ctor = dictionaryShape.GetDefaultConstructor();
return new MutableDictionaryConverter<TDictionary, TKey, TValue>(getReadable, keyConverter, valueConverter, addEntry, ctor);
case CollectionConstructionStrategy.Span:
SpanConstructor<KeyValuePair<TKey, TValue>, TDictionary> spanCtor = dictionaryShape.GetSpanConstructor();
return new ImmutableDictionaryConverter<TDictionary, TKey, TValue>(getReadable, keyConverter, valueConverter, spanCtor);
case CollectionConstructionStrategy.Enumerable:
Func<IEnumerable<KeyValuePair<TKey, TValue>>, TDictionary> enumCtor = dictionaryShape.GetEnumerableConstructor();
return new EnumerableDictionaryConverter<TDictionary, TKey, TValue>(getReadable, keyConverter, valueConverter, enumCtor);
default: throw new NotSupportedException($"Unrecognized dictionary pattern: {typeof(TDictionary).Name}");
}
CollectionConstructionStrategy.None => new DictionaryConverter<TDictionary, TKey, TValue>(getReadable, keyConverter, valueConverter),
CollectionConstructionStrategy.Mutable => new MutableDictionaryConverter<TDictionary, TKey, TValue>(getReadable, keyConverter, valueConverter, dictionaryShape.GetAddKeyValuePair(), dictionaryShape.GetDefaultConstructor()),
CollectionConstructionStrategy.Span => new ImmutableDictionaryConverter<TDictionary, TKey, TValue>(getReadable, keyConverter, valueConverter, dictionaryShape.GetSpanConstructor()),
CollectionConstructionStrategy.Enumerable => new EnumerableDictionaryConverter<TDictionary, TKey, TValue>(getReadable, keyConverter, valueConverter, dictionaryShape.GetEnumerableConstructor()),
_ => throw new NotSupportedException($"Unrecognized dictionary pattern: {typeof(TDictionary).Name}"),
};
}

/// <inheritdoc/>
public override object? VisitEnumerable<TEnumerable, TElement>(IEnumerableTypeShape<TEnumerable, TElement> enumerableShape, object? state = null)
{
IMessagePackConverter<TElement> elementConverter = this.GetConverter(enumerableShape.ElementType);
return enumerableShape.ConstructionStrategy switch
{
CollectionConstructionStrategy.None => new EnumerableConverter<TEnumerable, TElement>(enumerableShape.GetGetEnumerable(), elementConverter),
CollectionConstructionStrategy.Mutable => new MutableEnumerableConverter<TEnumerable, TElement>(enumerableShape.GetGetEnumerable(), elementConverter, enumerableShape.GetAddElement(), enumerableShape.GetDefaultConstructor()),
CollectionConstructionStrategy.Span => new SpanEnumerableConverter<TEnumerable, TElement>(enumerableShape.GetGetEnumerable(), elementConverter, enumerableShape.GetSpanConstructor()),
CollectionConstructionStrategy.Enumerable => new EnumerableEnumerableConverter<TEnumerable, TElement>(enumerableShape.GetGetEnumerable(), elementConverter, enumerableShape.GetEnumerableConstructor()),
_ => throw new NotSupportedException($"Unrecognized enumerable pattern: {typeof(TEnumerable).Name}"),
};
}

/// <summary>
Expand Down
27 changes: 27 additions & 0 deletions test/Nerdbank.MessagePack.Tests/ByValueEquality.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,31 @@ internal static bool Equal<TKey, TValue>(IReadOnlyDictionary<TKey, TValue>? left

return true;
}

internal static bool Equal<T>(IEnumerable<T>? left, IEnumerable<T>? right, IEqualityComparer<T>? equalityComparer = null) => Equal(left?.ToArray(), right?.ToArray(), equalityComparer);

internal static bool Equal<T>(IReadOnlyList<T>? left, IReadOnlyList<T>? right, IEqualityComparer<T>? equalityComparer = null)
{
equalityComparer ??= EqualityComparer<T>.Default;

if (left is null || right is null)
{
return left is null == right is null;
}

if (left.Count != right.Count)
{
return false;
}

for (int i = 0; i < left.Count; i++)
{
if (!equalityComparer.Equals(left[i], right[i]))
{
return false;
}
}

return true;
}
}
28 changes: 28 additions & 0 deletions test/Nerdbank.MessagePack.Tests/MessagePackSerializerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@ public partial class MessagePackSerializerTests(ITestOutputHelper logger)
[Fact]
public void ImmutableDictionary() => this.AssertRoundtrip(new ClassWithImmutableDictionary { StringInt = ImmutableDictionary<string, int>.Empty.Add("a", 1) });

[Fact]
public void Array() => this.AssertRoundtrip(new ClassWithArray { IntArray = [1, 2, 3] });

[Fact]
public void Array_Null() => this.AssertRoundtrip(new ClassWithArray { IntArray = null });

[Fact]
public void Enumerable() => this.AssertRoundtrip(new ClassWithEnumerable { IntEnum = [1, 2, 3] });

[Fact]
public void Enumerable_Null() => this.AssertRoundtrip(new ClassWithEnumerable { IntEnum = null });

protected void AssertRoundtrip<T>(T? value)
where T : IShapeable<T>
{
Expand Down Expand Up @@ -137,4 +149,20 @@ public partial class ClassWithImmutableDictionary : IEquatable<ClassWithImmutabl

public bool Equals(ClassWithImmutableDictionary? other) => other is not null && ByValueEquality.Equal(this.StringInt, other.StringInt);
}

[GenerateShape]
public partial class ClassWithArray : IEquatable<ClassWithArray>
{
public int[]? IntArray { get; set; }

public bool Equals(ClassWithArray? other) => other is not null && ByValueEquality.Equal(this.IntArray, other.IntArray);
}

[GenerateShape]
public partial class ClassWithEnumerable : IEquatable<ClassWithEnumerable>
{
public IEnumerable<int>? IntEnum { get; set; }

public bool Equals(ClassWithEnumerable? other) => other is not null && ByValueEquality.Equal(this.IntEnum, other.IntEnum);
}
}

0 comments on commit 12ffadd

Please sign in to comment.