From aae830d74b24a576e541afbdfc3acfcebb29c6f5 Mon Sep 17 00:00:00 2001 From: Vincent de Haan Date: Wed, 7 Sep 2022 13:23:09 +0200 Subject: [PATCH] Add Fragments.in with 3 parameters --- modules/core/src/main/scala/doobie/util/fragments.scala | 4 ++++ modules/core/src/test/scala/doobie/util/FragmentsSuite.scala | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/modules/core/src/main/scala/doobie/util/fragments.scala b/modules/core/src/main/scala/doobie/util/fragments.scala index 815c41eb2..950b794ec 100644 --- a/modules/core/src/main/scala/doobie/util/fragments.scala +++ b/modules/core/src/main/scala/doobie/util/fragments.scala @@ -24,6 +24,10 @@ object fragments { def in[F[_]: Reducible, A: util.Put, B: util.Put](f: Fragment, fs: F[(A,B)]): Fragment = fs.toList.map { case (a,b) => fr0"($a,$b)" }.foldSmash1(f ++ fr0"IN (", fr",", fr")") + /** Returns `f IN ((fs0-A, fs0-B, fs0-C), (fs1-A, fs1-B, fs1-C), ...)`. */ + def in[F[_]: Reducible, A: util.Put, B: util.Put, C: util.Put](f: Fragment, fs: F[(A,B,C)]): Fragment = + fs.toList.map { case (a,b,c) => fr0"($a,$b,$c)" }.foldSmash1(f ++ fr0"IN (", fr",", fr")") + /** Returns `f NOT IN (fs0, fs1, ...)`. */ def notIn[F[_]: Reducible, A: util.Put](f: Fragment, fs: F[A]): Fragment = fs.toList.map(a => fr0"$a").foldSmash1(f ++ fr0"NOT IN (", fr",", fr")") diff --git a/modules/core/src/test/scala/doobie/util/FragmentsSuite.scala b/modules/core/src/test/scala/doobie/util/FragmentsSuite.scala index 808c3e8a8..2e75b8737 100644 --- a/modules/core/src/test/scala/doobie/util/FragmentsSuite.scala +++ b/modules/core/src/test/scala/doobie/util/FragmentsSuite.scala @@ -40,6 +40,10 @@ class FragmentsSuite extends munit.FunSuite { assertEquals(in(fr"foo", NonEmptyList.of((1, true), (2, false))).query[Unit].sql, "foo IN ((?,?), (?,?)) ") } + test("in for three columns") { + assertEquals(in(fr"foo", NonEmptyList.of((1, true, 3), (2, false, 4))).query[Unit].sql, "foo IN ((?,?,?), (?,?,?)) ") + } + test("notIn") { assertEquals(notIn(fr"foo", nel).query[Unit].sql, "foo NOT IN (?, ?, ?) ") }