Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove 'RETURNING' functionality from MultiInserter #7740

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions db/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,9 @@ type Executor interface {
OneSelector
Inserter
SelectExecer
Queryer
Delete(context.Context, ...interface{}) (int64, error)
Get(context.Context, interface{}, ...interface{}) (interface{}, error)
Update(context.Context, ...interface{}) (int64, error)
}

// Queryer offers the QueryContext method. Note that this is not read-only (i.e. not
// Selector), since a QueryContext can be `INSERT`, `UPDATE`, etc. The difference
// between QueryContext and ExecContext is that QueryContext can return rows. So for instance it is
// suitable for inserting rows and getting back ids.
type Queryer interface {
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
}

Expand Down
78 changes: 21 additions & 57 deletions db/multi.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,24 @@ import (
)

// MultiInserter makes it easy to construct a
// `INSERT INTO table (...) VALUES ... RETURNING id;`
// `INSERT INTO table (...) VALUES ...;`
// query which inserts multiple rows into the same table. It can also execute
// the resulting query.
type MultiInserter struct {
// These are validated by the constructor as containing only characters
// that are allowed in an unquoted identifier.
// https://mariadb.com/kb/en/identifier-names/#unquoted
table string
fields []string
returningColumn string
table string
fields []string

values [][]interface{}
}

// NewMultiInserter creates a new MultiInserter, checking for reasonable table
// name and list of fields. returningColumn is the name of a column to be used
// in a `RETURNING xyz` clause at the end. If it is empty, no `RETURNING xyz`
// clause is used. If returningColumn is present, it must refer to a column
// that can be parsed into an int64.
// Safety: `table`, `fields`, and `returningColumn` must contain only strings
// that are known at compile time. They must not contain user-controlled
// strings.
func NewMultiInserter(table string, fields []string, returningColumn string) (*MultiInserter, error) {
// name and list of fields.
// Safety: `table` and `fields` must contain only strings that are known at
// compile time. They must not contain user-controlled strings.
func NewMultiInserter(table string, fields []string) (*MultiInserter, error) {
if len(table) == 0 || len(fields) == 0 {
return nil, fmt.Errorf("empty table name or fields list")
}
Expand All @@ -44,18 +39,11 @@ func NewMultiInserter(table string, fields []string, returningColumn string) (*M
return nil, err
}
}
if returningColumn != "" {
err := validMariaDBUnquotedIdentifier(returningColumn)
if err != nil {
return nil, err
}
}

return &MultiInserter{
table: table,
fields: fields,
returningColumn: returningColumn,
values: make([][]interface{}, 0),
table: table,
fields: fields,
values: make([][]interface{}, 0),
}, nil
}

Expand Down Expand Up @@ -84,56 +72,32 @@ func (mi *MultiInserter) query() (string, []interface{}) {

questions := strings.TrimRight(questionsBuf.String(), ",")

// Safety: we are interpolating `mi.returningColumn` into an SQL query. We
// know it is a valid unquoted identifier in MariaDB because we verified
// that in the constructor.
returning := ""
if mi.returningColumn != "" {
returning = fmt.Sprintf(" RETURNING %s", mi.returningColumn)
}
// Safety: we are interpolating `mi.table` and `mi.fields` into an SQL
// query. We know they contain, respectively, a valid unquoted identifier
// and a slice of valid unquoted identifiers because we verified that in
// the constructor. We know the query overall has valid syntax because we
// generate it entirely within this function.
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s%s", mi.table, strings.Join(mi.fields, ","), questions, returning)
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s", mi.table, strings.Join(mi.fields, ","), questions)

return query, queryArgs
}

