Skip to content
This repository has been archived by the owner on Jan 16, 2022. It is now read-only.

Commit

Permalink
#133: Handle type nullability in custom conversions
Browse files Browse the repository at this point in the history
  • Loading branch information
cezarypiatek committed Aug 31, 2020
1 parent 76445d1 commit 92c142c
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ namespace MappingGenerator.Test.MappingGenerator.TestCaseData
{
throw new NotImplementedException();
}

private static CategoryDTO? MapCategory(CategoryEntity category)
{
throw new NotImplementedException();
}
}

public class UserDTO
Expand Down Expand Up @@ -41,6 +46,7 @@ namespace MappingGenerator.Test.MappingGenerator.TestCaseData
public int Total { get; set; }
public int UnitId { get; set; }
public UserSourceDTO Source { get; set; }
public CategoryDTO Category{ get; set; }
}

public class AccountDTO
Expand Down Expand Up @@ -76,6 +82,10 @@ namespace MappingGenerator.Test.MappingGenerator.TestCaseData
}
}

public class CategoryDTO
{
}

//---- Entities

public class UserEntity
Expand All @@ -99,6 +109,7 @@ namespace MappingGenerator.Test.MappingGenerator.TestCaseData
public int? GetTotal() => throw new Exception();
public UnitEntity? Unit { get; set; }
public UserSourceEntity Source { get; set; }
public CategoryEntity? Category{ get; set; }
}

public class AccountEntity
Expand Down Expand Up @@ -128,4 +139,8 @@ namespace MappingGenerator.Test.MappingGenerator.TestCaseData
public string? ProviderAddress { get; set; }
public bool? IsActive {get; set;}
}

public class CategoryEntity
{
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,15 @@ namespace MappingGenerator.Test.MappingGenerator.TestCaseData
Address6 = entity.Address6?.OfType<AddressEntity>().Select(entityAddress6 => new AddressDTO(entityAddress6)).ToList() ?? throw new ArgumentNullException(nameof(entity), "The value of 'entity.Address6' should not be null"),
Total = entity.GetTotal() ?? throw new ArgumentNullException(nameof(entity), "The value of 'entity.GetTotal()' should not be null"),
UnitId = entity.Unit?.Id ?? throw new ArgumentNullException(nameof(entity), "The value of 'entity.Unit?.Id' should not be null"),
Source = new UserSourceDTO(providerName: entity.Source.ProviderName, providerAddress: entity.Source.ProviderAddress ?? throw new ArgumentNullException(nameof(entity), "The value of 'entity.Source.ProviderAddress' should not be null"), isActive: entity.Source.IsActive ?? throw new ArgumentNullException(nameof(entity), "The value of 'entity.Source.IsActive' should not be null"))
Source = new UserSourceDTO(providerName: entity.Source.ProviderName, providerAddress: entity.Source.ProviderAddress ?? throw new ArgumentNullException(nameof(entity), "The value of 'entity.Source.ProviderAddress' should not be null"), isActive: entity.Source.IsActive ?? throw new ArgumentNullException(nameof(entity), "The value of 'entity.Source.IsActive' should not be null")),
Category = entity.Category != null ? MapCategory(entity.Category) ?? throw new NullReferenceException("The value of 'MapCategory(entity.Category)' should not be null") : throw new ArgumentNullException(nameof(entity), "The value of 'entity.Category' should not be null")
} : throw new ArgumentNullException(nameof(entity), "The value of 'entity' should not be null");
}

private static CategoryDTO? MapCategory(CategoryEntity category)
{
throw new NotImplementedException();
}
}

public class UserDTO
Expand Down Expand Up @@ -82,6 +88,7 @@ namespace MappingGenerator.Test.MappingGenerator.TestCaseData
public int Total { get; set; }
public int UnitId { get; set; }
public UserSourceDTO Source { get; set; }
public CategoryDTO Category{ get; set; }
}

public class AccountDTO
Expand Down Expand Up @@ -117,6 +124,10 @@ namespace MappingGenerator.Test.MappingGenerator.TestCaseData
}
}

public class CategoryDTO
{
}

//---- Entities

public class UserEntity
Expand All @@ -140,6 +151,7 @@ namespace MappingGenerator.Test.MappingGenerator.TestCaseData
public int? GetTotal() => throw new Exception();
public UnitEntity? Unit { get; set; }
public UserSourceEntity Source { get; set; }
public CategoryEntity? Category{ get; set; }
}

public class AccountEntity
Expand Down Expand Up @@ -169,4 +181,8 @@ namespace MappingGenerator.Test.MappingGenerator.TestCaseData
public string? ProviderAddress { get; set; }
public bool? IsActive {get; set;}
}

