株式会社ヘンリー エンジニアブログ

株式会社ヘンリーのエンジニアが技術情報を発信します

【Scala 3 macroがすごい】Compiletime API編

株式会社ヘンリーでメタプログラミングに没頭しているgiiitaです。

突然ですが皆さんはメタプログラミングに触れたことがあるでしょうか? プログラミング言語によって様々なメタプログラミングの機能があります。Javaやその派生言語では馴染み深い「リフレクション」はその代表例ですが、LispやRust, Scalaなどにはマクロと呼ばれる、compile前にコードを操作する仕組みがあります。 そんな中でも最近急激な変化を遂げている Scala3 のMacroがとんでもなくすごいんですが、なかなかまだ情報が出回っておらず、手を出すにはあまりにも敷居が高くなっているため、皆さんに面白さを知ってもらうべく紹介します。

マクロって何ができるの?

そもそもマクロに手を出しづらい背景として、何に使えるのか、いつ使うべきなのかよくわからないからというのが大きいのではないでしょうか? 現に私自身、手を出すまでイメージがつかず、わざわざ難解でExperimentalな機能 (Scala2系当時) を使ってまで何かをしたいというモチベーションがありませんでした。 とは言いつつも、今なら明確にこういうケースで使うべきですと言えるわけではないので難しいところですが、基本的には コードの変換自動生成 といったところが主な目的です。 そして何より、それら成果物を compilerによって静的に検査できる というのが旨味なわけです。

使ってみよう

今回はひとまず先日の3.3.0リリースを祝って、3.3.0でいろいろやっていきたいと思います。

Setup

