Skip to content

Commit

Permalink
feat(headers): use as supplied + support multiple headers (#260)
Browse files Browse the repository at this point in the history
  • Loading branch information
gorillamoe authored Oct 4, 2024
1 parent 39ce81e commit a408fcf
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 40 deletions.
7 changes: 6 additions & 1 deletion lua/kulala/cmd/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,12 @@ M.run_parser = function(req, callback)
end
end
INT_PROCESSING.redirect_response_body_to_file(result.redirect_response_body_to_files)
PARSER.scripts.javascript.run("post_request", result.scripts.post_request)

local has_post_request_scripts = #result.scripts.post_request.inline > 0
or #result.scripts.pre_request.files > 0
if has_post_request_scripts then
PARSER.scripts.javascript.run("post_request", result.scripts.post_request)
end
Api.trigger("after_request")
end
Fs.delete_request_scripts_files()
Expand Down
2 changes: 1 addition & 1 deletion lua/kulala/globals/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ local M = {}

local plugin_tmp_dir = FS.get_plugin_tmp_dir()

M.VERSION = "4.0.2"
M.VERSION = "4.0.3"
M.UI_ID = "kulala://ui"
M.SCRATCHPAD_ID = "kulala://scratchpad"
M.HEADERS_FILE = plugin_tmp_dir .. "/headers.txt"
Expand Down
41 changes: 31 additions & 10 deletions lua/kulala/internal_processing/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,39 @@ local function get_nested_value(t, key)
return value
end

local get_headers_as_table = function()
---Function to get the last headers as a table
---@description Reads the headers file and returns the headers as a table.
---In some cases the headers file might contain multiple header sections,
---e.g. if you have follow-redirections enabled.
---This function will return the headers of the last response.
---@return table
local get_last_headers_as_table = function()
local headers_file = FS.read_file(GLOBALS.HEADERS_FILE):gsub("\r\n", "\n")
local lines = vim.split(headers_file, "\n")
local headers_table = {}
-- INFO:
-- We only want the headers of the last response
-- so we reset the headers_table only each time the previous line was empty
-- and we also have new headers data
local previously_empty = false
for _, header in ipairs(lines) do
if header:find(":") ~= nil then
local kv = vim.split(header, ":")
local key = kv[1]
-- the value should be everything after the first colon
-- but we can't use slice and join because the value might contain colons
local value = header:sub(#key + 2)
headers_table[key] = vim.trim(value)
local empty_line = header == ""
if empty_line then
previously_empty = true
else
if previously_empty then
headers_table = {}
end
previously_empty = false
if header:find(":") ~= nil then
local kv = vim.split(header, ":")
local key = kv[1]
-- INFO:
-- the value should be everything after the first colon
-- but we can't use slice and join because the value might contain colons
local value = header:sub(#key + 2)
headers_table[key] = vim.trim(value)
end
end
end
return headers_table
Expand Down Expand Up @@ -78,7 +99,7 @@ local get_cookies_as_table = function()
end

local get_lower_headers_as_table = function()
local headers = get_headers_as_table()
local headers = get_last_headers_as_table()
local headers_table = {}
for key, value in pairs(headers) do
headers_table[key:lower()] = value
Expand All @@ -101,7 +122,7 @@ end
M.set_env_for_named_request = function(name, body)
local named_request = {
response = {
headers = get_headers_as_table(),
headers = get_last_headers_as_table(),
body = body,
cookies = get_cookies_as_table(),
},
Expand Down
46 changes: 22 additions & 24 deletions lua/kulala/parser/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ local REQUEST_VARIABLES = require("kulala.parser.request_variables")
local STRING_UTILS = require("kulala.utils.string")
local PARSER_UTILS = require("kulala.parser.utils")
local TS = require("kulala.parser.treesitter")
local PLUGIN_TMP_DIR = FS.get_plugin_tmp_dir()
local CURL_FORMAT_FILE = FS.get_plugin_path({ "parser", "curl-format.json" })
local Logger = require("kulala.logger")

Expand Down Expand Up @@ -342,8 +341,8 @@ M.get_document = function()
-- dynamic variables are defined as `{{$variable_name}}`
local key, value = line:match("^([^:]+):%s*(.*)$")
if key and value then
request.headers[key:lower()] = value
request.headers_raw[key:lower()] = value
request.headers[key] = value
request.headers_raw[key] = value
end
elseif is_request_line == true then
-- Request line (e.g., GET http://example.com HTTP/1.1)
Expand Down Expand Up @@ -455,13 +454,9 @@ end
---@field file string -- The file path to write the response body to
---@field overwrite boolean -- Whether to overwrite the file if it already exists

---@class ScriptsItems
---@field inline table -- Inline post-request handler scripts - each element is a line of the script
---@field files table -- File post-request handler scripts - each element is a file path
---
---@class Scripts
---@field pre_request ScriptsItems -- Pre-request handler scripts
---@field post_request ScriptsItems -- Post-request handler scripts
---@field pre_request ScriptData -- Pre-request handler scripts
---@field post_request ScriptData -- Post-request handler scripts

---@class Request
---@field metadata table[] -- Metadata of the request
Expand Down Expand Up @@ -575,12 +570,12 @@ M.parse = function(start_request_linenr)
res.url, res.headers, res.body =
replace_variables_in_url_headers_body(res, document_variables, env, has_pre_request_scripts)

-- Merge headers from the $shared environment if it exists
-- Merge headers from the $shared environment if it does not exist in the request
-- this ensures that you can always override the headers in the request
if DB.find_unique("http_client_env_shared") then
local default_headers = DB.find_unique("http_client_env_shared")["$default_headers"]
if default_headers then
for key, value in pairs(default_headers) do
key = key:lower()
if res.headers[key] == nil then
res.headers[key] = value
end
Expand Down Expand Up @@ -617,13 +612,15 @@ M.parse = function(start_request_linenr)
table.insert(res.cmd, res.method)

local is_graphql = PARSER_UTILS.contains_meta_tag(res, "graphql")
or PARSER_UTILS.contains_header(res.headers, "x-request-type", "GraphQL")
or PARSER_UTILS.contains_header(res.headers, "x-request-type", "graphql")
if CONFIG.get().treesitter then
-- treesitter parser handles graphql requests before this point
is_graphql = false
end

if res.headers["content-type"] ~= nil and res.body ~= nil then
local content_type_header_name, content_type_header_value = PARSER_UTILS.get_header(res.headers, "content-type")

if content_type_header_name and content_type_header_value and res.body ~= nil then
-- check if we are a graphql query
-- we need this here, because the user could have defined the content-type
-- as application/json, but the body is a graphql query
Expand All @@ -633,9 +630,9 @@ M.parse = function(start_request_linenr)
if gql_json then
table.insert(res.cmd, "--data")
table.insert(res.cmd, gql_json)
res.headers["content-type"] = "application/json"
res.headers[content_type_header_name] = "application/json"
end
elseif res.headers["content-type"]:find("^multipart/form%-data") then
elseif content_type_header_value:find("^multipart/form%-data") then
local tmp_file = FS.get_binary_temp_file(res.body)
if tmp_file ~= nil then
table.insert(res.cmd, "--data-binary")
Expand All @@ -654,31 +651,32 @@ M.parse = function(start_request_linenr)
if gql_json then
table.insert(res.cmd, "--data")
table.insert(res.cmd, gql_json)
res.headers["content-type"] = "application/json"
res.headers[content_type_header_name] = "application/json"
end
end
end

if res.headers["authorization"] then
local auth_header = res.headers["authorization"]
local authtype = auth_header:match("^(%w+)%s+.*")
local auth_header_name, auth_header_value = PARSER_UTILS.get_header(res.headers, "authorization")

if auth_header_name and auth_header_value then
local authtype = auth_header_value:match("^(%w+)%s+.*")
if authtype == nil then
authtype = auth_header:match("^(%w+)%s*$")
authtype = auth_header_value:match("^(%w+)%s*$")
end

if authtype ~= nil then
authtype = authtype:lower()

if authtype == "ntlm" or authtype == "negotiate" or authtype == "digest" or authtype == "basic" then
local match, authuser, authpw = auth_header:match("^(%w+)%s+([^%s:]+)%s*[:%s]%s*([^%s]+)%s*$")
local match, authuser, authpw = auth_header_value:match("^(%w+)%s+([^%s:]+)%s*[:%s]%s*([^%s]+)%s*$")
if match ~= nil or (authtype == "ntlm" or authtype == "negotiate") then
table.insert(res.cmd, "--" .. authtype)
table.insert(res.cmd, "-u")
table.insert(res.cmd, (authuser or "") .. ":" .. (authpw or ""))
res.headers["authorization"] = nil
res.headers[auth_header_name] = nil
end
elseif authtype == "aws" then
local key, secret, optional = auth_header:match("^%w+%s([^%s]+)%s*([^%s]+)[%s$]+(.*)$")
local key, secret, optional = auth_header_value:match("^%w+%s([^%s]+)%s*([^%s]+)[%s$]+(.*)$")
local token = optional:match("token:([^%s]+)")
local region = optional:match("region:([^%s]+)")
local service = optional:match("service:([^%s]+)")
Expand All @@ -697,7 +695,7 @@ M.parse = function(start_request_linenr)
table.insert(res.cmd, "-H")
table.insert(res.cmd, "x-amz-security-token:" .. token)
end
res.headers["authorization"] = nil
res.headers[auth_header_name] = nil
end
end
end
Expand Down
104 changes: 100 additions & 4 deletions lua/kulala/parser/utils.lua
Original file line number Diff line number Diff line change
@@ -1,21 +1,117 @@
local M = {}

-- PERF: we do a lot of if else blocks with repeating loops
-- we could "optimize" this by using a single loop and if else blocks
-- that would make the code more readable and easier to maintain
-- but it would also make it slower

---Check if a request has a specific meta tag
---@param request table The request to check
---@param tag string The meta tag to check for
M.contains_meta_tag = function(request, tag)
tag = tag:lower()
for _, meta in ipairs(request.metadata) do
if meta.name == tag then
if meta.name:lower() == tag then
return true
end
end
return false
end

---Check if a header is present in the request
---@param headers table The headers to check
---@param header string The header name to check
---@param value string|nil The value to check for or nil if only the header name should be checked
---@return boolean
M.contains_header = function(headers, header, value)
for k, v in pairs(headers) do
if k == header and v == value then
return true
header = header:lower()
value = value and value:lower() or nil
vim.print("header: " .. header .. " value: " .. value)
if value == nil then
for k, _ in pairs(headers) do
if k:lower() == header then
return true
end
end
else
for k, v in pairs(headers) do
if k:lower() == header and v:lower() == value then
return true
end
end
end
return false
end

---Get the value of a header from the request
---@param headers table The headers to check
---@param header string The header name to check
---@param dont_ignore_case boolean|nil If true, the header name will be case sensitive
---@return string|nil
M.get_header_value = function(headers, header, dont_ignore_case)
header = dont_ignore_case and header or header:lower()
for k, v in pairs(headers) do
if k == header then
return v
end
end
return nil
end

---Get the name of a header from the request
---@param headers table The headers to check
---@param header string The header name to check
---@param dont_ignore_case boolean|nil If true, the header name will be case sensitive
---@return string|nil
M.get_header_name = function(headers, header, dont_ignore_case)
header = dont_ignore_case and header or header:lower()
for k, _ in pairs(headers) do
if k:lower() == header then
return k
end
end
return nil
end

---Get a header from the request
---@param headers table The headers to check
---@param header string The header name to check
---@param value string|nil The value to check for or nil if only the header name should be checked
---@param dont_ignore_case boolean|nil If true, the header name will be case sensitive
---@return (string|nil), (string|nil) The header name and value or nil if not found
M.get_header = function(headers, header, value, dont_ignore_case)
header = dont_ignore_case and header or header:lower()
value = value and (dont_ignore_case and value or value:lower()) or nil
if dont_ignore_case then
if value == nil then
for k, _ in pairs(headers) do
if k == header then
return k, headers[k]
end
end
else
for k, v in pairs(headers) do
if k == header and v == value then
return k, v
end
end
end
else
if value == nil then
for k, _ in pairs(headers) do
if k:lower() == header then
return k, headers[k]
end
end
else
for k, v in pairs(headers) do
if k:lower() == header and v:lower() == value then
return k, v
end
end
end
end
return nil, nil
end

return M

0 comments on commit a408fcf

Please sign in to comment.