diff --git a/main.go b/main.go index d21be5c..31de651 100644 --- a/main.go +++ b/main.go @@ -8,6 +8,7 @@ import ( "log" "net/http" "os" + "strconv" "github.com/film42/pgreba/config" "github.com/gorilla/handlers" @@ -18,41 +19,6 @@ type HealthCheckWebService struct { healthChecker *HealthChecker } -// func (hc *HealthCheckWebService) getSlotHealthCheck(w http.ResponseWriter, r *http.Request) { -// // Get request info -// w.Header().Set("Content-Type", "application/json") -// params := mux.Vars(r) -// slotName := params["slot_name"] - -// // Perform the health check. -// err := hc.healthChecker.CheckReplicationSlot(slotName) - -// // If the slot is OK, return status: ok. -// if err == nil { -// json.NewEncoder(w).Encode(map[string]string{ -// "status": "ok", -// "slot": slotName, -// }) -// return -// } - -// // If there was an error, set the appropriate status code. -// switch err { -// case ErrReplicationSlotNotFound: -// w.WriteHeader(http.StatusNotFound) -// case ErrReplicationSlotLagTooHigh: -// w.WriteHeader(http.StatusServiceUnavailable) -// default: -// w.WriteHeader(http.StatusInternalServerError) -// } - -// // Return error to the client. -// json.NewEncoder(w).Encode(map[string]string{ -// "error": err.Error(), -// "slot": slotName, -// }) -// } - func (hc *HealthCheckWebService) apiGetIsPrimary(w http.ResponseWriter, r *http.Request) { nodeInfo, err := hc.healthChecker.dataSource.GetNodeInfo() if err != nil { @@ -78,9 +44,28 @@ func (hc *HealthCheckWebService) apiGetIsReplica(w http.ResponseWriter, r *http. w.WriteHeader(http.StatusServiceUnavailable) } + // if byte lag exceeds max_allowable_byte_lag then return 500 + if maxAllowableByteLagExceeded(r, nodeInfo) { + w.WriteHeader(http.StatusServiceUnavailable) + } + json.NewEncoder(w).Encode(nodeInfo) } +func maxAllowableByteLagExceeded(r *http.Request, nodeInfo *NodeInfo) bool { + maxAllowableByteLagString := r.URL.Query().Get("max_allowable_byte_lag") + if len(maxAllowableByteLagString) == 0 { + return true + } + + maxAllowableByteLag, err := strconv.ParseInt(maxAllowableByteLagString, 10, 64) + if err != nil { + panic(err) + } + + return nodeInfo.ByteLag > maxAllowableByteLag +} + func main() { versionPtr := flag.Bool("version", false, "Print the teecp version and exit.") flag.Parse() @@ -118,7 +103,6 @@ func main() { return handlers.LoggingHandler(log.Writer(), next) }) - // For primary nodes router.HandleFunc("/", hcs.apiGetIsPrimary).Methods("GET") router.HandleFunc("/primary", hcs.apiGetIsPrimary).Methods("GET")