From ededdccb7374733e3273092d55b9c8d58564bd6c Mon Sep 17 00:00:00 2001 From: Grant Date: Sun, 14 May 2023 14:09:54 -0400 Subject: [PATCH] feat(citext): implement citext for postgres --- sqlx-postgres/src/any.rs | 2 + sqlx-postgres/src/type_info.rs | 2 + sqlx-postgres/src/types/citext.rs | 92 +++++++++++++++++++++++++++++++ sqlx-postgres/src/types/mod.rs | 2 + sqlx-postgres/src/types/str.rs | 1 + tests/postgres/types.rs | 13 ++++- 6 files changed, 110 insertions(+), 2 deletions(-) create mode 100644 sqlx-postgres/src/types/citext.rs diff --git a/sqlx-postgres/src/any.rs b/sqlx-postgres/src/any.rs index 4d1c593dff..1c5f7ea33b 100644 --- a/sqlx-postgres/src/any.rs +++ b/sqlx-postgres/src/any.rs @@ -13,6 +13,7 @@ use sqlx_core::connection::Connection; use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; +use sqlx_core::ext::ustr::UStr; use sqlx_core::transaction::TransactionManager; sqlx_core::declare_driver_with_optional_migrate!(DRIVER = Postgres); @@ -179,6 +180,7 @@ impl<'a> TryFrom<&'a PgTypeInfo> for AnyTypeInfo { PgType::Float8 => AnyTypeInfoKind::Double, PgType::Bytea => AnyTypeInfoKind::Blob, PgType::Text => AnyTypeInfoKind::Text, + PgType::DeclareWithName(UStr::Static("citext")) => AnyTypeInfoKind::Text, _ => { return Err(sqlx_core::Error::AnyDriverError( format!( diff --git a/sqlx-postgres/src/type_info.rs b/sqlx-postgres/src/type_info.rs index ae211d0d3a..1c03ea20e0 100644 --- a/sqlx-postgres/src/type_info.rs +++ b/sqlx-postgres/src/type_info.rs @@ -438,6 +438,7 @@ impl PgType { PgType::Int8RangeArray => Oid(3927), PgType::Jsonpath => Oid(4072), PgType::JsonpathArray => Oid(4073), + PgType::Custom(ty) => ty.oid, PgType::DeclareWithOid(oid) => *oid, @@ -855,6 +856,7 @@ impl PgType { PgType::Unknown => None, // There is no `VoidArray` PgType::Void => None, + PgType::Custom(ty) => match &ty.kind { PgTypeKind::Simple => None, PgTypeKind::Pseudo => None, diff --git a/sqlx-postgres/src/types/citext.rs b/sqlx-postgres/src/types/citext.rs new file mode 100644 index 0000000000..530a75b0e2 --- /dev/null +++ b/sqlx-postgres/src/types/citext.rs @@ -0,0 +1,92 @@ +use std::fmt; +use std::fmt::{Debug, Display, Formatter}; +use std::ops::Deref; +use std::str::FromStr; +use sqlx_core::decode::Decode; +use sqlx_core::encode::{Encode, IsNull}; +use sqlx_core::error::BoxDynError; +use sqlx_core::types::Type; +use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueRef, Postgres}; +use crate::types::array_compatible; + +/// Text type for case insensitive searching in Postgres. +/// +/// See https://www.postgresql.org/docs/current/citext.html +/// +/// ### Note: Extension Required +/// The `citext` extension is not enabled by default in Postgres. You will need to do so explicitly: +/// +/// ```ignore +/// CREATE EXTENSION IF NOT EXISTS "citext"; +/// ``` + +#[derive(Clone, Debug, Default, PartialEq)] +pub struct PgCitext(String); + +impl PgCitext { + pub fn new(s: String) -> Self { + Self(s) + } +} + +impl Type for PgCitext { + fn type_info() -> PgTypeInfo { + // Since `ltree` is enabled by an extension, it does not have a stable OID. + PgTypeInfo::with_name("citext") + } + + fn compatible(ty: &PgTypeInfo) -> bool { + <&str as Type>::compatible(ty) + } +} + +impl Deref for PgCitext { + type Target = str; + + fn deref(&self) -> &Self::Target { + self.0.as_str() + } +} + +impl From for PgCitext { + fn from(value: String) -> Self { + Self::new(value) + } +} + +impl FromStr for PgCitext { + type Err = core::convert::Infallible; + + fn from_str(s: &str) -> Result { + Ok(PgCitext(s.parse()?)) + } +} + +impl Display for PgCitext { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl PgHasArrayType for PgCitext { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::with_name("_citext") + } + + fn array_compatible(ty: &PgTypeInfo) -> bool { + array_compatible::<&str>(ty) + } +} + + +impl Encode<'_, Postgres> for PgCitext { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + <&str as Encode>::encode(&**self, buf) + } +} + +impl Decode<'_, Postgres> for PgCitext { + fn decode(value: PgValueRef<'_>) -> Result { + Ok(PgCitext(value.as_str()?.to_owned())) + } +} diff --git a/sqlx-postgres/src/types/mod.rs b/sqlx-postgres/src/types/mod.rs index 8749fe28ba..37030f09c9 100644 --- a/sqlx-postgres/src/types/mod.rs +++ b/sqlx-postgres/src/types/mod.rs @@ -180,6 +180,7 @@ mod int; mod interval; mod lquery; mod ltree; +mod citext; // Not behind a Cargo feature because we require JSON in the driver implementation. mod json; mod money; @@ -235,6 +236,7 @@ pub use ltree::PgLTreeParseError; pub use money::PgMoney; pub use oid::Oid; pub use range::PgRange; +pub use citext::PgCitext; #[cfg(any(feature = "chrono", feature = "time"))] pub use time_tz::PgTimeTz; diff --git a/sqlx-postgres/src/types/str.rs b/sqlx-postgres/src/types/str.rs index 53dda1f446..b42f7c9a49 100644 --- a/sqlx-postgres/src/types/str.rs +++ b/sqlx-postgres/src/types/str.rs @@ -18,6 +18,7 @@ impl Type for str { PgTypeInfo::BPCHAR, PgTypeInfo::VARCHAR, PgTypeInfo::UNKNOWN, + PgTypeInfo::with_name("citext") ] .contains(ty) } diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index 2b2de07f0f..692f662a9c 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -2,7 +2,7 @@ extern crate time_ as time; use std::ops::Bound; -use sqlx::postgres::types::{Oid, PgInterval, PgMoney, PgRange}; +use sqlx::postgres::types::{Oid, PgInterval, PgMoney, PgRange, PgCitext}; use sqlx::postgres::Postgres; use sqlx_test::{test_decode_type, test_prepared_type, test_type}; @@ -79,7 +79,7 @@ test_type!(string_vec>(Postgres, == vec!["", "\""], "array['Hello, World', '', 'Goodbye']::text[]" - == vec!["Hello, World", "", "Goodbye"] + == vec!["Hello, World", "", "Goodbye"], )); test_type!(string_array<[String; 3]>(Postgres, @@ -549,6 +549,15 @@ test_prepared_type!(money_vec>(Postgres, "array[123.45,420.00,666.66]::money[]" == vec![PgMoney(12345), PgMoney(42000), PgMoney(66666)], )); +test_prepared_type!(citext_array>(Postgres, + "array['one','two','three']::citext[]" == vec![ + PgCitext::new("one".to_string()), + PgCitext::new("two".to_string()), + PgCitext::new("three".to_string()), + ], +)); + + // FIXME: needed to disable `ltree` tests in version that don't have a binary format for it // but `PgLTree` should just fall back to text format #[cfg(any(postgres_14, postgres_15))]