Skip to content

Commit

Permalink
feat(citext): support postgres citext (launchbadge#2478)
Browse files Browse the repository at this point in the history
* feat(citext): implement citext for postgres

* feat(citext): add citext -> String conversion test

* feat(citext): fix ltree -> citree

* feat(citext): add citext to the setup.sql

* chore: address nits to launchbadge#2478

* Rename `PgCitext` to `PgCiText`
* Document when use of `PgCiText` is warranted
* Document potentially surprising `PartialEq` behavior
* Test that the macros consider `CITEXT` to be compatible with `String` and friends

* doc: add `PgCiText` to `postgres::types` listing

* chore: restore missing trailing line break to `tests/postgres/setup.sql`

---------

Co-authored-by: Austin Bonander <[email protected]>
  • Loading branch information
hgranthorner and abonander authored Oct 12, 2023
1 parent 3c2471e commit 56945d7
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 3 deletions.
2 changes: 2 additions & 0 deletions sqlx-postgres/src/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use sqlx_core::connection::Connection;
use sqlx_core::database::Database;
use sqlx_core::describe::Describe;
use sqlx_core::executor::Executor;
use sqlx_core::ext::ustr::UStr;
use sqlx_core::transaction::TransactionManager;

sqlx_core::declare_driver_with_optional_migrate!(DRIVER = Postgres);
Expand Down Expand Up @@ -178,6 +179,7 @@ impl<'a> TryFrom<&'a PgTypeInfo> for AnyTypeInfo {
PgType::Float8 => AnyTypeInfoKind::Double,
PgType::Bytea => AnyTypeInfoKind::Blob,
PgType::Text => AnyTypeInfoKind::Text,
PgType::DeclareWithName(UStr::Static("citext")) => AnyTypeInfoKind::Text,
_ => {
return Err(sqlx_core::Error::AnyDriverError(
format!("Any driver does not support the Postgres type {pg_type:?}").into(),
Expand Down
2 changes: 2 additions & 0 deletions sqlx-postgres/src/type_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ impl PgType {
PgType::Int8RangeArray => Oid(3927),
PgType::Jsonpath => Oid(4072),
PgType::JsonpathArray => Oid(4073),

PgType::Custom(ty) => ty.oid,

PgType::DeclareWithOid(oid) => *oid,
Expand Down Expand Up @@ -874,6 +875,7 @@ impl PgType {
PgType::Unknown => None,
// There is no `VoidArray`
PgType::Void => None,

PgType::Custom(ty) => match &ty.kind {
PgTypeKind::Simple => None,
PgTypeKind::Pseudo => None,
Expand Down
106 changes: 106 additions & 0 deletions sqlx-postgres/src/types/citext.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
use crate::types::array_compatible;
use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueRef, Postgres};
use sqlx_core::decode::Decode;
use sqlx_core::encode::{Encode, IsNull};
use sqlx_core::error::BoxDynError;
use sqlx_core::types::Type;
use std::fmt;
use std::fmt::{Debug, Display, Formatter};
use std::ops::Deref;
use std::str::FromStr;

/// Case-insensitive text (`citext`) support for Postgres.
///
/// Note that SQLx considers the `citext` type to be compatible with `String`
/// and its various derivatives, so direct usage of this type is generally unnecessary.
///
/// However, it may be needed, for example, when binding a `citext[]` array,
/// as Postgres will generally not accept a `text[]` array (mapped from `Vec<String>`) in its place.
///
/// See [the Postgres manual, Appendix F, Section 10][PG.F.10] for details on using `citext`.
///
/// [PG.F.10]: https://www.postgresql.org/docs/current/citext.html
///
/// ### Note: Extension Required
/// The `citext` extension is not enabled by default in Postgres. You will need to do so explicitly:
///
/// ```ignore
/// CREATE EXTENSION IF NOT EXISTS "citext";
/// ```
///
/// ### Note: `PartialEq` is Case-Sensitive
/// This type derives `PartialEq` which forwards to the implementation on `String`, which
/// is case-sensitive. This impl exists mainly for testing.
///
/// To properly emulate the case-insensitivity of `citext` would require use of locale-aware
/// functions in `libc`, and even then would require querying the locale of the database server
/// and setting it locally, which is unsafe.
#[derive(Clone, Debug, Default, PartialEq)]
pub struct PgCiText(pub String);

impl Type<Postgres> for PgCiText {
fn type_info() -> PgTypeInfo {
// Since `citext` is enabled by an extension, it does not have a stable OID.
PgTypeInfo::with_name("citext")
}

fn compatible(ty: &PgTypeInfo) -> bool {
<&str as Type<Postgres>>::compatible(ty)
}
}

impl Deref for PgCiText {
type Target = str;

fn deref(&self) -> &Self::Target {
self.0.as_str()
}
}

impl From<String> for PgCiText {
fn from(value: String) -> Self {
Self(value)
}
}

impl From<PgCiText> for String {
fn from(value: PgCiText) -> Self {
value.0
}
}

impl FromStr for PgCiText {
type Err = core::convert::Infallible;

fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(PgCiText(s.parse()?))
}
}

impl Display for PgCiText {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}

impl PgHasArrayType for PgCiText {
fn array_type_info() -> PgTypeInfo {
PgTypeInfo::with_name("_citext")
}

fn array_compatible(ty: &PgTypeInfo) -> bool {
array_compatible::<&str>(ty)
}
}

impl Encode<'_, Postgres> for PgCiText {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
<&str as Encode<Postgres>>::encode(&**self, buf)
}
}

impl Decode<'_, Postgres> for PgCiText {
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(PgCiText(value.as_str()?.to_owned()))
}
}
9 changes: 8 additions & 1 deletion sqlx-postgres/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@
//! | `i64` | BIGINT, BIGSERIAL, INT8 |
//! | `f32` | REAL, FLOAT4 |
//! | `f64` | DOUBLE PRECISION, FLOAT8 |
//! | `&str`, [`String`] | VARCHAR, CHAR(N), TEXT, NAME |
//! | `&str`, [`String`] | VARCHAR, CHAR(N), TEXT, NAME, CITEXT |
//! | `&[u8]`, `Vec<u8>` | BYTEA |
//! | `()` | VOID |
//! | [`PgInterval`] | INTERVAL |
//! | [`PgRange<T>`](PgRange) | INT8RANGE, INT4RANGE, TSRANGE, TSTZRANGE, DATERANGE, NUMRANGE |
//! | [`PgMoney`] | MONEY |
//! | [`PgLTree`] | LTREE |
//! | [`PgLQuery`] | LQUERY |
//! | [`PgCiText`] | CITEXT<sup>1</sup> |
//!
//! <sup>1</sup> SQLx generally considers `CITEXT` to be compatible with `String`, `&str`, etc.,
//! but this wrapper type is available for edge cases, such as `CITEXT[]` which Postgres
//! does not consider to be compatible with `TEXT[]`.
//!
//! ### [`bigdecimal`](https://crates.io/crates/bigdecimal)
//! Requires the `bigdecimal` Cargo feature flag.
Expand Down Expand Up @@ -175,6 +180,7 @@ pub(crate) use sqlx_core::types::{Json, Type};
mod array;
mod bool;
mod bytes;
mod citext;
mod float;
mod int;
mod interval;
Expand Down Expand Up @@ -224,6 +230,7 @@ mod mac_address;
mod bit_vec;

pub use array::PgHasArrayType;
pub use citext::PgCiText;
pub use interval::PgInterval;
pub use lquery::PgLQuery;
pub use lquery::PgLQueryLevel;
Expand Down
1 change: 1 addition & 0 deletions sqlx-postgres/src/types/str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ impl Type<Postgres> for str {
PgTypeInfo::BPCHAR,
PgTypeInfo::VARCHAR,
PgTypeInfo::UNKNOWN,
PgTypeInfo::with_name("citext"),
]
.contains(ty)
}
Expand Down
25 changes: 25 additions & 0 deletions tests/postgres/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -611,3 +611,28 @@ async fn test_bind_arg_override_wildcard() -> anyhow::Result<()> {

Ok(())
}

#[sqlx_macros::test]
async fn test_to_from_citext() -> anyhow::Result<()> {
// Ensure that the macros consider `CITEXT` to be compatible with `String` and friends

let mut conn = new::<Postgres>().await?;

let mut tx = conn.begin().await?;

let foo_in = "Hello, world!";

sqlx::query!("insert into test_citext(foo) values ($1)", foo_in)
.execute(&mut *tx)
.await?;

let foo_out: String = sqlx::query_scalar!("select foo from test_citext")
.fetch_one(&mut *tx)
.await?;

assert_eq!(foo_in, foo_out);

tx.rollback().await?;

Ok(())
}
7 changes: 7 additions & 0 deletions tests/postgres/setup.sql
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
-- https://www.postgresql.org/docs/current/ltree.html
CREATE EXTENSION IF NOT EXISTS ltree;

-- https://www.postgresql.org/docs/current/citext.html
CREATE EXTENSION IF NOT EXISTS citext;

-- https://www.postgresql.org/docs/current/sql-createtype.html
CREATE TYPE status AS ENUM ('new', 'open', 'closed');

Expand Down Expand Up @@ -44,3 +47,7 @@ CREATE TABLE products (

CREATE OR REPLACE PROCEDURE forty_two(INOUT forty_two INT = NULL)
LANGUAGE plpgsql AS 'begin forty_two := 42; end;';

CREATE TABLE test_citext (
foo CITEXT NOT NULL
);
13 changes: 11 additions & 2 deletions tests/postgres/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ extern crate time_ as time;

use std::ops::Bound;

use sqlx::postgres::types::{Oid, PgInterval, PgMoney, PgRange};
use sqlx::postgres::types::{Oid, PgCiText, PgInterval, PgMoney, PgRange};
use sqlx::postgres::Postgres;
use sqlx_test::{test_decode_type, test_prepared_type, test_type};

Expand Down Expand Up @@ -65,6 +65,7 @@ test_type!(str<&str>(Postgres,
"'identifier'::name" == "identifier",
"'five'::char(4)" == "five",
"'more text'::varchar" == "more text",
"'case insensitive searching'::citext" == "case insensitive searching",
));

test_type!(string<String>(Postgres,
Expand All @@ -79,7 +80,7 @@ test_type!(string_vec<Vec<String>>(Postgres,
== vec!["", "\""],

"array['Hello, World', '', 'Goodbye']::text[]"
== vec!["Hello, World", "", "Goodbye"]
== vec!["Hello, World", "", "Goodbye"],
));

test_type!(string_array<[String; 3]>(Postgres,
Expand Down Expand Up @@ -550,6 +551,14 @@ test_prepared_type!(money_vec<Vec<PgMoney>>(Postgres,
"array[123.45,420.00,666.66]::money[]" == vec![PgMoney(12345), PgMoney(42000), PgMoney(66666)],
));

test_prepared_type!(citext_array<Vec<PgCiText>>(Postgres,
"array['one','two','three']::citext[]" == vec![
PgCiText("one".to_string()),
PgCiText("two".to_string()),
PgCiText("three".to_string()),
],
));

// FIXME: needed to disable `ltree` tests in version that don't have a binary format for it
// but `PgLTree` should just fall back to text format
#[cfg(any(postgres_14, postgres_15))]
Expand Down

0 comments on commit 56945d7

Please sign in to comment.