public class CategoryEntity
{
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,12 @@ private async Task<Document> GenerateMappingMethodBody(Document document, BaseMe
continue;
}

mappingContext.CustomConversions[(userDefinedConversion.Parameters.First().Type, userDefinedConversion.ReturnType)] = (ExpressionSyntax) generator.IdentifierName(userDefinedConversion.Name);
mappingContext.CustomConversions.Add(new CustomConversion()
{
FromType = new AnnotatedType(userDefinedConversion.Parameters.First().Type),
ToType = new AnnotatedType(userDefinedConversion.ReturnType),
Conversion = SyntaxFactory.IdentifierName(userDefinedConversion.Name)
});
}
}
var blockSyntax = MappingImplementorEngine.GenerateMappingBlock(methodSymbol, generator, semanticModel, mappingContext);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
using MappingGenerator.Mappings.SourceFinders;
using Microsoft.CodeAnalysis;

namespace MappingGenerator.Mappings
{
public class AnnotatedType
{
public ITypeSymbol Type { get; }
public bool CanBeNull { get; }

public AnnotatedType(ITypeSymbol type)
{
Type = type;
CanBeNull = type.CanBeNull();
}

public AnnotatedType(ITypeSymbol type, bool canBeNull)
{
Type = type;
CanBeNull = canBeNull;
}

public AnnotatedType AsNotNull()
{
return new AnnotatedType(Type, false);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System.Collections.Generic;
using System.Linq;
using System.Net.Http.Headers;
using MappingGenerator.RoslynHelpers;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
Expand Down Expand Up @@ -35,16 +37,44 @@ public void AddMissingConversion(ITypeSymbol fromType, ITypeSymbol toType) => Mi

public bool WrapInCustomConversion { get; set; }

public Dictionary<(ITypeSymbol fromType, ITypeSymbol toType), ExpressionSyntax> CustomConversions { get; } = new Dictionary<(ITypeSymbol fromType, ITypeSymbol toType), ExpressionSyntax>();
public List<CustomConversion> CustomConversions { get; } = new List<CustomConversion>();

public ExpressionSyntax? FindConversion(ITypeSymbol fromType, ITypeSymbol toType)
public CustomConversion? FindConversion(AnnotatedType fromType, AnnotatedType toType)
{
if (CustomConversions.TryGetValue((fromType, toType), out var conversion))
if (CustomConversions.Count == 0)
{
return conversion;
return null;
}

return null;
var candidates = CustomConversions.Where(x => x.FromType.Type.Equals(fromType.Type) && x.ToType.Type.Equals(toType.Type)).ToList();
if (candidates.Count == 0)
{
return null;
}
if (candidates.Count == 1)
{
return candidates[0];
}
if (candidates.Count > 1)
{
var exactlyConversion = candidates.FirstOrDefault(x => x.FromType.CanBeNull == fromType.CanBeNull && x.ToType.CanBeNull == toType.CanBeNull);
if (exactlyConversion != null)
{
return exactlyConversion;
}

return candidates.FirstOrDefault(x => x.FromType.CanBeNull == fromType.CanBeNull || x.ToType.CanBeNull == toType.CanBeNull);
}

return candidates.FirstOrDefault();
}
}

public class CustomConversion
{
public AnnotatedType FromType { get; set; }
public AnnotatedType ToType { get; set; }

public ExpressionSyntax Conversion { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,9 @@
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Simplification;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
namespace MappingGenerator.Mappings
{

public class AnnotatedType
{
public ITypeSymbol Type { get; }
public bool CanBeNull { get; }

public AnnotatedType(ITypeSymbol type)
{
Type = type;
CanBeNull = type.CanBeNull();
}

public AnnotatedType(ITypeSymbol type, bool canBeNull)
{
Type = type;
CanBeNull = canBeNull;
}

public AnnotatedType AsNotNull()
{
return new AnnotatedType(Type, false);
}
}

public class MappingEngine
{
protected readonly SemanticModel semanticModel;
Expand Down Expand Up @@ -94,15 +69,16 @@ public MappingElement MapExpression(MappingElement source, AnnotatedType targetT
};
}

if (mappingContext.FindConversion(sourceType.Type, targetType.Type) is {} userDefinedConversion)
if (mappingContext.FindConversion(sourceType, targetType) is {} userDefinedConversion)
{
//TODO: Check if conversion accept nullable type
//TODO: Check if the result of the conversion is nullable-compatible
var invocationExpression = (ExpressionSyntax)syntaxGenerator.InvocationExpression(userDefinedConversion, source.Expression);
return new MappingElement()
var invocationExpression = (ExpressionSyntax)syntaxGenerator.InvocationExpression(userDefinedConversion.Conversion, source.Expression);
var protectResultFromNull = targetType.CanBeNull == false && userDefinedConversion.ToType.CanBeNull;
var conversionExpression = protectResultFromNull ? OrFailWhenExpressionNull(invocationExpression) : invocationExpression;
return new MappingElement
{
ExpressionType = targetType,
Expression = HandleSafeNull(source, targetType, invocationExpression),
Expression = HandleSafeNull(source, userDefinedConversion.FromType, conversionExpression)
};
}

Expand All @@ -111,7 +87,7 @@ public MappingElement MapExpression(MappingElement source, AnnotatedType targetT
{
return new MappingElement
{
Expression = OrFailWhenNull(source.Expression),
Expression = OrFailWhenArgumentNull(source.Expression),
ExpressionType = new AnnotatedType(underlyingType, false)
};
}
Expand All @@ -124,7 +100,7 @@ public MappingElement MapExpression(MappingElement source, AnnotatedType targetT
return new MappingElement
{
ExpressionType = conversion.ExpressionType.AsNotNull(),
Expression = OrFailWhenNull(conversion.Expression)
Expression = OrFailWhenArgumentNull(conversion.Expression)
};
}
return conversion;
Expand All @@ -140,17 +116,21 @@ public MappingElement MapExpression(MappingElement source, AnnotatedType targetT
{
return new MappingElement
{
Expression = OrFailWhenNull(source.Expression),
Expression = OrFailWhenArgumentNull(source.Expression),
ExpressionType = source.ExpressionType.AsNotNull()
};
}
return source;
}

private static BinaryExpressionSyntax OrFailWhenNull(ExpressionSyntax expression, string messageExpression = null)
private static BinaryExpressionSyntax OrFailWhenArgumentNull(ExpressionSyntax expression, string messageExpression = null)
{
return BinaryExpression(SyntaxKind.CoalesceExpression, expression, ThrowNullArgumentException(messageExpression ?? expression.ToFullString()));
}
private static BinaryExpressionSyntax OrFailWhenExpressionNull(ExpressionSyntax expression)
{
return BinaryExpression(SyntaxKind.CoalesceExpression, expression, ThrowNullReferenceException(expression.ToFullString()));
}

protected virtual bool ShouldCreateConversionBetweenTypes(ITypeSymbol targetType, ITypeSymbol sourceType)
{
Expand Down Expand Up @@ -184,7 +164,7 @@ protected virtual MappingElement TryToCreateMappingExpression(MappingElement sou
return new MappingElement
{
ExpressionType = targetType,
Expression = shouldProtectAgainstNull ? OrFailWhenNull(collectionMapping, source.Expression.ToFullString()) : collectionMapping,
Expression = shouldProtectAgainstNull ? OrFailWhenArgumentNull(collectionMapping, source.Expression.ToFullString()) : collectionMapping,
};
}

Expand Down Expand Up @@ -245,6 +225,15 @@ private static ThrowExpressionSyntax ThrowNullArgumentException(string expressio
var throwExpressionSyntax = SyntaxFactory.ThrowExpression(CreateObject(exceptionTypeName, exceptionParameters));
return throwExpressionSyntax;
}

private static ThrowExpressionSyntax ThrowNullReferenceException(string expressionText)
{
var errorMessageExpression = LiteralExpression(SyntaxKind.StringLiteralExpression, SyntaxFactory.Literal($"The value of '{expressionText}' should not be null"));
var exceptionTypeName = SyntaxFactory.IdentifierName("NullReferenceException");
var exceptionParameters = ArgumentList(new SeparatedSyntaxList<ArgumentSyntax>().AddRange(new []{ Argument(errorMessageExpression)}));
var throwExpressionSyntax = SyntaxFactory.ThrowExpression(CreateObject(exceptionTypeName, exceptionParameters));
return throwExpressionSyntax;
}

private ObjectCreationExpressionSyntax CreateObject(ITypeSymbol type, ArgumentListSyntax argumentList = null)
{
Expand Down
7 changes: 6 additions & 1 deletion MappingGenerator/OnBuildGenerator/OnBuildMappingGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ public Task<GenerationResult> GenerateAsync(CSharpSyntaxNode processedNode, Attr
{
if (x.IsStatic && accessibilityHelper.IsSymbolAccessible(x, interfaceSymbol))
{
mappingContext.CustomConversions[(x.Parameters[0].Type, x.ReturnType)] = (ExpressionSyntax)syntaxGenerator.MemberAccessExpression((ExpressionSyntax)syntaxGenerator.IdentifierName(x.ContainingType.ToDisplayString()), x.Name);
mappingContext.CustomConversions.Add(new CustomConversion()
{
FromType = new AnnotatedType(x.Parameters[0].Type),
ToType = new AnnotatedType(x.ReturnType),
Conversion = (ExpressionSyntax)syntaxGenerator.MemberAccessExpression((ExpressionSyntax)syntaxGenerator.IdentifierName(x.ContainingType.ToDisplayString()), x.Name)
});
}
}

Expand Down

0 comments on commit 92c142c

Please sign in to comment.