まずはModule構成です。マクロを使用したmoduleは、利用するmoduleのbuild時にはbuildが完了していないといけないので、macroを定義するmoduleと呼び出すmoduleが必要です。 (確かこれもScala3で変更があったような気も...

build.sbt

lazy val Scala3_3  = "3.3.0"
scalaVersion in Scope.Global := Scala3_3

lazy val root = (project in file("."))
  .aggregate(
    macroModule,
    runtimeModule,
  )
lazy val macroModule = (project in file("macro-module"))
  .settings(
    libraryDependencies ++= {
      Seq(
        "org.scala-lang" %% "scala3-compiler" % scalaVersion.value,
      )
    }
  )
lazy val runtimeModule = (project in file("runtime-module"))
  .dependsOn(macroModule)

もはや何も特別な事はありません。 さて、何を作ろうか悩むところですが、わかりやすくここは誰もが通る道という事で、JsonParserを作っていきたいと思います。 String => JsonJson => String が必要なのでI/Fを切りましょう。この時、マクロから利用するものはmacro以下のmoduleに定義する必要があります。

trait Read[T] {
  def read(json: Json): T
}
trait Write[T] {
  def write(t: T): Json
}
trait Both[T] extends Read[T] with Write[T]

ここでは変異境界は無しでやっていきます。

trait Json extends Serializable {
  def nameOf(key: String): Json
}
case class JsonString(value: String) extends Json {
  override def toString: String = s""""$value""""
  override def nameOf(key: String): Json = JsonNull
}
case class JsonAny(value: String) extends Json {
  override def toString: String = value
  override def nameOf(key: String): Json = JsonNull
}
case class JsonArray(values: Seq[Json]) extends Json {
  override def toString: String = s"""[${values.mkString(",")}]"""
  override def nameOf(key: String): Json = JsonNull
}
case class JsonObject(values: Map[JsonString, Json]) extends Json {
  override def toString: String = s"""{${
    values.map { case (key, value) =>
      s"$key:$value"
    }.mkString(",")
  }}"""

  override def nameOf(key: String): Json = values(JsonString(key))
}
case object JsonNull extends Json {
  override def nameOf(key: String): Json = this
}

Jsonの型としては一旦こんなところでしょう。エスケープ処理は面倒なので考慮していません。 性能面では話にならないのであまり参考にしないでください。String to Jsonの処理もちゃんとやると結構面倒なのでイメージです。

さて、ここからいよいよマクロを書いていきますが、Scala3のマクロには2種類あります。

一つはお馴染み、AST(抽象構文木: Abstract Syntax Tree) を直接的に操作する方法です。 もう一つはScalaのコードがCompileされる前段のフェースでPrecompileされる、半マクロ的なものです。これは、scala3-libraryの scala.compiletime パッケージにAPIがあります。

Compiletime API

後者は近しいものがScala2にもありましたが、非常に強化されました。非常に簡単かつ、安全に使用できるAPIなので、出番も多いかもしれません。これはCompiletime APIと呼ばれています。 まずはこれを用いて、CaseClassのJson変換器を導出していきます。

// runtimeModule
object CodecGenerator {
  inline final def CaseClass[T]: Both[T] = InferCodecs.gen[T]
}

実行module側からinline functionを呼び出しています。 InferCodecs.gen は単なるinline関数であり、マクロ展開はされませんが、前述の通りPrecompileによってinline化されます。

// macroModule
object InferCodecs {
  inline def gen[A]: Both[A] = {
    summonFrom[Both[A]] {
      case given Both[A] => implicitly[Both[A]]
      case _: Mirror.ProductOf[A] => InferCodecs.derivedCodec[A]
      case _ => error("Cannot inferred.")
    }
  }

  inline def derivedCodec[A](using inline A: Mirror.ProductOf[A]): Both[A] =
    new Both[A] {
      override def write(value: A): Json = Writes.inferWrite[A].write(value)
      override def read(value: Json): A = Reads.inferRead[A].read(value)
    }

  trait ProductProjection {

    transparent inline def inferLabels[T <: Tuple]: List[String] = foldElementLabels[T]

    transparent inline def foldElementLabels[T <: Tuple]: List[String] =
      inline erasedValue[T] match {
        case _: EmptyTuple =>
          Nil
        case _: (t *: ts) =>
          constValue[t].asInstanceOf[String] :: foldElementLabels[ts]
      }
  }

  object Reads extends ProductProjection {

    inline def inferRead[A]: Read[A] = {
      summonFrom[Read[A]] {
        case x: Read[A] => x
        case _: Mirror.ProductOf[A] => Reads.derivedRead[A]
        case _ => error("Cannot inferred")
      }
    }

    transparent inline def derivedRead[A](using A: Mirror.ProductOf[A]): Read[A] =
      new Read[A] {
        private[this] val elemLabels = inferLabels[A.MirroredElemLabels]

        private[this] val elemReads: List[Read[_]] =
          inferReads[A.MirroredElemTypes]

        private[this] val elemSignature = elemLabels.zip(elemReads).zipWithIndex

        private[this] val elemCount = elemSignature.size

        override def read(value: Json): A = {
          val buffer = new Array[Any](elemCount)
          elemSignature.foreach { case ((label, read), i) =>
            buffer(i) = {
              read.read(value.nameOf(label))
            }
          }
          A.fromProduct(
            new Product {
              override def canEqual(that: Any): Boolean = true

              override def productArity: Int = elemCount

              override def productElement(n: Int): Any =
                buffer(n)
            }
          )
        }
      }

    private inline def inferReads[T <: Tuple]: List[Read[_]] = foldReads[T]

    private inline def foldReads[T <: Tuple]: List[Read[_]] =
      inline erasedValue[T] match {
        case _: EmptyTuple =>
          Nil
        case _: (t *: ts) =>
          inferRead[t] :: foldReads[ts]
      }
  }
  object Writes extends ProductProjection {
    inline def inferWrite[A]: Write[A] = {
      summonFrom[Write[A]] {
        case x: Write[A] => x
        case _: Mirror.ProductOf[A] => Writes.derivedWrite[A]
        case _ => error("Cannot inferred")
      }
    }

    inline def derivedWrite[A](using A: Mirror.ProductOf[A]): Write[A] =
      new Write[A] {
        private[this] val elemLabels = inferLabels[A.MirroredElemLabels]

        private[this] val elemDecoders: List[Write[_]] =
          inferWrites[A.MirroredElemTypes]

        private[this] val elemSignature = elemLabels.zip(elemDecoders).zipWithIndex

        private[this] val elemCount = elemSignature.size

        override def write(t: A): Json = {
          val entries = t.asInstanceOf[Product].productIterator.toArray
          JsonObject(
            (0 until elemCount).map { i =>
              JsonString(elemLabels(i)) -> elemDecoders(i).asInstanceOf[Write[Any]].write(entries(i))
            }.toMap
          )
        }
      }

    private inline def inferWrites[T <: Tuple]: List[Write[_]] = foldWrites[T]

    private inline def foldWrites[T <: Tuple]: List[Write[_]] =
      inline erasedValue[T] match {
        case _: EmptyTuple =>
          Nil
        case _: (t *: ts) =>
          inferWrite[t] :: foldWrites[ts]
      }
  }
}

複雑に見えますが、順に見ていきます。

InferCodecs.gen
  inline def gen[A]: Both[A] = {
    summonFrom[Both[A]] {
      case given Both[A] => implicitly[Both[A]]
      case _: Mirror.ProductOf[A] => InferCodecs.derivedCodec[A]
      case _ => error("Cannot inferred.")
    }
  }

summonFrom というのはcompiletime APIで、scala2における implicitly[T] に近いものです。 パターンマッチによって、 given Both[A] がimplicit scopeに見つかればそれを返して終わります。

case _: Mirror.ProductOf[A] もpackageこそ変わっていますが馴染み深いのではないでしょうか?要するにケースクラスであり、メンバの型情報が導出できれば、という分岐で、 given Both[A] は見つからないけど、型情報導出できそうなのでReadとWriteをそれぞれ導出するぜ!という処理です。

Reads.inferReads
    inline def inferRead[A]: Read[A] = {
      summonFrom[Read[A]] {
        case x: Read[A] => x
        case _: Mirror.ProductOf[A] => Reads.derivedRead[A]
        case _ => error("Cannot inferred")
      }
    }

これも一緒ですね、さっきのは Both[T] の導出だったのに対し、今回は Read[T] の導出になっているだけです。

Reads.derivedRead
    transparent inline def derivedRead[A](using A: Mirror.ProductOf[A]): Read[A] =
      new Read[A] {
        private[this] val elemLabels = inferLabels[A.MirroredElemLabels]

        private[this] val elemReads: List[Read[_]] =
          inferReads[A.MirroredElemTypes]

        private[this] val elemSignature = elemLabels.zip(elemReads).zipWithIndex

        private[this] val elemCount = elemSignature.size

        override def read(value: Json): A = {
          val buffer = new Array[Any](elemCount)
          elemSignature.foreach { case ((label, read), i) =>
            buffer(i) = {
              read.read(value.nameOf(label))
            }
          }
          A.fromProduct(
            new Product {
              override def canEqual(that: Any): Boolean = true

              override def productArity: Int = elemCount

              override def productElement(n: Int): Any =
                buffer(n)
            }
          )
        }
      }    

    private inline def inferReads[T <: Tuple]: List[Read[_]] = foldReads[T]

    private inline def foldReads[T <: Tuple]: List[Read[_]] =
      inline erasedValue[T] match {
        case _: EmptyTuple =>
          Nil
        case _: (t *: ts) =>
          inferRead[t] :: foldReads[ts]
      }

さて、いよいよcompiletime APIの本領発揮です。

inferReads[T <: Tuple]: List[Read[_]] = foldReads[T]

これはシグネチャから想像できる通り、任意のTupleのサブタイプ T の要素から Read[T] を導出してListで返しています。

inline erasedValue[T] これはまぁ見ての通りですが、 T <: Tupletypeunapply 的な操作である事は想像できます。 ※リファレンス

EmptyTuple か、 _: (t *: ts) かのパターンがあるという事は、 (A, B, C)(A *: (B, C))unapply されるという事ですね。そして A, B, C が順次導出されていきます。WritesもReadと同じですね。

        private[this] val elemLabels = inferLabels[A.MirroredElemLabels]

        private[this] val elemReads: List[Read[_]] =
          inferReads[A.MirroredElemTypes]

        private[this] val elemSignature = elemLabels.zip(elemReads).zipWithIndex

        private[this] val elemCount = elemSignature.size

この辺はなんとなく想像しやすいと思います。型情報からラベル (変数名) とそれに対応する再起的に導出された Read[_]を持ち、 read(value: Json) が呼び出された際にはマッピングして各メンバをdeserializeしますよ、という事ですね。

さて、これでケースクラスのCodecは導出できるようになりました。後は末端のプリミティブ型のCodecが事前に定義されていて、implicit scopeに見つかれば動きそうです。 サンプルとして、末端がStringかIntであればいいとしましょう。

object Codecs {
  given Both[Int] with {
    override def read(json: Json): Int = json.toString.toInt
    override def write(t: Int): Json = JsonAny(t.toString)
  }

  given Both[String] with {
    override def read(json: Json): String = json.toString

    override def write(t: String): Json = JsonString(t)
  }
}

object Main extends App {
  export Codecs.given

  case class Root(child: Child)
  case class Child(str: String, num: Int)

  val codec = InferCodecs.gen[Root]
  val raw = codec.write(
      Root(
        Child("A", 1),
      )
  )
  println(raw)

  val value = codec.read(raw)
  println(value)

}

出力結果はこんな感じになりました。

{"child":{"str":"A","num":1}}
Root(Child("A",1))

雑な作りではありますが、雰囲気はつかめたのではないでしょうか? このサンプルでは、 given Both[T]Mirror.ProductOf[A] のいずれかしか処理していないので、間に Seq[T] などが入ってきてしまうとcompileできなくなります。 そして、マクロの呼び出し元はcompileの時点で異常に気付くことができるようになるわけです。

※本内容には、一部個人的な解釈を含んでいます。

最後に

面白い反面、マクロの運用にはなかなかリスクも伴います。 Scala3はまだまだガンガン開発が進んでいるので、3.0 ~ 3.3にかけて結構APIが変わっていたり、それだけならまだしもシグネチャのシンボルがこっそり変わったりして、OSSによってはScala3.xごとにsrcクラスパスを分けていたりします。 すでに、cross buildという観点では、同一のsrcクラスパスで各マイナーバージョンでのマクロmoduleのビルドはなかなか難しいかもしれません。

Scala2と比較して、構文literal直書き (scala.reflect.api.Quasiquotes) ができなくなった分マシにはなりましたが、結構めちゃくちゃなことをやっているところも少なくないので、適当なOSSのマクロ関数を読む事はあまりお勧めできません。 何より一番厳しいのは情報の少なさで、実装を読むか、公式 に聞く以外お手上げなんてことも... そんな中でも少しでも情報を集めている方は、私がメンテナンスしている DI framework のコードでも参考にして見てください。少しは助けになるかもしれません。

次回はもっと自由度の高いMacro APIについてお話ししようと思います。 Techオタクなあなた、是非私たちとお話しましょう!

jobs.henry-app.jp