diff --git a/controller/db/select.go b/controller/db/select.go index 06370a4..d378ef6 100644 --- a/controller/db/select.go +++ b/controller/db/select.go @@ -3,9 +3,8 @@ package controller import ( "encoding/json" "io/ioutil" - + "strings" "net/http" - "github.com/DropKit/DropKit-Adapter/constants" "github.com/DropKit/DropKit-Adapter/logger" "github.com/DropKit/DropKit-Adapter/package/crypto/account" @@ -18,6 +17,25 @@ import ( "github.com/spf13/viper" ) +func chekSelectAll(sqlCommand string, columnsCanSelect []string) string { + var idx int = -1; + for i, v := range sqlCommand { + if(v == '*'){ + idx = i + break + } + } + if(idx == -1){ + return sqlCommand + } + + sqlSlice := []byte(sqlCommand[0:idx]) + columnsStr := strings.Join(columnsCanSelect, ",") + sqlSlice = append(sqlSlice, []byte(columnsStr)...) + sqlSlice = append(sqlSlice, []byte(sqlCommand[idx+1:])...) + return string(sqlSlice) +} + func HandleDBSelection(w http.ResponseWriter, r *http.Request) { body, err := ioutil.ReadAll(r.Body) if err != nil { @@ -39,6 +57,7 @@ func HandleDBSelection(w http.ResponseWriter, r *http.Request) { } sqlCommand := newStatement.Statement + callerPriavteKey := newStatement.PrivateKey callerAddress, err := account.PrivateKeyToPublicKey(callerPriavteKey) if err != nil { @@ -51,12 +70,7 @@ func HandleDBSelection(w http.ResponseWriter, r *http.Request) { services.NormalResponse(w, response.SQLResponseBadSQLStatement()) return } - - columnsNames, err := columns.GetSelectColumns(sqlCommand) - if err != nil { - services.NormalResponse(w, response.SQLResponseBadSQLStatement()) - return - } + result, err := services.HasTableUserRole(callerPriavteKey, callerAddress, tableName) if err != nil { @@ -66,11 +80,22 @@ func HandleDBSelection(w http.ResponseWriter, r *http.Request) { switch result { case true: + columnsCanSelect, err := services.GetColumnsRole(callerPriavteKey, callerAddress, tableName) + if err != nil { services.NormalResponse(w, response.ResponseInternalError()) return } + + sqlCommand = chekSelectAll(sqlCommand, columnsCanSelect) + + columnsNames, err := columns.GetSelectColumns(sqlCommand) + if err != nil { + services.NormalResponse(w, response.SQLResponseBadSQLStatement()) + return + } + columnsAuth := utils.CompareColumns(columnsCanSelect, columnsNames) switch columnsAuth {