diff --git a/src/PSLambda/CompileVisitor.cs b/src/PSLambda/CompileVisitor.cs index 8c422c8..bff0b27 100644 --- a/src/PSLambda/CompileVisitor.cs +++ b/src/PSLambda/CompileVisitor.cs @@ -431,20 +431,36 @@ public object VisitForEachStatement(ForEachStatementAst forEachStatementAst) { using (_loops.NewScope()) { - var enumerator = Call( - ReflectionCache.LanguagePrimitives_GetEnumerator, - forEachStatementAst.Condition.Compile(this)); + var condition = forEachStatementAst.Condition.Compile(this); + var canEnumerate = TryGetEnumeratorMethod( + condition.Type, + out MethodInfo getEnumeratorMethod, + out MethodInfo getCurrentMethod); + + if (!canEnumerate) + { + Errors.ReportParseError( + forEachStatementAst.Condition.Extent, + nameof(ErrorStrings.ForEachInvalidEnumerable), + string.Format( + CultureInfo.CurrentCulture, + ErrorStrings.ForEachInvalidEnumerable, + condition.Type)); + + Errors.ThrowIfAnyErrors(); + return Empty(); + } using (_scopeStack.NewScope()) { var enumeratorRef = _scopeStack.GetVariable( Strings.ForEachVariableName, - typeof(IEnumerator)); + getEnumeratorMethod.ReturnType); try { return Block( _scopeStack.GetVariables(), - Assign(enumeratorRef, enumerator), + Assign(enumeratorRef, Call(condition, getEnumeratorMethod)), Loop( IfThenElse( test: Call(enumeratorRef, ReflectionCache.IEnumerator_MoveNext), @@ -454,8 +470,8 @@ public object VisitForEachStatement(ForEachStatementAst forEachStatementAst) Assign( _scopeStack.GetVariable( forEachStatementAst.Variable.VariablePath.UserPath, - typeof(object)), - Property(enumeratorRef, ReflectionCache.IEnumerator_Current)), + getCurrentMethod.ReturnType), + Call(enumeratorRef, getCurrentMethod)), forEachStatementAst.Body.Compile(this) }), ifFalse: Break(_loops.Break)), @@ -1906,5 +1922,61 @@ private MemberBinder GetBinder(Ast ast) _binder = new MemberBinder(BindingFlags.Public, namespaces.ToArray()); return _binder; } + + private bool TryGetEnumeratorMethod( + Type type, + out MethodInfo getEnumeratorMethod, + out MethodInfo getCurrentMethod) + { + var canFallbackToEnumerable = false; + var canFallbackToIDictionary = false; + var interfaces = type.GetInterfaces(); + for (var i = 0; i < interfaces.Length; i++) + { + if (interfaces[i] == typeof(IEnumerable)) + { + canFallbackToEnumerable = true; + continue; + } + + if (interfaces[i] == typeof(IDictionary)) + { + canFallbackToIDictionary = true; + continue; + } + + if (interfaces[i].IsConstructedGenericType && + interfaces[i].GetGenericTypeDefinition() == typeof(IEnumerable<>)) + { + getEnumeratorMethod = interfaces[i].GetMethod( + Strings.GetEnumeratorMethodName, + Type.EmptyTypes); + + getCurrentMethod = + getEnumeratorMethod.ReturnType.GetMethod( + Strings.EnumeratorGetCurrentMethodName, + Type.EmptyTypes); + return true; + } + } + + if (canFallbackToIDictionary) + { + getEnumeratorMethod = ReflectionCache.IDictionary_GetEnumerator; + getCurrentMethod = ReflectionCache.IDictionaryEnumerator_get_Entry; + return true; + } + + if (canFallbackToEnumerable) + { + getEnumeratorMethod = ReflectionCache.IEnumerable_GetEnumerator; + getCurrentMethod = ReflectionCache.IEnumerator_get_Current; + return true; + } + + getEnumeratorMethod = null; + getCurrentMethod = null; + return false; + } } } diff --git a/src/PSLambda/ExpressionUtils.cs b/src/PSLambda/ExpressionUtils.cs index 8142793..f47c9ce 100644 --- a/src/PSLambda/ExpressionUtils.cs +++ b/src/PSLambda/ExpressionUtils.cs @@ -185,7 +185,7 @@ public static Expression PSConvertAllTo(Expression source) collectionVar, Strings.AddMethodName, Type.EmptyTypes, - PSConvertTo(Property(enumeratorVar, ReflectionCache.IEnumerator_Current))), + PSConvertTo(Call(enumeratorVar, ReflectionCache.IEnumerator_get_Current))), Break(breakLabel)), breakLabel), Call(collectionVar, Strings.ToArrayMethodName, Type.EmptyTypes)); @@ -305,7 +305,7 @@ public static Expression PSIsIn(Expression item, Expression items, bool isCaseSe IfThen( PSEquals( item, - Property(enumeratorVar, ReflectionCache.IEnumerator_Current), + Call(enumeratorVar, ReflectionCache.IEnumerator_get_Current), isCaseSensitive), Return(returnLabel, SpecialVariables.Constants[Strings.TrueVariableName])), Return(returnLabel, SpecialVariables.Constants[Strings.FalseVariableName]))), diff --git a/src/PSLambda/ReflectionCache.cs b/src/PSLambda/ReflectionCache.cs index 65bea1f..3bbb9ef 100644 --- a/src/PSLambda/ReflectionCache.cs +++ b/src/PSLambda/ReflectionCache.cs @@ -1,5 +1,6 @@ using System; using System.Collections; +using System.Collections.Generic; using System.Linq; using System.Management.Automation; using System.Reflection; @@ -189,10 +190,16 @@ internal static class ReflectionCache typeof(Hashtable).GetConstructor(new[] { typeof(int), typeof(IEqualityComparer) }); /// - /// Resolves to + /// Resolves to /// - public static readonly PropertyInfo IEnumerator_Current = - typeof(IEnumerator).GetProperty("Current"); + public static readonly MethodInfo IEnumerator_get_Current = + typeof(IEnumerator).GetMethod(Strings.EnumeratorGetCurrentMethodName, Type.EmptyTypes); + + /// + /// Resolves to + /// + public static readonly MethodInfo IEnumerable_GetEnumerator = + typeof(IEnumerable).GetMethod(Strings.GetEnumeratorMethodName, Type.EmptyTypes); /// /// Resolves to . @@ -225,5 +232,29 @@ internal static class ReflectionCache /// public static readonly MethodInfo Monitor_Exit = typeof(System.Threading.Monitor).GetMethod("Exit", new[] { typeof(object) }); + + /// + /// Resolves to . + /// + public static readonly MethodInfo IEnumerable_T_GetEnumerator = + typeof(IEnumerable<>).GetMethod(Strings.GetEnumeratorMethodName, Type.EmptyTypes); + + /// + /// Resolves to . + /// + public static readonly MethodInfo IEnumerator_T_get_Current = + typeof(IEnumerator<>).GetMethod(Strings.EnumeratorGetCurrentMethodName, Type.EmptyTypes); + + /// + /// Resolves to . + /// + public static readonly MethodInfo IDictionary_GetEnumerator = + typeof(IDictionary).GetMethod(Strings.GetEnumeratorMethodName, Type.EmptyTypes); + + /// + /// Resolves to . + /// + public static readonly MethodInfo IDictionaryEnumerator_get_Entry = + typeof(IDictionaryEnumerator).GetMethod("get_Entry", Type.EmptyTypes); } } diff --git a/src/PSLambda/Strings.cs b/src/PSLambda/Strings.cs index 3898c25..c74750a 100644 --- a/src/PSLambda/Strings.cs +++ b/src/PSLambda/Strings.cs @@ -5,6 +5,16 @@ namespace PSLambda /// internal class Strings { + /// + /// Constant containing a string similar to "GetEnumerator". + /// + public const string GetEnumeratorMethodName = "GetEnumerator"; + + /// + /// Constant containing a string similar to "get_Current". + /// + public const string EnumeratorGetCurrentMethodName = "get_Current"; + /// /// Constant containing a string similar to "psdelegate". /// diff --git a/src/PSLambda/resources/ErrorStrings.resx b/src/PSLambda/resources/ErrorStrings.resx index 3796f2a..d5210e3 100644 --- a/src/PSLambda/resources/ErrorStrings.resx +++ b/src/PSLambda/resources/ErrorStrings.resx @@ -168,4 +168,7 @@ '{0}' does not contain a definition for a method named '{1}' that takes the specified arguments. + + The foreach statement cannot operate on variables of type '{0}' because '{0}' does not contain a public definition for 'GetEnumerator' + diff --git a/test/Loops.Tests.ps1 b/test/Loops.Tests.ps1 index 461b4c7..e193f99 100644 --- a/test/Loops.Tests.ps1 +++ b/test/Loops.Tests.ps1 @@ -4,32 +4,93 @@ $manifestPath = "$PSScriptRoot\..\Release\$moduleName\*\$moduleName.psd1" Import-Module $manifestPath -Force Describe 'basic loop functionality' { - It 'for statement' { - $delegate = New-PSDelegate { - [int] $total = 0 - for ([int] $i = 0; $i -lt 10; $i++) { - $total = $i + $total + Context 'foreach statement tests' { + It 'can enumerate IEnumerable<>' { + $delegate = New-PSDelegate { + $total = 0 + foreach($item in 0..10) { + $total = $item + $total + } + + return $total } - return $total + $delegate.Invoke() | Should -Be 55 } - $delegate.Invoke() | Should -Be 45 + It 'can enumerate IDictionary' { + $delegate = New-PSDelegate { + $hashtable = @{ + one = 'two' + three = 'four' + } + + $sb = [System.Text.StringBuilder]::new() + foreach($item in $hashtable) { + $sb.Append($item.Value.ToString()) + } + + return $sb.ToString() + } + + $delegate.Invoke() | Should -Be twofour + } + + It 'prioritizes IEnumerable<> over IDictionary' { + $delegate = New-PSDelegate { + $map = [System.Collections.Generic.Dictionary[string, int]]::new() + $map.Add('test', 10) + $map.Add('test2', 30) + + $results = [System.Collections.Generic.List[int]]::new() + foreach ($item in $map) { + $results.Add($item.Value) + } + + return $results + } + + $delegate.Invoke() | Should -Be 10, 30 + } + + It 'can enumerable IEnumerable' { + $delegate = New-PSDelegate { + $list = [System.Collections.ArrayList]::new() + $list.Add([object]10) + $list.Add('test2') + + $results = [System.Collections.Generic.List[string]]::new() + foreach ($item in $list) { + $results.Add($item.ToString()) + } + + return $results + } + + $delegate.Invoke() | Should -Be 10, test2 + } + + It 'throws the correct message when target is not IEnumerable' { + $expectedMsg = + "The foreach statement cannot operate on variables of type " + + "'System.Int32' because 'System.Int32' does not contain a " + + "public definition for 'GetEnumerator'" + + { New-PSDelegate { foreach ($a in 10) {}}} | Should -Throw $expectedMsg + } } - It 'foreach statement' { + It 'for statement' { $delegate = New-PSDelegate { - [int[]] $numbers = 1, 2, 3, 4 [int] $total = 0 - - foreach($item in $numbers) { - $total = [int]$item + [int]$total + for ([int] $i = 0; $i -lt 10; $i++) { + $total = $i + $total } return $total } - $delegate.Invoke() | Should -Be 10 + $delegate.Invoke() | Should -Be 45 } It 'while statement' {