Skip to content

Commit

Permalink
proxy: add per query array mode flag (#6678)
Browse files Browse the repository at this point in the history
## Problem

Drizzle needs to be able to configure the array_mode flag per query.

## Summary of changes

Adds an array_mode flag to the query data json that will otherwise
default to the header flag.
  • Loading branch information
conradludgate authored Feb 9, 2024
1 parent 951c9bf commit ea089dc
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 74 deletions.
157 changes: 83 additions & 74 deletions proxy/src/serverless/sql_over_http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,13 @@ use super::json::pg_text_row_to_json;
use super::SERVERLESS_DRIVER_SNI;

#[derive(serde::Deserialize)]
#[serde(rename_all = "camelCase")]
struct QueryData {
query: String,
#[serde(deserialize_with = "bytes_to_pg_text")]
params: Vec<Option<String>>,
#[serde(default)]
array_mode: Option<bool>,
}

#[derive(serde::Deserialize)]
Expand Down Expand Up @@ -330,7 +333,7 @@ async fn handle_inner(
// Determine the output options. Default behaviour is 'false'. Anything that is not
// strictly 'true' assumed to be false.
let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE);
let array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE);
let default_array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE);

// Allow connection pooling only if explicitly requested
// or if we have decided that http pool is no longer opt-in
Expand Down Expand Up @@ -402,83 +405,87 @@ async fn handle_inner(
// Now execute the query and return the result
//
let mut size = 0;
let result =
match payload {
Payload::Single(stmt) => {
let (status, results) =
query_to_json(&*client, stmt, &mut 0, raw_output, array_mode)
.await
.map_err(|e| {
client.discard();
e
})?;
client.check_idle(status);
results
let result = match payload {
Payload::Single(stmt) => {
let (status, results) =
query_to_json(&*client, stmt, &mut 0, raw_output, default_array_mode)
.await
.map_err(|e| {
client.discard();
e
})?;
client.check_idle(status);
results
}
Payload::Batch(statements) => {
let (inner, mut discard) = client.inner();
let mut builder = inner.build_transaction();
if let Some(isolation_level) = txn_isolation_level {
builder = builder.isolation_level(isolation_level);
}
if txn_read_only {
builder = builder.read_only(true);
}
if txn_deferrable {
builder = builder.deferrable(true);
}
Payload::Batch(statements) => {
let (inner, mut discard) = client.inner();
let mut builder = inner.build_transaction();
if let Some(isolation_level) = txn_isolation_level {
builder = builder.isolation_level(isolation_level);
}
if txn_read_only {
builder = builder.read_only(true);
}
if txn_deferrable {
builder = builder.deferrable(true);
}

let transaction = builder.start().await.map_err(|e| {
// if we cannot start a transaction, we should return immediately
// and not return to the pool. connection is clearly broken
discard.discard();
e
})?;

let results =
match query_batch(&transaction, statements, &mut size, raw_output, array_mode)
.await
{
Ok(results) => {
let status = transaction.commit().await.map_err(|e| {
// if we cannot commit - for now don't return connection to pool
// TODO: get a query status from the error
discard.discard();
e
})?;
discard.check_idle(status);
results
}
Err(err) => {
let status = transaction.rollback().await.map_err(|e| {
// if we cannot rollback - for now don't return connection to pool
// TODO: get a query status from the error
discard.discard();
e
})?;
discard.check_idle(status);
return Err(err);
}
};

if txn_read_only {
response = response.header(
TXN_READ_ONLY.clone(),
HeaderValue::try_from(txn_read_only.to_string())?,
);
}
if txn_deferrable {
response = response.header(
TXN_DEFERRABLE.clone(),
HeaderValue::try_from(txn_deferrable.to_string())?,
);
let transaction = builder.start().await.map_err(|e| {
// if we cannot start a transaction, we should return immediately
// and not return to the pool. connection is clearly broken
discard.discard();
e
})?;

let results = match query_batch(
&transaction,
statements,
&mut size,
raw_output,
default_array_mode,
)
.await
{
Ok(results) => {
let status = transaction.commit().await.map_err(|e| {
// if we cannot commit - for now don't return connection to pool
// TODO: get a query status from the error
discard.discard();
e
})?;
discard.check_idle(status);
results
}
if let Some(txn_isolation_level) = txn_isolation_level_raw {
response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
Err(err) => {
let status = transaction.rollback().await.map_err(|e| {
// if we cannot rollback - for now don't return connection to pool
// TODO: get a query status from the error
discard.discard();
e
})?;
discard.check_idle(status);
return Err(err);
}
json!({ "results": results })
};

if txn_read_only {
response = response.header(
TXN_READ_ONLY.clone(),
HeaderValue::try_from(txn_read_only.to_string())?,
);
}
if txn_deferrable {
response = response.header(
TXN_DEFERRABLE.clone(),
HeaderValue::try_from(txn_deferrable.to_string())?,
);
}
if let Some(txn_isolation_level) = txn_isolation_level_raw {
response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
}
};
json!({ "results": results })
}
};

ctx.set_success();
ctx.log();
Expand Down Expand Up @@ -524,7 +531,7 @@ async fn query_to_json<T: GenericClient>(
data: QueryData,
current_size: &mut usize,
raw_output: bool,
array_mode: bool,
default_array_mode: bool,
) -> anyhow::Result<(ReadyForQueryStatus, Value)> {
let query_params = data.params;
let row_stream = client.query_raw_txt(&data.query, query_params).await?;
Expand Down Expand Up @@ -578,6 +585,8 @@ async fn query_to_json<T: GenericClient>(
columns.push(client.get_type(c.type_oid()).await?);
}

let array_mode = data.array_mode.unwrap_or(default_array_mode);

// convert rows to JSON
let rows = rows
.iter()
Expand Down
33 changes: 33 additions & 0 deletions test_runner/regress/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,39 @@ def qq(
assert result[0]["rows"] == [{"answer": 42}]


def test_sql_over_http_batch_output_options(static_proxy: NeonProxy):
static_proxy.safe_psql("create role http with login password 'http' superuser")

connstr = f"postgresql://http:http@{static_proxy.domain}:{static_proxy.proxy_port}/postgres"
response = requests.post(
f"https://{static_proxy.domain}:{static_proxy.external_http_port}/sql",
data=json.dumps(
{
"queries": [
{"query": "select $1 as answer", "params": [42], "arrayMode": True},
{"query": "select $1 as answer", "params": [42], "arrayMode": False},
]
}
),
headers={
"Content-Type": "application/sql",
"Neon-Connection-String": connstr,
"Neon-Batch-Isolation-Level": "Serializable",
"Neon-Batch-Read-Only": "false",
"Neon-Batch-Deferrable": "false",
},
verify=str(static_proxy.test_output_dir / "proxy.crt"),
)
assert response.status_code == 200
results = response.json()["results"]

assert results[0]["rowAsArray"]
assert results[0]["rows"] == [["42"]]

assert not results[1]["rowAsArray"]
assert results[1]["rows"] == [{"answer": "42"}]


def test_sql_over_http_pool(static_proxy: NeonProxy):
static_proxy.safe_psql("create user http_auth with password 'http' superuser")

Expand Down

1 comment on commit ea089dc

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2493 tests run: 2363 passed, 5 failed, 125 skipped (full report)


Failures on Postgres 14

  • test_multixact: release
  • test_pageserver_max_throughput_getpage_at_latest_lsn[10-6-30]: release
  • test_pageserver_max_throughput_getpage_at_latest_lsn[10-13-30]: release
  • test_pageserver_max_throughput_getpage_at_latest_lsn[1-6-30]: release
  • test_pageserver_max_throughput_getpage_at_latest_lsn[1-13-30]: release
# Run all failed tests locally:
scripts/pytest -vv -n $(nproc) -k "test_multixact[release-pg14] or test_pageserver_max_throughput_getpage_at_latest_lsn[10-6-30] or test_pageserver_max_throughput_getpage_at_latest_lsn[10-13-30] or test_pageserver_max_throughput_getpage_at_latest_lsn[1-6-30] or test_pageserver_max_throughput_getpage_at_latest_lsn[1-13-30]"
Flaky tests (1)

Postgres 14

  • test_ondemand_download_timetravel: debug

Test coverage report is not available

The comment gets automatically updated with the latest test results
ea089dc at 2024-02-09T11:43:38.307Z :recycle:

Please sign in to comment.