-
Notifications
You must be signed in to change notification settings - Fork 2
/
tester.go
278 lines (245 loc) · 6.76 KB
/
tester.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
package sol
import (
"fmt"
"reflect"
"runtime"
"strings"
"testing"
"time"
"github.com/aodin/sol/dialect"
"github.com/aodin/sol/types"
)
// callerInfo returns a string containing the file and line number of the
// assert call that failed.
// https://github.com/stretchr/testify/blob/master/assert/assertions.go
// Copyright (c) 2012 - 2013 Mat Ryer and Tyler Bunnell
func callerInfo() string {
file := ""
line := 0
ok := false
for i := 0; ; i++ {
_, file, line, ok = runtime.Caller(i)
if !ok {
return ""
}
if file == "<autogenerated>" {
break
}
parts := strings.Split(file, "/")
file = parts[len(parts)-1]
// dir := parts[len(parts)-2]
if file == "tester.go" {
continue
}
break
}
return fmt.Sprintf("%s:%d", file, line)
}
type tester struct {
t *testing.T
dialect dialect.Dialect
}
// Error tests that the given Compiles instances generates an error for the
// current dialect.
func (t *tester) Error(stmt Compiles) {
// TODO Allow a specific error
if _, err := stmt.Compile(t.dialect, Params()); err == nil {
t.t.Errorf("%s: expected error, received nil", callerInfo())
}
}
// SQL tests that the given Compiles instance matches the expected string for
// the current dialect.
func (t *tester) SQL(stmt Compiles, expect string, ps ...interface{}) {
// Get caller information in case of failure
caller := callerInfo()
// Start a new parameters instance
params := Params()
// Compile the given stmt with the tester's dialect
actual, err := stmt.Compile(t.dialect, params)
if err != nil {
t.t.Errorf("%s:\n\rUnexpected error from Compile: %s", caller, err)
return
}
if expect != actual {
t.t.Errorf(
"%s:\n\rUnexpected SQL from Compile: \n\r - have: %s\n\r - want: %s",
caller, actual, expect,
)
}
// Test that the parameters are equal
if len(*params) != len(ps) {
t.t.Errorf(
"%s:\n\rUnexpected parameter length for %T: \n\r - have %d, want %d",
caller, stmt, len(ps), len(*params),
)
return
}
// Examine individual parameters for equality
for i, param := range *params {
if !reflect.DeepEqual(ps[i], param) {
t.t.Errorf(
"%s:\n\rUnequal parameters at index %d: \n\r - have %#v (%T), want %#v (%T)",
caller, i, param, param, ps[i], ps[i],
)
}
}
}
// NewTester creates a new SQL/Error tester that uses the given dialect
func NewTester(t *testing.T, d dialect.Dialect) *tester {
return &tester{t: t, dialect: d}
}
// IntegrationTest runs a large, neutral dialect test
func IntegrationTest(t *testing.T, conn *DB, ddlCommit bool) {
// Perform all tests in a transaction
// TODO What features should be tested outside of a transaction?
// CREATE TABLE is performed outside of the transaction because any
// change to the DDL in MySQL is a implicit commit
// Other databases: http://stackoverflow.com/a/4736346
testusers := Table("testusers",
Column("id", types.Integer()),
Column("email", types.Varchar().Limit(255).NotNull()),
Column("is_admin", types.Boolean().NotNull()),
Column("created_at", types.Timestamp()),
PrimaryKey("id"),
Unique("email"),
)
type testuser struct {
ID int64
Email string
IsAdmin bool
CreatedAt time.Time
}
tx, err := conn.Begin()
if err != nil {
t.Fatalf("Creating a new transaction should not error: %s", err)
}
defer tx.Rollback()
if ddlCommit {
if err = conn.Query(testusers.Create().IfNotExists()); err != nil {
t.Fatalf("CREATE TABLE should not error: %s", err)
}
} else {
if err = tx.Query(testusers.Create().IfNotExists()); err != nil {
t.Fatalf("CREATE TABLE should not error: %s", err)
}
}
// INSERT by struct
// Truncate the time.Time field to avoid significant digit errors
admin := testuser{
ID: 1,
Email: "[email protected]",
IsAdmin: true,
CreatedAt: time.Now().UTC().Truncate(time.Second),
}
if err = tx.Query(testusers.Insert().Values(admin)); err != nil {
t.Fatalf("INSERT by struct should not fail %s", err)
}
// SELECT
var selected testuser
if err = tx.Query(
testusers.Select().Where(testusers.C("id").Equals(admin.ID)),
&selected,
); err != nil {
t.Fatalf("SELECT should not fail: %s", err)
}
// TODO test with direct comparison: selected == admin
// For now, test each field since DATETIME handling is terribly
// inconsistent across databases
if selected.ID != admin.ID {
t.Errorf(
"Unequal testusers id: have %d, want %d",
selected.ID, admin.ID,
)
}
if selected.Email != admin.Email {
t.Errorf(
"Unequal testusers email: have %s, want %s",
selected.Email, admin.Email,
)
}
if selected.IsAdmin != admin.IsAdmin {
t.Errorf(
"Unequal testusers is_admin: have %t, want %t",
selected.IsAdmin, admin.IsAdmin,
)
}
if !selected.CreatedAt.Equal(admin.CreatedAt) {
t.Errorf(
"Unequal testusers created_at: have %v, want %v",
selected.CreatedAt, admin.CreatedAt,
)
}
// UPDATE
if err = tx.Query(
testusers.Update().Values(
Values{"is_admin": false},
).Where(testusers.C("id").Equals(admin.ID)),
); err != nil {
t.Fatalf("UPDATE should not fail: %s", err)
}
var updated testuser
if err = tx.Query(testusers.Select().Limit(1), &updated); err != nil {
t.Fatalf("SELECT should not fail: %s", err)
}
selected.IsAdmin = false
if updated != selected {
t.Errorf(
"Unequal testusers: have %+v, want %+v",
updated, selected,
)
}
// INSERT by values
client := Values{
"id": 2,
"email": "[email protected]",
"is_admin": false,
"created_at": time.Now().UTC().Truncate(time.Second),
}
if err = tx.Query(testusers.Insert().Values(client)); err != nil {
t.Fatalf("INSERT by values should not fail %s", err)
}
var list []testuser
if err = tx.Query(
testusers.Select().OrderBy(testusers.C("id").Desc()),
&list,
); err != nil {
t.Fatalf("SELECT with ORDER BY should not fail: %s", err)
}
if len(list) != 2 {
t.Fatalf("Unexpected length of list: want 2, have %d", len(list))
}
// The client should be first
if list[0].Email != "[email protected]" {
t.Errorf(
"Unexpected email: want [email protected], have %s",
list[0].Email,
)
}
var count int64
if err = tx.Query(Select(Count(testusers.C("id"))), &count); err != nil {
t.Fatalf("SELECT with COUNT should not fail: %s", err)
}
if count != 2 {
t.Errorf("Unexpected COUNT: want 2, have %d", count)
}
// DELETE
if err = tx.Query(
testusers.Delete().Where(testusers.C("email").Equals(admin.Email)),
); err != nil {
t.Fatalf("DELETE should not fail: %s", err)
}
// DROP TABLE
// TODO Since this is a DDL, this will likely commit in MySQL
if err = tx.Query(testusers.Drop()); err != nil {
t.Fatalf("DROP TABLE should not fail %s", err)
}
// Test a recover
func() {
defer func() {
if panicked := recover(); panicked == nil {
t.Errorf("Connection failed to panic on error")
}
}()
conn.Must().Query(testusers.Select(), list)
}()
}