-
-
Notifications
You must be signed in to change notification settings - Fork 5
/
join.go
88 lines (77 loc) · 1.54 KB
/
join.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
package linq
type joinEnumerator[S1, S2, T any, K comparable] struct {
eOut Enumerator[S1]
eIn Enumerator[S2]
ksOut func(S1) (K, error)
ksIn func(S2) (K, error)
rSel func(S1, S2) (T, error)
s1 *S1
ks1 K
ms2 map[K][]S2
i int
}
// Join correlates the elements of two sequences based on matching keys.
func Join[S1, S2, T any, K comparable, E1 IEnumerable[S1], E2 IEnumerable[S2]](
outer E1,
inner E2,
outerKeySelector func(S1) (K, error),
innerKeySelector func(S2) (K, error),
resultSelector func(S1, S2) (T, error),
) Enumerable[T] {
return func() Enumerator[T] {
return &joinEnumerator[S1, S2, T, K]{
eOut: outer(),
eIn: inner(),
ksOut: outerKeySelector,
ksIn: innerKeySelector,
rSel: resultSelector,
}
}
}
func (e *joinEnumerator[S1, S2, T, K]) Next() (def T, _ error) {
if e.s1 == nil {
s1, err := e.eOut.Next()
if err != nil {
return def, err
}
ks1, err := e.ksOut(s1)
if err != nil {
return def, err
}
e.s1 = &s1
e.ks1 = ks1
}
if e.ms2 == nil {
m, err := innerMap(e.eIn, e.ksIn)
if err != nil {
return def, err
}
e.ms2 = m
}
s := e.ms2[e.ks1]
if e.i >= len(s) {
e.i = 0
e.s1 = nil
return e.Next()
}
i := e.i
e.i++
return e.rSel(*e.s1, s[i])
}
func innerMap[S2 any, K comparable](e Enumerator[S2], ks func(S2) (K, error)) (map[K][]S2, error) {
m := make(map[K][]S2)
for {
s2, err := e.Next()
if err != nil {
if isEOC(err) {
return m, nil
}
return nil, err
}
k, err := ks(s2)
if err != nil {
return nil, err
}
m[k] = append(m[k], s2)
}
}