Skip to content

Commit

Permalink
Implement MD5 password auth + handle credential errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Jarred-Sumner committed Dec 29, 2024
1 parent c18b983 commit c13aa5f
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 87 deletions.
192 changes: 116 additions & 76 deletions src/sql/postgres.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1027,100 +1027,107 @@ pub const PostgresSQLConnection = struct {

pub const AuthenticationState = union(enum) {
pending: void,
SASL: SASL,
none: void,
ok: void,
SASL: SASL,
md5: void,

pub fn zero(this: *AuthenticationState) void {
const bytes = std.mem.asBytes(this);
@memset(bytes, 0);
switch (this.*) {
.SASL => |*sasl| {
sasl.deinit();
},
else => {},
}
this.* = .{ .none = {} };
}
};

pub const SASL = struct {
const nonce_byte_len = 18;
const nonce_base64_len = bun.base64.encodeLenFromSize(nonce_byte_len);
pub const SASL = struct {
const nonce_byte_len = 18;
const nonce_base64_len = bun.base64.encodeLenFromSize(nonce_byte_len);

const server_signature_byte_len = 32;
const server_signature_base64_len = bun.base64.encodeLenFromSize(server_signature_byte_len);
const server_signature_byte_len = 32;
const server_signature_base64_len = bun.base64.encodeLenFromSize(server_signature_byte_len);

const salted_password_byte_len = 32;
const salted_password_byte_len = 32;

nonce_base64_bytes: [nonce_base64_len]u8 = .{0} ** nonce_base64_len,
nonce_len: u8 = 0,
nonce_base64_bytes: [nonce_base64_len]u8 = .{0} ** nonce_base64_len,
nonce_len: u8 = 0,

server_signature_base64_bytes: [server_signature_base64_len]u8 = .{0} ** server_signature_base64_len,
server_signature_len: u8 = 0,
server_signature_base64_bytes: [server_signature_base64_len]u8 = .{0} ** server_signature_base64_len,
server_signature_len: u8 = 0,

salted_password_bytes: [salted_password_byte_len]u8 = .{0} ** salted_password_byte_len,
salted_password_created: bool = false,
salted_password_bytes: [salted_password_byte_len]u8 = .{0} ** salted_password_byte_len,
salted_password_created: bool = false,

status: SASLStatus = .init,
status: SASLStatus = .init,

pub const SASLStatus = enum {
init,
@"continue",
};
pub const SASLStatus = enum {
init,
@"continue",
};

fn hmac(password: []const u8, data: []const u8) ?[32]u8 {
var buf = std.mem.zeroes([bun.BoringSSL.EVP_MAX_MD_SIZE]u8);
fn hmac(password: []const u8, data: []const u8) ?[32]u8 {
var buf = std.mem.zeroes([bun.BoringSSL.EVP_MAX_MD_SIZE]u8);

// TODO: I don't think this is failable.
const result = bun.hmac.generate(password, data, .sha256, &buf) orelse return null;
// TODO: I don't think this is failable.
const result = bun.hmac.generate(password, data, .sha256, &buf) orelse return null;

assert(result.len == 32);
return buf[0..32].*;
}
assert(result.len == 32);
return buf[0..32].*;
}

pub fn computeSaltedPassword(this: *SASL, salt_bytes: []const u8, iteration_count: u32, connection: *PostgresSQLConnection) !void {
this.salted_password_created = true;
if (Crypto.EVP.pbkdf2(&this.salted_password_bytes, connection.password, salt_bytes, iteration_count, .sha256) == null) {
return error.PBKDFD2;
}
pub fn computeSaltedPassword(this: *SASL, salt_bytes: []const u8, iteration_count: u32, connection: *PostgresSQLConnection) !void {
this.salted_password_created = true;
if (Crypto.EVP.pbkdf2(&this.salted_password_bytes, connection.password, salt_bytes, iteration_count, .sha256) == null) {
return error.PBKDFD2;
}
}

pub fn saltedPassword(this: *const SASL) []const u8 {
assert(this.salted_password_created);
return this.salted_password_bytes[0..salted_password_byte_len];
}
pub fn saltedPassword(this: *const SASL) []const u8 {
assert(this.salted_password_created);
return this.salted_password_bytes[0..salted_password_byte_len];
}

pub fn serverSignature(this: *const SASL) []const u8 {
assert(this.server_signature_len > 0);
return this.server_signature_base64_bytes[0..this.server_signature_len];
}
pub fn serverSignature(this: *const SASL) []const u8 {
assert(this.server_signature_len > 0);
return this.server_signature_base64_bytes[0..this.server_signature_len];
}

pub fn computeServerSignature(this: *SASL, auth_string: []const u8) !void {
assert(this.server_signature_len == 0);
pub fn computeServerSignature(this: *SASL, auth_string: []const u8) !void {
assert(this.server_signature_len == 0);

const server_key = hmac(this.saltedPassword(), "Server Key") orelse return error.InvalidServerKey;
const server_signature_bytes = hmac(&server_key, auth_string) orelse return error.InvalidServerSignature;
this.server_signature_len = @intCast(bun.base64.encode(&this.server_signature_base64_bytes, &server_signature_bytes));
}
const server_key = hmac(this.saltedPassword(), "Server Key") orelse return error.InvalidServerKey;
const server_signature_bytes = hmac(&server_key, auth_string) orelse return error.InvalidServerSignature;
this.server_signature_len = @intCast(bun.base64.encode(&this.server_signature_base64_bytes, &server_signature_bytes));
}

pub fn clientKey(this: *const SASL) [32]u8 {
return hmac(this.saltedPassword(), "Client Key").?;
}
pub fn clientKey(this: *const SASL) [32]u8 {
return hmac(this.saltedPassword(), "Client Key").?;
}

pub fn clientKeySignature(_: *const SASL, client_key: []const u8, auth_string: []const u8) [32]u8 {
var sha_digest = std.mem.zeroes(bun.sha.SHA256.Digest);
bun.sha.SHA256.hash(client_key, &sha_digest, JSC.VirtualMachine.get().rareData().boringEngine());
return hmac(&sha_digest, auth_string).?;
}
pub fn clientKeySignature(_: *const SASL, client_key: []const u8, auth_string: []const u8) [32]u8 {
var sha_digest = std.mem.zeroes(bun.sha.SHA256.Digest);
bun.sha.SHA256.hash(client_key, &sha_digest, JSC.VirtualMachine.get().rareData().boringEngine());
return hmac(&sha_digest, auth_string).?;
}

pub fn nonce(this: *SASL) []const u8 {
if (this.nonce_len == 0) {
var bytes: [nonce_byte_len]u8 = .{0} ** nonce_byte_len;
bun.rand(&bytes);
this.nonce_len = @intCast(bun.base64.encode(&this.nonce_base64_bytes, &bytes));
}
return this.nonce_base64_bytes[0..this.nonce_len];
pub fn nonce(this: *SASL) []const u8 {
if (this.nonce_len == 0) {
var bytes: [nonce_byte_len]u8 = .{0} ** nonce_byte_len;
bun.rand(&bytes);
this.nonce_len = @intCast(bun.base64.encode(&this.nonce_base64_bytes, &bytes));
}
return this.nonce_base64_bytes[0..this.nonce_len];
}

pub fn deinit(this: *SASL) void {
this.nonce_len = 0;
this.salted_password_created = false;
this.server_signature_len = 0;
this.status = .init;
}
};
pub fn deinit(this: *SASL) void {
this.nonce_len = 0;
this.salted_password_created = false;
this.server_signature_len = 0;
this.status = .init;
}
};

pub const Status = enum {
Expand Down Expand Up @@ -1319,6 +1326,7 @@ pub const PostgresSQLConnection = struct {
if (this.status == .failed) return;

this.status = .failed;

this.ref();
defer this.deref();
if (!this.socket.isClosed()) this.socket.close();
Expand Down Expand Up @@ -1540,6 +1548,7 @@ pub const PostgresSQLConnection = struct {
this.read_buffer.write(bun.default_allocator, data[offset..]) catch @panic("failed to write to read buffer");
} else {
bun.handleErrorReturnTrace(err, @errorReturnTrace());

this.fail("Failed to read data", err);
}
};
Expand Down Expand Up @@ -2614,6 +2623,43 @@ pub const PostgresSQLConnection = struct {
this.flushData();
},

.MD5Password => |md5| {
debug("MD5Password", .{});
// Format is: md5 + md5(md5(password + username) + salt)
var first_hash_buf: bun.sha.MD5.Digest = undefined;
var first_hash_str: [32]u8 = undefined;
var final_hash_buf: bun.sha.MD5.Digest = undefined;
var final_hash_str: [32]u8 = undefined;
var final_password_buf: [36]u8 = undefined;

// First hash: md5(password + username)
var first_hasher = bun.sha.MD5.init();
first_hasher.update(this.password);
first_hasher.update(this.user);
first_hasher.final(&first_hash_buf);
const first_hash_str_output = std.fmt.bufPrint(&first_hash_str, "{x}", .{std.fmt.fmtSliceHexLower(&first_hash_buf)}) catch unreachable;

// Second hash: md5(first_hash + salt)
var final_hasher = bun.sha.MD5.init();
final_hasher.update(first_hash_str_output);
final_hasher.update(&md5.salt);
final_hasher.final(&final_hash_buf);
const final_hash_str_output = std.fmt.bufPrint(&final_hash_str, "{x}", .{std.fmt.fmtSliceHexLower(&final_hash_buf)}) catch unreachable;

// Format final password as "md5" + final_hash
const final_password = std.fmt.bufPrintZ(&final_password_buf, "md5{s}", .{final_hash_str_output}) catch unreachable;

var response = protocol.PasswordMessage{
.password = .{
.temporary = final_password,
},
};

this.authentication_state = .{ .md5 = {} };
try response.writeInternal(PostgresSQLConnection.Writer, this.writer());
this.flushData();
},

else => {
debug("TODO auth: {s}", .{@tagName(std.meta.activeTag(auth))});
this.fail("TODO: support authentication method: {s}", error.UNSUPPORTED_AUTHENTICATION_METHOD);
Expand All @@ -2634,18 +2680,12 @@ pub const PostgresSQLConnection = struct {
var err: protocol.ErrorResponse = undefined;
try err.decodeInternal(Context, reader);

if (this.status == .connecting) {
this.status = .failed;
if (this.status == .connecting or this.status == .sent_startup_message) {
defer {
err.deinit();
this.poll_ref.unref(this.globalObject.bunVM());
this.updateHasPendingActivity();
}

const on_connect = this.consumeOnConnectCallback(this.globalObject) orelse return;
defer on_connect.ensureStillAlive();
const js_value = this.js_value;
this.globalObject.queueMicrotask(on_connect, &[_]JSValue{ err.toJS(this.globalObject), js_value });
this.failWithJSValue(err.toJS(this.globalObject));

// it shouldn't enqueue any requests while connecting
bun.assert(this.requests.count == 0);
Expand Down
3 changes: 0 additions & 3 deletions src/sql/postgres/postgres_protocol.zig
Original file line number Diff line number Diff line change
Expand Up @@ -539,9 +539,6 @@ pub const Authentication = union(enum) {
},
5 => {
if (message_length != 12) return error.InvalidMessageLength;
if (!try reader.expectInt(u32, 5)) {
return error.InvalidMessage;
}
var salt_data = try reader.bytes(4);
defer salt_data.deinit();
this.* = .{
Expand Down
27 changes: 19 additions & 8 deletions test/js/sql/sql.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,20 @@ if (!isCI) {
// local all postgres trust
// local all bun_sql_test_scram scram-sha-256
// local all bun_sql_test trust
//
// local all bun_sql_test_md5 md5

// # IPv4 local connections:
// host all ${USERNAME} 127.0.0.1/32 trust
// host all postgres 127.0.0.1/32 trust
// host all bun_sql_test_scram 127.0.0.1/32 scram-sha-256
// host all bun_sql_test 127.0.0.1/32 trust
// host all bun_sql_test_md5 127.0.0.1/32 md5
// # IPv6 local connections:
// host all ${USERNAME} ::1/128 trust
// host all postgres ::1/128 trust
// host all bun_sql_test ::1/128 trust
// host all bun_sql_test_scram ::1/128 scram-sha-256
//
// host all bun_sql_test_md5 ::1/128 md5
// # Allow replication connections from localhost, by a user with the
// # replication privilege.
// local replication all trust
Expand All @@ -33,9 +35,6 @@ if (!isCI) {
// --- Expected pg_hba.conf ---
process.env.DATABASE_URL = "postgres://bun_sql_test@localhost:5432/bun_sql_test";

const delay = ms => Bun.sleep(ms);
const rel = x => new URL(x, import.meta.url);

const login = {
username: "bun_sql_test",
};
Expand Down Expand Up @@ -598,9 +597,21 @@ if (!isCI) {
// return [true, (await postgres({ ...options, ...login })`select true as x`)[0].x]
// })

// t('Login using MD5', async() => {
// return [true, (await postgres({ ...options, ...login_md5 })`select true as x`)[0].x]
// })
test("Login using MD5", async () => {
await using sql = postgres({ ...options, ...login_md5 });
expect(await sql`select true as x`).toEqual([{ x: true }]);
});

test("Login with bad credentials propagates error from server", async () => {
const sql = postgres({ ...options, ...login_md5, username: "bad_user", password: "bad_password" });
let err;
try {
await sql`select true as x`;
} catch (e) {
err = e;
}
expect(err.code).toBe("ERR_POSTGRES_SERVER_ERROR");
});

test("Login using scram-sha-256", async () => {
await using sql = postgres({ ...options, ...login_scram });
Expand Down

0 comments on commit c13aa5f

Please sign in to comment.