diff --git a/modules/core/src/main/scala/sangria/schema/AstSchemaMaterializer.scala b/modules/core/src/main/scala/sangria/schema/AstSchemaMaterializer.scala index 7c77c70a..45db33aa 100644 --- a/modules/core/src/main/scala/sangria/schema/AstSchemaMaterializer.scala +++ b/modules/core/src/main/scala/sangria/schema/AstSchemaMaterializer.scala @@ -576,10 +576,10 @@ class AstSchemaMaterializer[Ctx] private ( def getNamedType( origin: MatOrigin, typeName: String, - location: Option[AstLocation]): Type with Named = - typeDefCache.getOrElseUpdate( - origin -> typeName, - Schema.getBuiltInType(typeName).getOrElse { + location: Option[AstLocation]): Type with Named = { + + def update = { + val builtType = { val existing = existingDefsMat.get(typeName).toVector val sdl = typeDefsMat.filter(_.name == typeName) val additional = builder.additionalTypes.filter(_.name == typeName).toVector @@ -606,11 +606,21 @@ class AstSchemaMaterializer[Ctx] private ( getNamedType(origin, allCandidates.head) } else None - builtType.getOrElse( - throw MaterializedSchemaValidationError(Vector( - UnknownTypeViolation(typeName, Seq.empty, document.sourceMapper, location.toList)))) + builtType } + + builtType + .orElse(Schema.getBuiltInType(typeName)) + .getOrElse(throw MaterializedSchemaValidationError(Vector( + UnknownTypeViolation(typeName, Seq.empty, document.sourceMapper, location.toList)))) + + } + + typeDefCache.getOrElseUpdate( + origin -> typeName, + update ) + } def getNamedType(origin: MatOrigin, tpe: MaterializedType): Option[Type with Named] = tpe match { diff --git a/modules/core/src/test/scala/sangria/schema/CustomScalarSpec.scala b/modules/core/src/test/scala/sangria/schema/CustomScalarSpec.scala index 6ba2c9a5..70889165 100644 --- a/modules/core/src/test/scala/sangria/schema/CustomScalarSpec.scala +++ b/modules/core/src/test/scala/sangria/schema/CustomScalarSpec.scala @@ -4,6 +4,7 @@ import java.text.SimpleDateFormat import java.util.Date import sangria.ast +import sangria.macros._ import sangria.util.Pos import sangria.util.SimpleGraphQlSupport._ import sangria.validation.ValueCoercionViolation @@ -11,6 +12,7 @@ import sangria.validation.ValueCoercionViolation import scala.util.{Failure, Success, Try} import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec +import sangria.validation.BigDecimalCoercionViolation class CustomScalarSpec extends AnyWordSpec with Matchers { "Schema" should { @@ -80,5 +82,48 @@ class CustomScalarSpec extends AnyWordSpec with Matchers { Pos(3, 28))) ) } + + "allow to overwrite built-in scalar type" in { + + def parseBigDecimal(s: String) = Try(BigDecimal(s)) match { + case Success(d) => Right(d) + case Failure(error) => Left(BigDecimalCoercionViolation) + } + + val BigDecimalType = ScalarType[BigDecimal]( + "BigDecimal", + description = Some("A string only BigDecimal type"), + coerceOutput = (d, _) => d.toString, + coerceUserInput = { + case s: String => parseBigDecimal(s) + case _ => Left(BigDecimalCoercionViolation) + }, + coerceInput = { + case ast.StringValue(s, _, _, _, _) => parseBigDecimal(s) + case _ => Left(BigDecimalCoercionViolation) + } + ) + + val schemaSdl = graphql""" + scalar BigDecimal + + type Query { + age: BigDecimal! + } + """ + + val schema = Schema.buildFromAst[Unit]( + schemaSdl, + AstSchemaBuilder.resolverBased( + ScalarResolver[Unit] { case ast.ScalarTypeDefinition("BigDecimal", _, _, _, _) => + BigDecimalType + } + ) + ) + + schema.scalarTypes.collectFirst { case ("BigDecimal", tpe) => tpe } should be( + Some(BigDecimalType)) + + } } }