Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions modules/cli/src/main/scala/dev/guardrail/cli/CLICommon.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import scala.language.reflectiveCalls

import dev.guardrail._
import dev.guardrail.core.{ LogLevel, LogLevels }
import dev.guardrail.generators.ScalaVersion
import dev.guardrail.terms.protocol.PropertyRequirement
import dev.guardrail.runner.GuardrailRunner

Expand Down Expand Up @@ -50,6 +51,7 @@ trait CLICommon extends GuardrailRunner {
| --module <module name> : Explicitly select libraries to satisfy composition requirements
| --custom-extraction : Permit supplying an akka-http Directive into the generated guardrail routing layer (server only)
| --package-from-tags : Use the tags, defined in the OpenAPI specification, to guide the generated package structures
| --scala-version <version> : Target Scala version for generated code (2.12, 2.13, 3). Default: 2.13
|
|Examples:
| Generate two clients, put both in src/main/scala, under different packages, one with tracing, one without:
Expand Down Expand Up @@ -89,6 +91,12 @@ trait CLICommon extends GuardrailRunner {
case _ => Target.raiseError(UnparseableArgument(s"${arg} ${value}", "Expected one of 'disable', 'native', 'simple' or 'custom'"))
}

def parseScalaVersion(arg: String, value: String): Target[ScalaVersion] =
ScalaVersion.fromString(value) match {
case Right(v) => Target.pure(v)
case Left(err) => Target.raiseError(UnparseableArgument(arg, err))
}

def parseArgs(args: Array[String]): Target[List[Args]] = {
def expandTilde(path: String): String =
path.replaceFirst("^~", System.getProperty("user.home"))
Expand Down Expand Up @@ -160,6 +168,11 @@ trait CLICommon extends GuardrailRunner {
auth <- parseAuthImplementation(arg, value)
res <- Continue((sofar.modifyContext(_.withAuthImplementation(auth)) :: already, xs))
} yield res
case (sofar :: already, (arg @ "--scala-version") :: value :: xs) =>
for {
scalaVer <- parseScalaVersion(arg, value)
res <- Continue((sofar.modifyContext(_.withScalaVersion(scalaVer)) :: already, xs))
} yield res
case (_, unknown) =>
debug("Unknown argument") >> Bail(UnknownArguments(unknown))
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package dev.guardrail.generators

sealed abstract class ScalaVersion(val value: String) extends Product with Serializable {
def isScala3: Boolean = this == ScalaVersion.Scala3
def isScala2: Boolean = !isScala3
}

object ScalaVersion {
case object Scala212 extends ScalaVersion("2.12")
case object Scala213 extends ScalaVersion("2.13")
case object Scala3 extends ScalaVersion("3")

def fromString(version: String): Either[String, ScalaVersion] = version.trim.toLowerCase match {
case "2.12" | "2.12.x" => Right(Scala212)
case "2.13" | "2.13.x" => Right(Scala213)
case "3" | "3.x" => Right(Scala3)
case s if s.startsWith("3.") => Right(Scala3)
case other => Left(s"Unsupported Scala version: $other. Supported: 2.12, 2.13, 3")
}

val default: ScalaVersion = Scala213
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package dev.guardrail.generators

import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers

class ScalaVersionSpec extends AnyFunSuite with Matchers {

test("ScalaVersion.fromString should parse 2.12") {
ScalaVersion.fromString("2.12") shouldBe Right(ScalaVersion.Scala212)
}

test("ScalaVersion.fromString should parse 2.12.x") {
ScalaVersion.fromString("2.12.x") shouldBe Right(ScalaVersion.Scala212)
}

test("ScalaVersion.fromString should parse 2.13") {
ScalaVersion.fromString("2.13") shouldBe Right(ScalaVersion.Scala213)
}

test("ScalaVersion.fromString should parse 2.13.x") {
ScalaVersion.fromString("2.13.x") shouldBe Right(ScalaVersion.Scala213)
}

test("ScalaVersion.fromString should parse 3") {
ScalaVersion.fromString("3") shouldBe Right(ScalaVersion.Scala3)
}

test("ScalaVersion.fromString should parse 3.x") {
ScalaVersion.fromString("3.x") shouldBe Right(ScalaVersion.Scala3)
}

test("ScalaVersion.fromString should parse 3.3.4") {
ScalaVersion.fromString("3.3.4") shouldBe Right(ScalaVersion.Scala3)
}

test("ScalaVersion.fromString should parse 3.4.0") {
ScalaVersion.fromString("3.4.0") shouldBe Right(ScalaVersion.Scala3)
}

test("ScalaVersion.fromString should handle whitespace") {
ScalaVersion.fromString(" 3 ") shouldBe Right(ScalaVersion.Scala3)
ScalaVersion.fromString(" 2.13 ") shouldBe Right(ScalaVersion.Scala213)
}

test("ScalaVersion.fromString should be case insensitive") {
ScalaVersion.fromString("3.X") shouldBe Right(ScalaVersion.Scala3)
ScalaVersion.fromString("2.13.X") shouldBe Right(ScalaVersion.Scala213)
}

test("ScalaVersion.fromString should return error for unsupported versions") {
ScalaVersion.fromString("2.11") shouldBe a[Left[_, _]]
ScalaVersion.fromString("invalid") shouldBe a[Left[_, _]]
ScalaVersion.fromString("") shouldBe a[Left[_, _]]
}

test("ScalaVersion.fromString error message should be informative") {
val result = ScalaVersion.fromString("2.11")
result match {
case Left(msg) =>
msg should include("Unsupported Scala version")
msg should include("2.11")
case Right(_) => fail("Expected Left")
}
}

test("ScalaVersion.default should be Scala213") {
ScalaVersion.default shouldBe ScalaVersion.Scala213
}

test("isScala3 should return true for Scala3") {
ScalaVersion.Scala3.isScala3 shouldBe true
ScalaVersion.Scala213.isScala3 shouldBe false
ScalaVersion.Scala212.isScala3 shouldBe false
}

test("isScala2 should return true for Scala 2 versions") {
ScalaVersion.Scala3.isScala2 shouldBe false
ScalaVersion.Scala213.isScala2 shouldBe true
ScalaVersion.Scala212.isScala2 shouldBe true
}

test("ScalaVersion.value should return correct string") {
ScalaVersion.Scala212.value shouldBe "2.12"
ScalaVersion.Scala213.value shouldBe "2.13"
ScalaVersion.Scala3.value shouldBe "3"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ import dev.guardrail.generators.scala.CirceRefinedModelGenerator
import dev.guardrail.generators.scala.JacksonModelGenerator
import dev.guardrail.generators.scala.ModelGeneratorType
import dev.guardrail.generators.scala.ResponseADTHelper
import dev.guardrail.generators.scala.Scala3Compat
import dev.guardrail.generators.scala.ScalaLanguage
import dev.guardrail.generators.scala.syntax._
import dev.guardrail.generators.ScalaVersion
import dev.guardrail.generators.spi.ClientGeneratorLoader
import dev.guardrail.generators.spi.ModuleLoadResult
import dev.guardrail.generators.spi.ProtocolGeneratorLoader
Expand Down Expand Up @@ -108,8 +110,8 @@ class AkkaHttpClientGenerator private (modelGeneratorType: ModelGeneratorType) e
}
(responseDefinitions, clientOperations) = responseClientPair.unzip
tracingName = Option(className.mkString("-")).filterNot(_.isEmpty)
ctorArgs <- clientClsArgs(tracingName, serverUrls, context.tracing)
staticDefns <- buildStaticDefns(clientName, tracingName, serverUrls, ctorArgs, context.tracing)
ctorArgs <- clientClsArgs(tracingName, serverUrls, context.tracing, context.scalaVersion)
staticDefns <- buildStaticDefns(clientName, tracingName, serverUrls, ctorArgs, context.tracing, context.scalaVersion)
client <- buildClient(
clientName,
tracingName,
Expand Down Expand Up @@ -492,7 +494,8 @@ class AkkaHttpClientGenerator private (modelGeneratorType: ModelGeneratorType) e
private def clientClsArgs(
tracingName: Option[String],
serverUrls: Option[NonEmptyList[URI]],
tracing: Boolean
tracing: Boolean,
scalaVersion: ScalaVersion
): Target[List[Term.ParamClause]] = {
val implicits = List(
param"httpClient: HttpRequest => Future[HttpResponse]",
Expand All @@ -508,7 +511,7 @@ class AkkaHttpClientGenerator private (modelGeneratorType: ModelGeneratorType) e
else None),
None
),
Term.ParamClause(implicits ++ protocolImplicits, Some(Mod.Implicit()))
Scala3Compat.implicitsClause(implicits ++ protocolImplicits, scalaVersion)
)
}
private def generateResponseDefinitions(
Expand All @@ -527,7 +530,8 @@ class AkkaHttpClientGenerator private (modelGeneratorType: ModelGeneratorType) e
tracingName: Option[String],
serverUrls: Option[NonEmptyList[URI]],
ctorArgs: List[Term.ParamClause],
tracing: Boolean
tracing: Boolean,
scalaVersion: ScalaVersion
): Target[StaticDefns[ScalaLanguage]] = {
def extraConstructors(
tracingName: Option[String],
Expand All @@ -553,7 +557,7 @@ class AkkaHttpClientGenerator private (modelGeneratorType: ModelGeneratorType) e
List(param"httpClient: HttpRequest => Future[HttpResponse]", formatHost(serverUrls)) ++ tracingParams,
None
),
Term.ParamClause(implicits ++ protocolImplicits, Some(Mod.Implicit()))
Scala3Compat.implicitsClause(implicits ++ protocolImplicits, scalaVersion)
)
} yield List(
q"""def httpClient(...${args}): $tpe = $ctorCall"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,13 @@ object AkkaHttpHelper {
case _ => Left(s"Unknown modelGeneratorType: ${modelGeneratorType}")
}

def protocolImplicits(modelGeneratorType: ModelGeneratorType): Target[Term.ParamClause] = modelGeneratorType match {
case _: CirceModelGenerator => Target.pure(Term.ParamClause(Nil))
def protocolImplicits(modelGeneratorType: ModelGeneratorType): Target[List[Term.Param]] = modelGeneratorType match {
case _: CirceModelGenerator => Target.pure(Nil)
case _: JacksonModelGenerator =>
Target.pure(
Term.ParamClause(
List(
param"mapper: com.fasterxml.jackson.databind.ObjectMapper",
param"validator: javax.validation.Validator"
),
Some(Mod.Implicit())
List(
param"mapper: com.fasterxml.jackson.databind.ObjectMapper",
param"validator: javax.validation.Validator"
)
)
case _ => Target.raiseError(RuntimeFailure(s"Unknown modelGeneratorType: ${modelGeneratorType}"))
Expand Down
Loading
Loading