// Insert inserts all the collected rows into the database represented by
// `queryer`. If a non-empty returningColumn was provided, then it returns
// the list of values from that column returned by the query.
func (mi *MultiInserter) Insert(ctx context.Context, queryer Queryer) ([]int64, error) {
// `queryer`.
func (mi *MultiInserter) Insert(ctx context.Context, db Execer) error {
query, queryArgs := mi.query()
rows, err := queryer.QueryContext(ctx, query, queryArgs...)
res, err := db.ExecContext(ctx, query, queryArgs...)
if err != nil {
return nil, err
return err
}

ids := make([]int64, 0, len(mi.values))
if mi.returningColumn != "" {
for rows.Next() {
var id int64
err = rows.Scan(&id)
if err != nil {
rows.Close()
return nil, err
}
ids = append(ids, id)
}
affected, err := res.RowsAffected()
if err != nil {
return err
}

// Hack: sometimes in unittests we make a mock Queryer that returns a nil
// `*sql.Rows`. A nil `*sql.Rows` is not actually valid— calling `Close()`
// on it will panic— but here we choose to treat it like an empty list,
// and skip calling `Close()` to avoid the panic.
if rows != nil {
err = rows.Close()
if err != nil {
return nil, err
}
if affected != int64(len(mi.values)) {
return fmt.Errorf("unexpected number of rows inserted: %d != %d", affected, len(mi.values))
}

return ids, nil
return nil
}
32 changes: 8 additions & 24 deletions db/multi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,29 @@ import (
)

func TestNewMulti(t *testing.T) {
_, err := NewMultiInserter("", []string{"colA"}, "")
_, err := NewMultiInserter("", []string{"colA"})
test.AssertError(t, err, "Empty table name should fail")

_, err = NewMultiInserter("myTable", nil, "")
_, err = NewMultiInserter("myTable", nil)
test.AssertError(t, err, "Empty fields list should fail")

mi, err := NewMultiInserter("myTable", []string{"colA"}, "")
mi, err := NewMultiInserter("myTable", []string{"colA"})
test.AssertNotError(t, err, "Single-column construction should not fail")
test.AssertEquals(t, len(mi.fields), 1)

mi, err = NewMultiInserter("myTable", []string{"colA", "colB", "colC"}, "")
mi, err = NewMultiInserter("myTable", []string{"colA", "colB", "colC"})
test.AssertNotError(t, err, "Multi-column construction should not fail")
test.AssertEquals(t, len(mi.fields), 3)

_, err = NewMultiInserter("", []string{"colA"}, "colB")
test.AssertError(t, err, "expected error for empty table name")
_, err = NewMultiInserter("foo\"bar", []string{"colA"}, "colB")
_, err = NewMultiInserter("foo\"bar", []string{"colA"})
test.AssertError(t, err, "expected error for invalid table name")

_, err = NewMultiInserter("myTable", []string{"colA", "foo\"bar"}, "colB")
_, err = NewMultiInserter("myTable", []string{"colA", "foo\"bar"})
test.AssertError(t, err, "expected error for invalid column name")

_, err = NewMultiInserter("myTable", []string{"colA"}, "foo\"bar")
test.AssertError(t, err, "expected error for invalid returning column name")
}

func TestMultiAdd(t *testing.T) {
mi, err := NewMultiInserter("table", []string{"a", "b", "c"}, "")
mi, err := NewMultiInserter("table", []string{"a", "b", "c"})
test.AssertNotError(t, err, "Failed to create test MultiInserter")

err = mi.Add([]interface{}{})
Expand All @@ -57,7 +52,7 @@ func TestMultiAdd(t *testing.T) {
}

func TestMultiQuery(t *testing.T) {
mi, err := NewMultiInserter("table", []string{"a", "b", "c"}, "")
mi, err := NewMultiInserter("table", []string{"a", "b", "c"})
test.AssertNotError(t, err, "Failed to create test MultiInserter")
err = mi.Add([]interface{}{"one", "two", "three"})
test.AssertNotError(t, err, "Failed to insert test row")
Expand All @@ -67,15 +62,4 @@ func TestMultiQuery(t *testing.T) {
query, queryArgs := mi.query()
test.AssertEquals(t, query, "INSERT INTO table (a,b,c) VALUES (?,?,?),(?,?,?)")
test.AssertDeepEquals(t, queryArgs, []interface{}{"one", "two", "three", "egy", "kettö", "három"})

mi, err = NewMultiInserter("table", []string{"a", "b", "c"}, "id")
test.AssertNotError(t, err, "Failed to create test MultiInserter")
err = mi.Add([]interface{}{"one", "two", "three"})
test.AssertNotError(t, err, "Failed to insert test row")
err = mi.Add([]interface{}{"egy", "kettö", "három"})
test.AssertNotError(t, err, "Failed to insert test row")

query, queryArgs = mi.query()
test.AssertEquals(t, query, "INSERT INTO table (a,b,c) VALUES (?,?,?),(?,?,?) RETURNING id")
test.AssertDeepEquals(t, queryArgs, []interface{}{"one", "two", "three", "egy", "kettö", "három"})
}
8 changes: 1 addition & 7 deletions features/features.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type Config struct {
CertCheckerRequiresCorrespondence bool
ECDSAForAll bool
CheckRenewalExemptionAtWFE bool
InsertAuthzsIndividually bool

// ServeRenewalInfo exposes the renewalInfo endpoint in the directory and for
// GET requests. WARNING: This feature is a draft and highly unstable.
Expand Down Expand Up @@ -115,13 +116,6 @@ type Config struct {
//
// This flag should only be used in conjunction with UseKvLimitsForNewOrder.
DisableLegacyLimitWrites bool

// InsertAuthzsIndividually causes the SA's NewOrderAndAuthzs method to
// create each new authz one at a time, rather than using MultiInserter.
// Although this is expected to be a performance penalty, it is necessary to
// get the AUTO_INCREMENT ID of each new authz without relying on MariaDB's
// unique "INSERT ... RETURNING" functionality.
InsertAuthzsIndividually bool
}

var fMu = new(sync.RWMutex)
Expand Down
7 changes: 3 additions & 4 deletions sa/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -1047,12 +1047,12 @@ func deleteOrderFQDNSet(
return nil
}

func addIssuedNames(ctx context.Context, queryer db.Queryer, cert *x509.Certificate, isRenewal bool) error {
func addIssuedNames(ctx context.Context, queryer db.Execer, cert *x509.Certificate, isRenewal bool) error {
if len(cert.DNSNames) == 0 {
return berrors.InternalServerError("certificate has no DNSNames")
}

multiInserter, err := db.NewMultiInserter("issuedNames", []string{"reversedName", "serial", "notBefore", "renewal"}, "")
multiInserter, err := db.NewMultiInserter("issuedNames", []string{"reversedName", "serial", "notBefore", "renewal"})
if err != nil {
return err
}
Expand All @@ -1067,8 +1067,7 @@ func addIssuedNames(ctx context.Context, queryer db.Queryer, cert *x509.Certific
return err
}
}
_, err = multiInserter.Insert(ctx, queryer)
return err
return multiInserter.Insert(ctx, queryer)
}

func addKeyHash(ctx context.Context, db db.Inserter, cert *x509.Certificate) error {
Expand Down
59 changes: 11 additions & 48 deletions sa/sa.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"encoding/json"
"errors"
"fmt"
"strings"
"time"

"github.com/jmhodges/clock"
Expand Down Expand Up @@ -473,53 +472,17 @@ func (ssa *SQLStorageAuthority) NewOrderAndAuthzs(ctx context.Context, req *sapb

output, err := db.WithTransaction(ctx, ssa.dbMap, func(tx db.Executor) (interface{}, error) {
// First, insert all of the new authorizations and record their IDs.
newAuthzIDs := make([]int64, 0)
if features.Get().InsertAuthzsIndividually {
for _, authz := range req.NewAuthzs {
am, err := newAuthzReqToModel(authz)
if err != nil {
return nil, err
}
err = tx.Insert(ctx, am)
if err != nil {
return nil, err
}
newAuthzIDs = append(newAuthzIDs, am.ID)
newAuthzIDs := make([]int64, 0, len(req.NewAuthzs))
for _, authz := range req.NewAuthzs {
am, err := newAuthzReqToModel(authz)
if err != nil {
return nil, err
}
} else {
if len(req.NewAuthzs) != 0 {
inserter, err := db.NewMultiInserter("authz2", strings.Split(authzFields, ", "), "id")
if err != nil {
return nil, err
}
for _, authz := range req.NewAuthzs {
am, err := newAuthzReqToModel(authz)
if err != nil {
return nil, err
}
err = inserter.Add([]interface{}{
am.ID,
am.IdentifierType,
am.IdentifierValue,
am.RegistrationID,
statusToUint[core.StatusPending],
am.Expires,
am.Challenges,
nil,
nil,
am.Token,
nil,
nil,
})
if err != nil {
return nil, err
}
}
newAuthzIDs, err = inserter.Insert(ctx, tx)
if err != nil {
return nil, err
}
err = tx.Insert(ctx, am)
if err != nil {
return nil, err
}
newAuthzIDs = append(newAuthzIDs, am.ID)
}

// Second, insert the new order.
Expand Down Expand Up @@ -549,7 +512,7 @@ func (ssa *SQLStorageAuthority) NewOrderAndAuthzs(ctx context.Context, req *sapb
}

// Third, insert all of the orderToAuthz relations.
inserter, err := db.NewMultiInserter("orderToAuthz2", []string{"orderID", "authzID"}, "")
inserter, err := db.NewMultiInserter("orderToAuthz2", []string{"orderID", "authzID"})
if err != nil {
return nil, err
}
Expand All @@ -565,7 +528,7 @@ func (ssa *SQLStorageAuthority) NewOrderAndAuthzs(ctx context.Context, req *sapb
return nil, err
}
}
_, err = inserter.Insert(ctx, tx)
err = inserter.Insert(ctx, tx)
if err != nil {
return nil, err
}
Expand Down
Loading
Loading