From 1067fbe8d5f0e353e47518296a37fdfb10738bfd Mon Sep 17 00:00:00 2001 From: David Gregory Date: Wed, 25 Jun 2025 00:19:19 +0100 Subject: [PATCH] Add support for Scala 3 enums --- .../play/api/libs/json/JsMacroImpl.scala | 12 +++- .../play/api/libs/json/QuotesHelper.scala | 61 ++++++------------- .../play/api/libs/json/MacroScala3Spec.scala | 49 +++++++++++++++ 3 files changed, 77 insertions(+), 45 deletions(-) diff --git a/play-json/shared/src/main/scala-3/play/api/libs/json/JsMacroImpl.scala b/play-json/shared/src/main/scala-3/play/api/libs/json/JsMacroImpl.scala index c00872d51..c57151db2 100644 --- a/play-json/shared/src/main/scala-3/play/api/libs/json/JsMacroImpl.scala +++ b/play-json/shared/src/main/scala-3/play/api/libs/json/JsMacroImpl.scala @@ -241,8 +241,12 @@ object JsMacroImpl { // TODO: debug TypeRepr.of[String] ) + val termSym = tpr.termSymbol + val typeSym = tpr.typeSymbol + val isCase = termSym.flags.is(Flags.Case) + val tpeCaseName: Expr[String] = '{ - ${ config }.typeNaming(${ Expr(typeName(tpr.typeSymbol)) }) + ${ config }.typeNaming(${ Expr(typeName(if isCase then termSym else typeSym)) }) } val resolve = resolver[Reads, sub]( @@ -595,8 +599,12 @@ object JsMacroImpl { // TODO: debug subTpr ) + val termSym = tpr.termSymbol + val typeSym = tpr.typeSymbol + val isCase = termSym.flags.is(Flags.Case) + val tpeCaseName: Expr[String] = '{ - ${ config }.typeNaming(${ Expr(typeName(tpr.typeSymbol)) }) + ${ config }.typeNaming(${ Expr(typeName(if isCase then termSym else typeSym)) }) } val resolve = resolver[Writes, sub]( diff --git a/play-json/shared/src/main/scala-3/play/api/libs/json/QuotesHelper.scala b/play-json/shared/src/main/scala-3/play/api/libs/json/QuotesHelper.scala index 48612faea..2095269b3 100644 --- a/play-json/shared/src/main/scala-3/play/api/libs/json/QuotesHelper.scala +++ b/play-json/shared/src/main/scala-3/play/api/libs/json/QuotesHelper.scala @@ -4,12 +4,11 @@ package play.api.libs.json +import scala.deriving.Mirror import scala.util.Try as TryResult import scala.util.Success as TrySuccess import scala.util.Failure as TryFailure - import scala.deriving.Mirror.ProductOf - import scala.quoted.Expr import scala.quoted.Quotes import scala.quoted.Type @@ -43,51 +42,27 @@ private[json] trait QuotesHelper { * Class `Lorem` is listed through `SubFoo`, * but `SubFoo` itself is not returned. */ - final def knownSubclasses(tpr: TypeRepr): Option[List[TypeRepr]] = - tpr.classSymbol.flatMap { cls => - @annotation.tailrec - def subclasses( - children: List[Tree], - out: List[TypeRepr] - ): List[TypeRepr] = { - val childTpr = children.headOption.collect { - case tpd: Typed => - tpd.tpt.tpe - - case vd: ValDef => - vd.tpt.tpe - - case cd: ClassDef => - cd.constructor.returnTpt.tpe - - } - - childTpr match { - case Some(child) => { - val tpeSym = child.typeSymbol - - if (tpeSym.flags.is(Flags.Abstract) && - tpeSym.flags.is(Flags.Sealed) && - !(child <:< anyValTpe)) || - (tpeSym.flags.is(Flags.Sealed) && - tpeSym.flags.is(Flags.Trait)) - then { - // Ignore sub-trait itself, but check the sub-sub-classes - subclasses(tpeSym.children.map(_.tree) ::: children.tail, out) - } else { - subclasses(children.tail, child :: out) - } + final def knownSubclasses(tpr: TypeRepr): Option[List[TypeRepr]] = { + def gatherNestedSubtypes[Parent: Type, Elems: Type](using Quotes): List[TypeRepr] = + Type.of[Elems] match { + case '[elem *: elems] => + Expr.summon[Mirror.Of[elem]] match { + case Some('{ $sum: Mirror.SumOf[elem] { type MirroredElemTypes = elementTypes } }) => + gatherNestedSubtypes[elem, elementTypes] ++ gatherNestedSubtypes[Parent, elems] + case _ => + TypeRepr.of[elem] :: gatherNestedSubtypes[Parent, elems] } - - case _ => - out.reverse - } + case '[EmptyTuple] => Nil } - val types = subclasses(cls.children.map(_.tree), Nil) - - if types.isEmpty then None else Some(types) + tpr.asType match { + case '[t] => + Expr.summon[Mirror.Of[t]].collect { + case '{ $sum: Mirror.SumOf[t] { type MirroredElemTypes = elementTypes } } => + gatherNestedSubtypes[t, elementTypes] + } } + } @annotation.tailrec private def withElems[U <: Product]( diff --git a/play-json/shared/src/test/scala-3/play/api/libs/json/MacroScala3Spec.scala b/play-json/shared/src/test/scala-3/play/api/libs/json/MacroScala3Spec.scala index a957fb9e7..bd8972f07 100644 --- a/play-json/shared/src/test/scala-3/play/api/libs/json/MacroScala3Spec.scala +++ b/play-json/shared/src/test/scala-3/play/api/libs/json/MacroScala3Spec.scala @@ -4,12 +4,14 @@ package play.api.libs.json +import org.scalatest.EitherValues import org.scalatest.matchers.must.Matchers import org.scalatest.wordspec.AnyWordSpec final class MacroScala3Spec extends AnyWordSpec with Matchers + with EitherValues with org.scalatestplus.scalacheck.ScalaCheckPropertyChecks { "Case class" should { "not be handled" when { @@ -57,6 +59,44 @@ final class MacroScala3Spec } } } + + "Scala 3 enum" should { + "be handled" when { + "declared with no-arg cases" in { + given Format[Color.Red.type] = Json.format[Color.Red.type] + given Format[Color.Green.type] = Json.format[Color.Green.type] + given Format[Color.Blue.type] = Json.format[Color.Blue.type] + val format = Json.format[Color] + + val redJson = Json.obj("_type" -> "play.api.libs.json.Color.Red") + format.writes(Color.Red).mustEqual(redJson) + format.reads(redJson).asEither.value.mustEqual(Color.Red) + + val greenJson = Json.obj("_type" -> "play.api.libs.json.Color.Green") + format.writes(Color.Green).mustEqual(greenJson) + format.reads(greenJson).asEither.value.mustEqual(Color.Green) + + val blueJson = Json.obj("_type" -> "play.api.libs.json.Color.Blue") + format.writes(Color.Blue).mustEqual(blueJson) + format.reads(blueJson).asEither.value.mustEqual(Color.Blue) + } + } + + "declared with single-arg case" in { + given Format[IntOption.Some] = Json.format[IntOption.Some] + given Format[IntOption.None.type] = Json.format[IntOption.None.type] + given format: Format[IntOption] = Json.format[IntOption] + + val someValue = IntOption.Some(1) + val someJson = Json.obj("_type" -> "play.api.libs.json.IntOption.Some", "value" -> 1) + format.writes(someValue).mustEqual(someJson) + format.reads(someJson).asEither.value.mustEqual(someValue) + + val noneJson = Json.obj("_type" -> "play.api.libs.json.IntOption.None") + format.writes(IntOption.None).mustEqual(noneJson) + format.reads(noneJson).asEither.value.mustEqual(IntOption.None) + } + } } final class CustomNoProductOf(val name: String, val age: Int) @@ -66,3 +106,12 @@ object CustomNoProductOf { given Conversion[CustomNoProductOf, Tuple2[String, Int]] = (v: CustomNoProductOf) => v.name -> v.age } + +enum Color { + case Red, Green, Blue +} + +enum IntOption { + case Some(value: Int) + case None +}