diff --git a/lua/kulala/cmd/init.lua b/lua/kulala/cmd/init.lua index cc8b218..37d08e5 100644 --- a/lua/kulala/cmd/init.lua +++ b/lua/kulala/cmd/init.lua @@ -123,7 +123,12 @@ M.run_parser = function(req, callback) end end INT_PROCESSING.redirect_response_body_to_file(result.redirect_response_body_to_files) - PARSER.scripts.javascript.run("post_request", result.scripts.post_request) + + local has_post_request_scripts = #result.scripts.post_request.inline > 0 + or #result.scripts.pre_request.files > 0 + if has_post_request_scripts then + PARSER.scripts.javascript.run("post_request", result.scripts.post_request) + end Api.trigger("after_request") end Fs.delete_request_scripts_files() diff --git a/lua/kulala/globals/init.lua b/lua/kulala/globals/init.lua index 377e9b0..528404d 100644 --- a/lua/kulala/globals/init.lua +++ b/lua/kulala/globals/init.lua @@ -4,7 +4,7 @@ local M = {} local plugin_tmp_dir = FS.get_plugin_tmp_dir() -M.VERSION = "4.0.2" +M.VERSION = "4.0.3" M.UI_ID = "kulala://ui" M.SCRATCHPAD_ID = "kulala://scratchpad" M.HEADERS_FILE = plugin_tmp_dir .. "/headers.txt" diff --git a/lua/kulala/internal_processing/init.lua b/lua/kulala/internal_processing/init.lua index d62bbe3..800b6de 100644 --- a/lua/kulala/internal_processing/init.lua +++ b/lua/kulala/internal_processing/init.lua @@ -18,18 +18,39 @@ local function get_nested_value(t, key) return value end -local get_headers_as_table = function() +---Function to get the last headers as a table +---@description Reads the headers file and returns the headers as a table. +---In some cases the headers file might contain multiple header sections, +---e.g. if you have follow-redirections enabled. +---This function will return the headers of the last response. +---@return table +local get_last_headers_as_table = function() local headers_file = FS.read_file(GLOBALS.HEADERS_FILE):gsub("\r\n", "\n") local lines = vim.split(headers_file, "\n") local headers_table = {} + -- INFO: + -- We only want the headers of the last response + -- so we reset the headers_table only each time the previous line was empty + -- and we also have new headers data + local previously_empty = false for _, header in ipairs(lines) do - if header:find(":") ~= nil then - local kv = vim.split(header, ":") - local key = kv[1] - -- the value should be everything after the first colon - -- but we can't use slice and join because the value might contain colons - local value = header:sub(#key + 2) - headers_table[key] = vim.trim(value) + local empty_line = header == "" + if empty_line then + previously_empty = true + else + if previously_empty then + headers_table = {} + end + previously_empty = false + if header:find(":") ~= nil then + local kv = vim.split(header, ":") + local key = kv[1] + -- INFO: + -- the value should be everything after the first colon + -- but we can't use slice and join because the value might contain colons + local value = header:sub(#key + 2) + headers_table[key] = vim.trim(value) + end end end return headers_table @@ -78,7 +99,7 @@ local get_cookies_as_table = function() end local get_lower_headers_as_table = function() - local headers = get_headers_as_table() + local headers = get_last_headers_as_table() local headers_table = {} for key, value in pairs(headers) do headers_table[key:lower()] = value @@ -101,7 +122,7 @@ end M.set_env_for_named_request = function(name, body) local named_request = { response = { - headers = get_headers_as_table(), + headers = get_last_headers_as_table(), body = body, cookies = get_cookies_as_table(), }, diff --git a/lua/kulala/parser/init.lua b/lua/kulala/parser/init.lua index 08dfb37..c3b810f 100644 --- a/lua/kulala/parser/init.lua +++ b/lua/kulala/parser/init.lua @@ -11,7 +11,6 @@ local REQUEST_VARIABLES = require("kulala.parser.request_variables") local STRING_UTILS = require("kulala.utils.string") local PARSER_UTILS = require("kulala.parser.utils") local TS = require("kulala.parser.treesitter") -local PLUGIN_TMP_DIR = FS.get_plugin_tmp_dir() local CURL_FORMAT_FILE = FS.get_plugin_path({ "parser", "curl-format.json" }) local Logger = require("kulala.logger") @@ -342,8 +341,8 @@ M.get_document = function() -- dynamic variables are defined as `{{$variable_name}}` local key, value = line:match("^([^:]+):%s*(.*)$") if key and value then - request.headers[key:lower()] = value - request.headers_raw[key:lower()] = value + request.headers[key] = value + request.headers_raw[key] = value end elseif is_request_line == true then -- Request line (e.g., GET http://example.com HTTP/1.1) @@ -455,13 +454,9 @@ end ---@field file string -- The file path to write the response body to ---@field overwrite boolean -- Whether to overwrite the file if it already exists ----@class ScriptsItems ----@field inline table -- Inline post-request handler scripts - each element is a line of the script ----@field files table -- File post-request handler scripts - each element is a file path ---- ---@class Scripts ----@field pre_request ScriptsItems -- Pre-request handler scripts ----@field post_request ScriptsItems -- Post-request handler scripts +---@field pre_request ScriptData -- Pre-request handler scripts +---@field post_request ScriptData -- Post-request handler scripts ---@class Request ---@field metadata table[] -- Metadata of the request @@ -575,12 +570,12 @@ M.parse = function(start_request_linenr) res.url, res.headers, res.body = replace_variables_in_url_headers_body(res, document_variables, env, has_pre_request_scripts) - -- Merge headers from the $shared environment if it exists + -- Merge headers from the $shared environment if it does not exist in the request + -- this ensures that you can always override the headers in the request if DB.find_unique("http_client_env_shared") then local default_headers = DB.find_unique("http_client_env_shared")["$default_headers"] if default_headers then for key, value in pairs(default_headers) do - key = key:lower() if res.headers[key] == nil then res.headers[key] = value end @@ -617,13 +612,15 @@ M.parse = function(start_request_linenr) table.insert(res.cmd, res.method) local is_graphql = PARSER_UTILS.contains_meta_tag(res, "graphql") - or PARSER_UTILS.contains_header(res.headers, "x-request-type", "GraphQL") + or PARSER_UTILS.contains_header(res.headers, "x-request-type", "graphql") if CONFIG.get().treesitter then -- treesitter parser handles graphql requests before this point is_graphql = false end - if res.headers["content-type"] ~= nil and res.body ~= nil then + local content_type_header_name, content_type_header_value = PARSER_UTILS.get_header(res.headers, "content-type") + + if content_type_header_name and content_type_header_value and res.body ~= nil then -- check if we are a graphql query -- we need this here, because the user could have defined the content-type -- as application/json, but the body is a graphql query @@ -633,9 +630,9 @@ M.parse = function(start_request_linenr) if gql_json then table.insert(res.cmd, "--data") table.insert(res.cmd, gql_json) - res.headers["content-type"] = "application/json" + res.headers[content_type_header_name] = "application/json" end - elseif res.headers["content-type"]:find("^multipart/form%-data") then + elseif content_type_header_value:find("^multipart/form%-data") then local tmp_file = FS.get_binary_temp_file(res.body) if tmp_file ~= nil then table.insert(res.cmd, "--data-binary") @@ -654,31 +651,32 @@ M.parse = function(start_request_linenr) if gql_json then table.insert(res.cmd, "--data") table.insert(res.cmd, gql_json) - res.headers["content-type"] = "application/json" + res.headers[content_type_header_name] = "application/json" end end end - if res.headers["authorization"] then - local auth_header = res.headers["authorization"] - local authtype = auth_header:match("^(%w+)%s+.*") + local auth_header_name, auth_header_value = PARSER_UTILS.get_header(res.headers, "authorization") + + if auth_header_name and auth_header_value then + local authtype = auth_header_value:match("^(%w+)%s+.*") if authtype == nil then - authtype = auth_header:match("^(%w+)%s*$") + authtype = auth_header_value:match("^(%w+)%s*$") end if authtype ~= nil then authtype = authtype:lower() if authtype == "ntlm" or authtype == "negotiate" or authtype == "digest" or authtype == "basic" then - local match, authuser, authpw = auth_header:match("^(%w+)%s+([^%s:]+)%s*[:%s]%s*([^%s]+)%s*$") + local match, authuser, authpw = auth_header_value:match("^(%w+)%s+([^%s:]+)%s*[:%s]%s*([^%s]+)%s*$") if match ~= nil or (authtype == "ntlm" or authtype == "negotiate") then table.insert(res.cmd, "--" .. authtype) table.insert(res.cmd, "-u") table.insert(res.cmd, (authuser or "") .. ":" .. (authpw or "")) - res.headers["authorization"] = nil + res.headers[auth_header_name] = nil end elseif authtype == "aws" then - local key, secret, optional = auth_header:match("^%w+%s([^%s]+)%s*([^%s]+)[%s$]+(.*)$") + local key, secret, optional = auth_header_value:match("^%w+%s([^%s]+)%s*([^%s]+)[%s$]+(.*)$") local token = optional:match("token:([^%s]+)") local region = optional:match("region:([^%s]+)") local service = optional:match("service:([^%s]+)") @@ -697,7 +695,7 @@ M.parse = function(start_request_linenr) table.insert(res.cmd, "-H") table.insert(res.cmd, "x-amz-security-token:" .. token) end - res.headers["authorization"] = nil + res.headers[auth_header_name] = nil end end end diff --git a/lua/kulala/parser/utils.lua b/lua/kulala/parser/utils.lua index 374ab31..b974f70 100644 --- a/lua/kulala/parser/utils.lua +++ b/lua/kulala/parser/utils.lua @@ -1,21 +1,117 @@ local M = {} +-- PERF: we do a lot of if else blocks with repeating loops +-- we could "optimize" this by using a single loop and if else blocks +-- that would make the code more readable and easier to maintain +-- but it would also make it slower + +---Check if a request has a specific meta tag +---@param request table The request to check +---@param tag string The meta tag to check for M.contains_meta_tag = function(request, tag) + tag = tag:lower() for _, meta in ipairs(request.metadata) do - if meta.name == tag then + if meta.name:lower() == tag then return true end end return false end +---Check if a header is present in the request +---@param headers table The headers to check +---@param header string The header name to check +---@param value string|nil The value to check for or nil if only the header name should be checked +---@return boolean M.contains_header = function(headers, header, value) - for k, v in pairs(headers) do - if k == header and v == value then - return true + header = header:lower() + value = value and value:lower() or nil + vim.print("header: " .. header .. " value: " .. value) + if value == nil then + for k, _ in pairs(headers) do + if k:lower() == header then + return true + end + end + else + for k, v in pairs(headers) do + if k:lower() == header and v:lower() == value then + return true + end end end return false end +---Get the value of a header from the request +---@param headers table The headers to check +---@param header string The header name to check +---@param dont_ignore_case boolean|nil If true, the header name will be case sensitive +---@return string|nil +M.get_header_value = function(headers, header, dont_ignore_case) + header = dont_ignore_case and header or header:lower() + for k, v in pairs(headers) do + if k == header then + return v + end + end + return nil +end + +---Get the name of a header from the request +---@param headers table The headers to check +---@param header string The header name to check +---@param dont_ignore_case boolean|nil If true, the header name will be case sensitive +---@return string|nil +M.get_header_name = function(headers, header, dont_ignore_case) + header = dont_ignore_case and header or header:lower() + for k, _ in pairs(headers) do + if k:lower() == header then + return k + end + end + return nil +end + +---Get a header from the request +---@param headers table The headers to check +---@param header string The header name to check +---@param value string|nil The value to check for or nil if only the header name should be checked +---@param dont_ignore_case boolean|nil If true, the header name will be case sensitive +---@return (string|nil), (string|nil) The header name and value or nil if not found +M.get_header = function(headers, header, value, dont_ignore_case) + header = dont_ignore_case and header or header:lower() + value = value and (dont_ignore_case and value or value:lower()) or nil + if dont_ignore_case then + if value == nil then + for k, _ in pairs(headers) do + if k == header then + return k, headers[k] + end + end + else + for k, v in pairs(headers) do + if k == header and v == value then + return k, v + end + end + end + else + if value == nil then + for k, _ in pairs(headers) do + if k:lower() == header then + return k, headers[k] + end + end + else + for k, v in pairs(headers) do + if k:lower() == header and v:lower() == value then + return k, v + end + end + end + end + return nil, nil +end + return M