From a67369c8a62ded1f6988ae58e4d09e5c12b1eee2 Mon Sep 17 00:00:00 2001 From: Christian Roessner Date: Mon, 9 Dec 2024 12:10:47 +0100 Subject: [PATCH] Fix: Refactor session handling and error handling. Revised session validation logic, removed redundant code, and replaced hardcoded protocol checks with a lookup table for better maintainability. Introduced a new error message "ErrFilterFailed" to provide more specific feedback on filter execution failures. These updates aim to improve code readability, error traceability, and system robustness. Signed-off-by: Christian Roessner --- server/errors/errors.go | 1 + server/lua-plugins.d/filters/monitoring.lua | 152 ++++++++---------- .../hooks/dovecot-session-cleaner.lua | 5 +- server/lualib/filter/filter.go | 20 ++- 4 files changed, 84 insertions(+), 94 deletions(-) diff --git a/server/errors/errors.go b/server/errors/errors.go index 02b94dd3..93c2c577 100644 --- a/server/errors/errors.go +++ b/server/errors/errors.go @@ -214,6 +214,7 @@ var ( ErrNoFiltersDefined = errors.New("no filters defined") ErrFilterLuaNameMissing = errors.New("filter 'name' sttribute missing") ErrFilterLuaScriptPathEmpty = errors.New("filter 'script_path' attribute missing") + ErrFilterFailed = errors.New("filter failed") ) // misc. diff --git a/server/lua-plugins.d/filters/monitoring.lua b/server/lua-plugins.d/filters/monitoring.lua index dcbc43cd..ea15b580 100644 --- a/server/lua-plugins.d/filters/monitoring.lua +++ b/server/lua-plugins.d/filters/monitoring.lua @@ -16,37 +16,24 @@ dynamic_loader("nauthilus_backend") local nauthilus_backend = require("nauthilus_backend") -local N = "monitoring" - -local wanted_protocols = { - "imap", "imapa", "pop3", "pop3s", "lmtp", "lmtps", - "sieve", -- Not sure about this +local N = "director" + +local WANTED_PROTOCOLS = { + imap = true, + imapa = true, + pop3 = true, + pop3s = true, + lmtp = true, + lmtps = true, + sieve = true, } function nauthilus_call_filter(request) - local skip_and_accept_filter = false - - -- Dovecot userdb request - if request.authenticated and request.no_auth then - skip_and_accept_filter = true - end - - -- Dovecot passdb request - if request.authenticated and not request.no_auth then - skip_and_accept_filter = true - - for _, proto in ipairs(wanted_protocols) do - if proto == request.protocol then - skip_and_accept_filter = false - - break - end - end + if not request.authenticated then + return nauthilus_builtin.FILTER_REJECT, nauthilus_builtin.FILTER_RESULT_OK end - if skip_and_accept_filter then - nauthilus_backend.remove_from_backend_result({ "Proxy-Host" }) - + if not WANTED_PROTOCOLS[request.protocol] then return nauthilus_builtin.FILTER_ACCEPT, nauthilus_builtin.FILTER_RESULT_OK end @@ -66,6 +53,7 @@ function nauthilus_call_filter(request) dynamic_loader("nauthilus_redis") local nauthilus_redis = require("nauthilus_redis") + local redis_key = "ntc:DS:" .. request.account local custom_pool = "default" local custom_pool_name = os.getenv("CUSTOM_REDIS_POOL_NAME") @@ -76,7 +64,7 @@ function nauthilus_call_filter(request) nauthilus_util.if_error_raise(err_redis_client) end - local function set_initial_expiry(redis_key) + local function set_initial_expiry() local length, err_redis_hlen = nauthilus_redis.redis_hlen(custom_pool, redis_key) if err_redis_hlen then if err_redis_hlen ~= "redis: nil" then @@ -90,16 +78,13 @@ function nauthilus_call_filter(request) end end - dynamic_loader("nauthilus_gluacrypto") - local crypto = require("crypto") - - local function add_session(session, server) - if session == nil then - return - end + local function invalidate_stale_sessions() + local _, err_redis_hdel = nauthilus_redis.redis_del(custom_pool, redis_key) - local redis_key = "ntc:DS:" .. crypto.md5(request.account) + nauthilus_util.if_error_raise(err_redis_hdel) + end + local function add_session(session, server) local _, err_redis_hset = nauthilus_redis.redis_hset(custom_pool, redis_key, session, server) if err_redis_hset then nauthilus_builtin.custom_log_add(N .. "_redis_hset_error", err_redis_hset) @@ -107,13 +92,11 @@ function nauthilus_call_filter(request) return end - set_initial_expiry(redis_key) + set_initial_expiry() nauthilus_builtin.custom_log_add(N .. "_dovecot_session", session) end local function get_server_from_sessions(session) - local redis_key = "ntc:DS:" .. crypto.md5(request.account) - local server_from_session, err_redis_hget = nauthilus_redis.redis_hget(custom_pool, redis_key, session) if err_redis_hget then if err_redis_hget ~= "redis: nil" then @@ -143,70 +126,75 @@ function nauthilus_call_filter(request) return nil end - -- Only look for backend servers, if a user was authenticated (passdb requests) - if request.authenticated and not request.no_auth then - local num_of_bs = 0 + local function preprocess_backend_servers(backend_servers) + local valid_servers = {} - local backend_servers = nauthilus_backend.get_backend_servers() - if nauthilus_util.is_table(backend_servers) then - num_of_bs = nauthilus_util.table_length(backend_servers) - - local server_host = "" - local new_server_host = "" - - local session = get_dovecot_session() - if session then - local maybe_server = get_server_from_sessions(session) - if maybe_server then - server_host = maybe_server - end + for _, server in ipairs(backend_servers) do + if server.protocol == request.protocol then + table.insert(valid_servers, server) end + end - if num_of_bs > 0 then - local attributes = {} - - local b = nauthilus_backend_result.new() + return valid_servers + end - for _, server in ipairs(backend_servers) do - new_server_host = server.host + local server_host + local session = get_dovecot_session() - if server_host == new_server_host then - attributes["Proxy-Host"] = server_host + if session then + local valid_servers = preprocess_backend_servers(nauthilus_backend.get_backend_servers()) + local num_of_bs = nauthilus_util.table_length(valid_servers) - add_session(session, server_host) - nauthilus_builtin.custom_log_add(N .. "_backend_server_current", server_host) + if num_of_bs > 0 then + local maybe_server = get_server_from_sessions(session) - b:attributes(attributes) - nauthilus_backend.apply_backend_result(b) + if maybe_server then + for _, server in ipairs(valid_servers) do + if server.host == maybe_server then + server_host = maybe_server break end end - if server_host ~= new_server_host then - -- Put your own logic here to select a proper server for the user. In this demo, the last server - -- available is always used. - attributes["Proxy-Host"] = new_server_host + if not server_host then + invalidate_stale_sessions() - add_session(session, new_server_host) - nauthilus_builtin.custom_log_add(N .. "_backend_server_new", new_server_host) - - b:attributes(attributes) - nauthilus_backend.apply_backend_result(b) + server_host = valid_servers[math.random(1, num_of_bs)].host end + else + server_host = valid_servers[math.random(1, num_of_bs)].host end end - if num_of_bs == 0 then - nauthilus_builtin.custom_log_add(N .. "_backend_server", "failed") - nauthilus_builtin.status_message_set("No backend servers are available") + if server_host then + local backend_result = nauthilus_backend_result.new() + local attributes = {} + + add_session(session, server_host) - return nauthilus_builtin.FILTER_ACCEPT, nauthilus_builtin.FILTER_RESULT_FAIL + local expected_server = get_server_from_sessions(session) + + -- Another client might have been faster at the same point in time... + if expected_server and server_host ~= expected_server then + server_host = expected_server + end + + attributes["Proxy-Host"] = server_host + + nauthilus_builtin.custom_log_add(N .. "_backend_server", server_host) + + backend_result:attributes(attributes) + nauthilus_backend.apply_backend_result(backend_result) end + end - return nauthilus_builtin.FILTER_ACCEPT, nauthilus_builtin.FILTER_RESULT_OK + if server_host == nil then + nauthilus_builtin.custom_log_add(N .. "_backend_server", "failed") + nauthilus_builtin.status_message_set("No backend servers are available") + + return nauthilus_builtin.FILTER_ACCEPT, nauthilus_builtin.FILTER_RESULT_FAIL end - -- Anything else must be a rejected request - return nauthilus_builtin.FILTER_REJECT, nauthilus_builtin.FILTER_RESULT_OK + return nauthilus_builtin.FILTER_ACCEPT, nauthilus_builtin.FILTER_RESULT_OK end diff --git a/server/lua-plugins.d/hooks/dovecot-session-cleaner.lua b/server/lua-plugins.d/hooks/dovecot-session-cleaner.lua index 70932289..79836af3 100644 --- a/server/lua-plugins.d/hooks/dovecot-session-cleaner.lua +++ b/server/lua-plugins.d/hooks/dovecot-session-cleaner.lua @@ -21,9 +21,6 @@ local nauthilus_redis = require("nauthilus_redis") dynamic_loader("nauthilus_http_request") local nauthilus_http_request = require("nauthilus_http_request") -dynamic_loader("nauthilus_gluacrypto") -local crypto = require("crypto") - dynamic_loader("nauthilus_gll_json") local json = require("json") @@ -108,7 +105,7 @@ function nauthilus_run_hook(logging, session) if result.category == "service:imap" or result.category == "service:pop3" or result.category == "service:lmtp" or result.category == "service:sieve" then if result.dovecot_session ~= "unknown" then - local redis_key = "ntc:DS:" .. crypto.md5(result.user) + local redis_key = "ntc:DS:" .. result.user if is_cmd_noop then result.cmd = "NOOP" diff --git a/server/lualib/filter/filter.go b/server/lualib/filter/filter.go index af239e8e..5e0422e1 100644 --- a/server/lualib/filter/filter.go +++ b/server/lualib/filter/filter.go @@ -18,6 +18,7 @@ package filter import ( "context" stderrors "errors" + "fmt" "net/http" "sync" "time" @@ -445,6 +446,8 @@ func setRequest(r *Request, L *lua.LState) *lua.LTable { // It also calls the Lua function with the given parameters and logs the result. // The function will return a boolean indicating whether the Lua function was called successfully, and an error if any occurred. func executeScriptWithinContext(request *lua.LTable, script *LuaFilter, r *Request, ctx *gin.Context, L *lua.LState) (bool, error) { + var err error + stopTimer := stats.PrometheusTimer(definitions.PromFilter, script.Name) if stopTimer != nil { @@ -486,11 +489,15 @@ func executeScriptWithinContext(request *lua.LTable, script *LuaFilter, r *Reque logResult(r, script, action, result) + if result != 0 { + err = fmt.Errorf("%v: %s", errors.ErrFilterFailed, script.Name) + } + if action { - return true, nil + return true, err } - return false, nil + return false, err } // logError is a function that logs error information when a LuaFilter script fails during a Request session. @@ -516,12 +523,9 @@ func logResult(r *Request, script *LuaFilter, action bool, ret int) { "result", resultMap[ret], } - if ret != 0 { - - if r.Logs != nil { - for index := range *r.Logs { - logs = append(logs, (*r.Logs)[index]) - } + if r.Logs != nil { + for index := range *r.Logs { + logs = append(logs, (*r.Logs)[index]) } }