diff --git a/src/sql/postgres.zig b/src/sql/postgres.zig index 85fcd50ba9745c..14ac6e9de919a6 100644 --- a/src/sql/postgres.zig +++ b/src/sql/postgres.zig @@ -1027,100 +1027,107 @@ pub const PostgresSQLConnection = struct { pub const AuthenticationState = union(enum) { pending: void, - SASL: SASL, + none: void, ok: void, + SASL: SASL, + md5: void, pub fn zero(this: *AuthenticationState) void { - const bytes = std.mem.asBytes(this); - @memset(bytes, 0); + switch (this.*) { + .SASL => |*sasl| { + sasl.deinit(); + }, + else => {}, + } + this.* = .{ .none = {} }; } + }; - pub const SASL = struct { - const nonce_byte_len = 18; - const nonce_base64_len = bun.base64.encodeLenFromSize(nonce_byte_len); + pub const SASL = struct { + const nonce_byte_len = 18; + const nonce_base64_len = bun.base64.encodeLenFromSize(nonce_byte_len); - const server_signature_byte_len = 32; - const server_signature_base64_len = bun.base64.encodeLenFromSize(server_signature_byte_len); + const server_signature_byte_len = 32; + const server_signature_base64_len = bun.base64.encodeLenFromSize(server_signature_byte_len); - const salted_password_byte_len = 32; + const salted_password_byte_len = 32; - nonce_base64_bytes: [nonce_base64_len]u8 = .{0} ** nonce_base64_len, - nonce_len: u8 = 0, + nonce_base64_bytes: [nonce_base64_len]u8 = .{0} ** nonce_base64_len, + nonce_len: u8 = 0, - server_signature_base64_bytes: [server_signature_base64_len]u8 = .{0} ** server_signature_base64_len, - server_signature_len: u8 = 0, + server_signature_base64_bytes: [server_signature_base64_len]u8 = .{0} ** server_signature_base64_len, + server_signature_len: u8 = 0, - salted_password_bytes: [salted_password_byte_len]u8 = .{0} ** salted_password_byte_len, - salted_password_created: bool = false, + salted_password_bytes: [salted_password_byte_len]u8 = .{0} ** salted_password_byte_len, + salted_password_created: bool = false, - status: SASLStatus = .init, + status: SASLStatus = .init, - pub const SASLStatus = enum { - init, - @"continue", - }; + pub const SASLStatus = enum { + init, + @"continue", + }; - fn hmac(password: []const u8, data: []const u8) ?[32]u8 { - var buf = std.mem.zeroes([bun.BoringSSL.EVP_MAX_MD_SIZE]u8); + fn hmac(password: []const u8, data: []const u8) ?[32]u8 { + var buf = std.mem.zeroes([bun.BoringSSL.EVP_MAX_MD_SIZE]u8); - // TODO: I don't think this is failable. - const result = bun.hmac.generate(password, data, .sha256, &buf) orelse return null; + // TODO: I don't think this is failable. + const result = bun.hmac.generate(password, data, .sha256, &buf) orelse return null; - assert(result.len == 32); - return buf[0..32].*; - } + assert(result.len == 32); + return buf[0..32].*; + } - pub fn computeSaltedPassword(this: *SASL, salt_bytes: []const u8, iteration_count: u32, connection: *PostgresSQLConnection) !void { - this.salted_password_created = true; - if (Crypto.EVP.pbkdf2(&this.salted_password_bytes, connection.password, salt_bytes, iteration_count, .sha256) == null) { - return error.PBKDFD2; - } + pub fn computeSaltedPassword(this: *SASL, salt_bytes: []const u8, iteration_count: u32, connection: *PostgresSQLConnection) !void { + this.salted_password_created = true; + if (Crypto.EVP.pbkdf2(&this.salted_password_bytes, connection.password, salt_bytes, iteration_count, .sha256) == null) { + return error.PBKDFD2; } + } - pub fn saltedPassword(this: *const SASL) []const u8 { - assert(this.salted_password_created); - return this.salted_password_bytes[0..salted_password_byte_len]; - } + pub fn saltedPassword(this: *const SASL) []const u8 { + assert(this.salted_password_created); + return this.salted_password_bytes[0..salted_password_byte_len]; + } - pub fn serverSignature(this: *const SASL) []const u8 { - assert(this.server_signature_len > 0); - return this.server_signature_base64_bytes[0..this.server_signature_len]; - } + pub fn serverSignature(this: *const SASL) []const u8 { + assert(this.server_signature_len > 0); + return this.server_signature_base64_bytes[0..this.server_signature_len]; + } - pub fn computeServerSignature(this: *SASL, auth_string: []const u8) !void { - assert(this.server_signature_len == 0); + pub fn computeServerSignature(this: *SASL, auth_string: []const u8) !void { + assert(this.server_signature_len == 0); - const server_key = hmac(this.saltedPassword(), "Server Key") orelse return error.InvalidServerKey; - const server_signature_bytes = hmac(&server_key, auth_string) orelse return error.InvalidServerSignature; - this.server_signature_len = @intCast(bun.base64.encode(&this.server_signature_base64_bytes, &server_signature_bytes)); - } + const server_key = hmac(this.saltedPassword(), "Server Key") orelse return error.InvalidServerKey; + const server_signature_bytes = hmac(&server_key, auth_string) orelse return error.InvalidServerSignature; + this.server_signature_len = @intCast(bun.base64.encode(&this.server_signature_base64_bytes, &server_signature_bytes)); + } - pub fn clientKey(this: *const SASL) [32]u8 { - return hmac(this.saltedPassword(), "Client Key").?; - } + pub fn clientKey(this: *const SASL) [32]u8 { + return hmac(this.saltedPassword(), "Client Key").?; + } - pub fn clientKeySignature(_: *const SASL, client_key: []const u8, auth_string: []const u8) [32]u8 { - var sha_digest = std.mem.zeroes(bun.sha.SHA256.Digest); - bun.sha.SHA256.hash(client_key, &sha_digest, JSC.VirtualMachine.get().rareData().boringEngine()); - return hmac(&sha_digest, auth_string).?; - } + pub fn clientKeySignature(_: *const SASL, client_key: []const u8, auth_string: []const u8) [32]u8 { + var sha_digest = std.mem.zeroes(bun.sha.SHA256.Digest); + bun.sha.SHA256.hash(client_key, &sha_digest, JSC.VirtualMachine.get().rareData().boringEngine()); + return hmac(&sha_digest, auth_string).?; + } - pub fn nonce(this: *SASL) []const u8 { - if (this.nonce_len == 0) { - var bytes: [nonce_byte_len]u8 = .{0} ** nonce_byte_len; - bun.rand(&bytes); - this.nonce_len = @intCast(bun.base64.encode(&this.nonce_base64_bytes, &bytes)); - } - return this.nonce_base64_bytes[0..this.nonce_len]; + pub fn nonce(this: *SASL) []const u8 { + if (this.nonce_len == 0) { + var bytes: [nonce_byte_len]u8 = .{0} ** nonce_byte_len; + bun.rand(&bytes); + this.nonce_len = @intCast(bun.base64.encode(&this.nonce_base64_bytes, &bytes)); } + return this.nonce_base64_bytes[0..this.nonce_len]; + } - pub fn deinit(this: *SASL) void { - this.nonce_len = 0; - this.salted_password_created = false; - this.server_signature_len = 0; - this.status = .init; - } - }; + pub fn deinit(this: *SASL) void { + this.nonce_len = 0; + this.salted_password_created = false; + this.server_signature_len = 0; + this.status = .init; + } }; pub const Status = enum { @@ -1319,6 +1326,7 @@ pub const PostgresSQLConnection = struct { if (this.status == .failed) return; this.status = .failed; + this.ref(); defer this.deref(); if (!this.socket.isClosed()) this.socket.close(); @@ -1540,6 +1548,7 @@ pub const PostgresSQLConnection = struct { this.read_buffer.write(bun.default_allocator, data[offset..]) catch @panic("failed to write to read buffer"); } else { bun.handleErrorReturnTrace(err, @errorReturnTrace()); + this.fail("Failed to read data", err); } }; @@ -2614,6 +2623,43 @@ pub const PostgresSQLConnection = struct { this.flushData(); }, + .MD5Password => |md5| { + debug("MD5Password", .{}); + // Format is: md5 + md5(md5(password + username) + salt) + var first_hash_buf: bun.sha.MD5.Digest = undefined; + var first_hash_str: [32]u8 = undefined; + var final_hash_buf: bun.sha.MD5.Digest = undefined; + var final_hash_str: [32]u8 = undefined; + var final_password_buf: [36]u8 = undefined; + + // First hash: md5(password + username) + var first_hasher = bun.sha.MD5.init(); + first_hasher.update(this.password); + first_hasher.update(this.user); + first_hasher.final(&first_hash_buf); + const first_hash_str_output = std.fmt.bufPrint(&first_hash_str, "{x}", .{std.fmt.fmtSliceHexLower(&first_hash_buf)}) catch unreachable; + + // Second hash: md5(first_hash + salt) + var final_hasher = bun.sha.MD5.init(); + final_hasher.update(first_hash_str_output); + final_hasher.update(&md5.salt); + final_hasher.final(&final_hash_buf); + const final_hash_str_output = std.fmt.bufPrint(&final_hash_str, "{x}", .{std.fmt.fmtSliceHexLower(&final_hash_buf)}) catch unreachable; + + // Format final password as "md5" + final_hash + const final_password = std.fmt.bufPrintZ(&final_password_buf, "md5{s}", .{final_hash_str_output}) catch unreachable; + + var response = protocol.PasswordMessage{ + .password = .{ + .temporary = final_password, + }, + }; + + this.authentication_state = .{ .md5 = {} }; + try response.writeInternal(PostgresSQLConnection.Writer, this.writer()); + this.flushData(); + }, + else => { debug("TODO auth: {s}", .{@tagName(std.meta.activeTag(auth))}); this.fail("TODO: support authentication method: {s}", error.UNSUPPORTED_AUTHENTICATION_METHOD); @@ -2634,18 +2680,12 @@ pub const PostgresSQLConnection = struct { var err: protocol.ErrorResponse = undefined; try err.decodeInternal(Context, reader); - if (this.status == .connecting) { - this.status = .failed; + if (this.status == .connecting or this.status == .sent_startup_message) { defer { err.deinit(); - this.poll_ref.unref(this.globalObject.bunVM()); - this.updateHasPendingActivity(); } - const on_connect = this.consumeOnConnectCallback(this.globalObject) orelse return; - defer on_connect.ensureStillAlive(); - const js_value = this.js_value; - this.globalObject.queueMicrotask(on_connect, &[_]JSValue{ err.toJS(this.globalObject), js_value }); + this.failWithJSValue(err.toJS(this.globalObject)); // it shouldn't enqueue any requests while connecting bun.assert(this.requests.count == 0); diff --git a/src/sql/postgres/postgres_protocol.zig b/src/sql/postgres/postgres_protocol.zig index 0d6ddee4012b6f..60eeaf9f9dce36 100644 --- a/src/sql/postgres/postgres_protocol.zig +++ b/src/sql/postgres/postgres_protocol.zig @@ -539,9 +539,6 @@ pub const Authentication = union(enum) { }, 5 => { if (message_length != 12) return error.InvalidMessageLength; - if (!try reader.expectInt(u32, 5)) { - return error.InvalidMessage; - } var salt_data = try reader.bytes(4); defer salt_data.deinit(); this.* = .{ diff --git a/test/js/sql/sql.test.ts b/test/js/sql/sql.test.ts index 11c63b02f05f1e..63c6fd1aa18f1b 100644 --- a/test/js/sql/sql.test.ts +++ b/test/js/sql/sql.test.ts @@ -13,18 +13,20 @@ if (!isCI) { // local all postgres trust // local all bun_sql_test_scram scram-sha-256 // local all bun_sql_test trust - // + // local all bun_sql_test_md5 md5 + // # IPv4 local connections: // host all ${USERNAME} 127.0.0.1/32 trust // host all postgres 127.0.0.1/32 trust // host all bun_sql_test_scram 127.0.0.1/32 scram-sha-256 // host all bun_sql_test 127.0.0.1/32 trust + // host all bun_sql_test_md5 127.0.0.1/32 md5 // # IPv6 local connections: // host all ${USERNAME} ::1/128 trust // host all postgres ::1/128 trust // host all bun_sql_test ::1/128 trust // host all bun_sql_test_scram ::1/128 scram-sha-256 - // + // host all bun_sql_test_md5 ::1/128 md5 // # Allow replication connections from localhost, by a user with the // # replication privilege. // local replication all trust @@ -33,9 +35,6 @@ if (!isCI) { // --- Expected pg_hba.conf --- process.env.DATABASE_URL = "postgres://bun_sql_test@localhost:5432/bun_sql_test"; - const delay = ms => Bun.sleep(ms); - const rel = x => new URL(x, import.meta.url); - const login = { username: "bun_sql_test", }; @@ -598,9 +597,21 @@ if (!isCI) { // return [true, (await postgres({ ...options, ...login })`select true as x`)[0].x] // }) - // t('Login using MD5', async() => { - // return [true, (await postgres({ ...options, ...login_md5 })`select true as x`)[0].x] - // }) + test("Login using MD5", async () => { + await using sql = postgres({ ...options, ...login_md5 }); + expect(await sql`select true as x`).toEqual([{ x: true }]); + }); + + test("Login with bad credentials propagates error from server", async () => { + const sql = postgres({ ...options, ...login_md5, username: "bad_user", password: "bad_password" }); + let err; + try { + await sql`select true as x`; + } catch (e) { + err = e; + } + expect(err.code).toBe("ERR_POSTGRES_SERVER_ERROR"); + }); test("Login using scram-sha-256", async () => { await using sql = postgres({ ...options, ...login_scram });