Skip to content

Commit

Permalink
all changes necessary for nice stackrefs
Browse files Browse the repository at this point in the history
  • Loading branch information
lbialy committed Aug 26, 2024
1 parent 23fbaa1 commit 2515c39
Show file tree
Hide file tree
Showing 10 changed files with 175 additions and 80 deletions.
2 changes: 1 addition & 1 deletion besom-json/src/main/scala/besom/json/JsonFormat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ object JsonReader {
def read(json: JsValue) = f(json)
}

inline def derived[T <: Product](using JsonProtocol): JsonReader[T] = summon[JsonProtocol].jsonFormatN[T]
inline def derived[T <: Product](using JsonProtocol): JsonReader[T] = summon[JsonProtocol].jsonReaderN[T]
}

/** Provides the JSON serialization for type T.
Expand Down
89 changes: 72 additions & 17 deletions besom-json/src/main/scala/besom/json/ProductFormats.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ trait ProductFormats:
def requireNullsForOptions: Boolean = false

inline def jsonFormatN[T <: Product]: RootJsonFormat[T] = ${ ProductFormatsMacro.jsonFormatImpl[T]('self) }
inline def jsonReaderN[T <: Product]: RootJsonReader[T] = ${ ProductFormatsMacro.jsonReaderImpl[T]('self) }

object ProductFormatsMacro:
import scala.deriving.*
Expand Down Expand Up @@ -57,29 +58,45 @@ object ProductFormatsMacro:
'{ $namesExpr.zip($identsExpr).toMap }
catch case cce: ClassCastException => '{ Map.empty[String, Any] } // TODO drop after https://github.com/lampepfl/dotty/issues/19732

private def prepareFormatInstances(elemLabels: Type[?], elemTypes: Type[?])(using Quotes): List[Expr[(String, JsonFormat[?], Boolean)]] =
(elemLabels, elemTypes) match
case ('[EmptyTuple], '[EmptyTuple]) => Nil
case ('[label *: labelsTail], '[tpe *: tpesTail]) =>
val label = Type.valueOfConstant[label].get.asInstanceOf[String]
val isOption = Type.of[tpe] match
case '[Option[?]] => Expr(true)
case _ => Expr(false)

val fieldName = Expr(label)
val fieldFormat = Expr.summon[JsonFormat[tpe]].getOrElse {
quotes.reflect.report.errorAndAbort("Missing given instance of JsonFormat[" ++ Type.show[tpe] ++ "]")
} // TODO: Handle missing instance
val namedInstance = '{ (${ fieldName }, $fieldFormat, ${ isOption }) }
namedInstance :: prepareFormatInstances(Type.of[labelsTail], Type.of[tpesTail])

private def prepareReaderInstances(elemLabels: Type[?], elemTypes: Type[?])(using Quotes): List[Expr[(String, JsonReader[?], Boolean)]] =
(elemLabels, elemTypes) match
case ('[EmptyTuple], '[EmptyTuple]) => Nil
case ('[label *: labelsTail], '[tpe *: tpesTail]) =>
val label = Type.valueOfConstant[label].get.asInstanceOf[String]
val isOption = Type.of[tpe] match
case '[Option[?]] => Expr(true)
case _ => Expr(false)

val fieldName = Expr(label)
val fieldFormat = Expr.summon[JsonReader[tpe]].getOrElse {
quotes.reflect.report.errorAndAbort("Missing given instance of JsonFormat[" ++ Type.show[tpe] ++ "]")
} // TODO: Handle missing instance
val namedInstance = '{ (${ fieldName }, $fieldFormat, ${ isOption }) }
namedInstance :: prepareReaderInstances(Type.of[labelsTail], Type.of[tpesTail])

def jsonFormatImpl[T <: Product: Type](prodFormats: Expr[ProductFormats])(using Quotes): Expr[RootJsonFormat[T]] =
Expr.summon[Mirror.Of[T]].get match
case '{
$m: Mirror.ProductOf[T] { type MirroredElemLabels = elementLabels; type MirroredElemTypes = elementTypes }
} =>
def prepareInstances(elemLabels: Type[?], elemTypes: Type[?]): List[Expr[(String, JsonFormat[?], Boolean)]] =
(elemLabels, elemTypes) match
case ('[EmptyTuple], '[EmptyTuple]) => Nil
case ('[label *: labelsTail], '[tpe *: tpesTail]) =>
val label = Type.valueOfConstant[label].get.asInstanceOf[String]
val isOption = Type.of[tpe] match
case '[Option[?]] => Expr(true)
case _ => Expr(false)

val fieldName = Expr(label)
val fieldFormat = Expr.summon[JsonFormat[tpe]].getOrElse {
quotes.reflect.report.errorAndAbort("Missing given instance of JsonFormat[" ++ Type.show[tpe] ++ "]")
} // TODO: Handle missing instance
val namedInstance = '{ (${ fieldName }, $fieldFormat, ${ isOption }) }
namedInstance :: prepareInstances(Type.of[labelsTail], Type.of[tpesTail])

// instances are in correct order of fields of the product
val allInstancesExpr = Expr.ofList(prepareInstances(Type.of[elementLabels], Type.of[elementTypes]))
val allInstancesExpr = Expr.ofList(prepareFormatInstances(Type.of[elementLabels], Type.of[elementTypes]))
val defaultArguments = findDefaultParams[T]

'{
Expand Down Expand Up @@ -121,6 +138,44 @@ object ProductFormatsMacro:

JsObject(fields.toMap)
}

def jsonReaderImpl[T <: Product: Type](prodFormats: Expr[ProductFormats])(using Quotes): Expr[RootJsonReader[T]] =
Expr.summon[Mirror.Of[T]].get match
case '{
$m: Mirror.ProductOf[T] { type MirroredElemLabels = elementLabels; type MirroredElemTypes = elementTypes }
} =>
// instances are in correct order of fields of the product
val allInstancesExpr = Expr.ofList(prepareReaderInstances(Type.of[elementLabels], Type.of[elementTypes]))
val defaultArguments = findDefaultParams[T]

'{
new RootJsonReader[T]:
private val allInstances = ${ allInstancesExpr }
private val fmts = ${ prodFormats }
private val defaultArgs = ${ defaultArguments }

def read(json: JsValue): T = json match
case JsObject(fields) =>
val values = allInstances.map { case (fieldName, fieldFormat, isOption) =>
try fieldFormat.read(fields(fieldName))
catch
case e: NoSuchElementException =>
// if field has a default value, use it, we didn't find anything in the JSON
if defaultArgs.contains(fieldName) then defaultArgs(fieldName)
// if field is optional and requireNullsForOptions is disabled, return None
// otherwise we require an explicit null value
else if isOption && !fmts.requireNullsForOptions then None
// it's missing so we throw an exception
else throw DeserializationException("Object is missing required member '" ++ fieldName ++ "'", null, fieldName :: Nil)
case DeserializationException(msg, cause, fieldNames) =>
throw DeserializationException(msg, cause, fieldName :: fieldNames)
}
$m.fromProduct(Tuple.fromArray(values.toArray))

case _ => throw DeserializationException("Object expected", null, allInstances.map(_._1))

}

end ProductFormatsMacro

/** This trait supplies an alternative rendering mode for optional case class members. Normally optional members that are undefined (`None`)
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/besom/aliases.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,5 @@ object aliases:
object CustomTimeouts extends besom.internal.CustomTimeoutsFactory

export besom.internal.InvokeOptions
export besom.util.JsonReaderInstances.*
end aliases
39 changes: 9 additions & 30 deletions core/src/main/scala/besom/internal/StackReference.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,16 @@ end StackReference
trait StackReferenceFactory:
sealed trait StackReferenceType[T]:
type Out[T]
def transform(stackReference: StackReference): Output[Out[T]]
def transform(stackReference: StackReference)(using Context): Output[Out[T]]

object StackReferenceType:
given untyped: UntypedStackReferenceType = UntypedStackReferenceType()

given typed[T: JsonReader]: TypedStackReferenceType[T] = TypedStackReferenceType[T]

class TypedStackReferenceType[T](using JsonReader[T]) extends StackReferenceType[T]:
type Out[T] = TypedStackReference[T]
def transform(stackReference: StackReference): Output[Out[T]] =
def transform(stackReference: StackReference)(using Context): Output[Out[T]] =
val objectOutput: Output[T] =
requireObject(stackReference.outputs, stackReference.secretOutputNames)

Expand All @@ -71,16 +76,9 @@ trait StackReferenceFactory:
)
)

class UntypedStackReferenceType(using Context) extends StackReferenceType[Any]:
class UntypedStackReferenceType extends StackReferenceType[Any]:
type Out[T] = StackReference
def transform(stackReference: StackReference): Output[StackReference] = Output(stackReference)

import scala.compiletime.summonFrom
inline implicit def stackRefTypeProvider[T](using Context): StackReferenceType[T] =
summonFrom {
case _: besom.json.JsonReader[T] => typedStackReference[T]
case _ => untypedStackReference.asInstanceOf[StackReferenceType[T]]
}
def transform(stackReference: StackReference)(using Context): Output[StackReference] = Output(stackReference)

def untypedStackReference(using Context): StackReferenceType[Any] = UntypedStackReferenceType()

Expand Down Expand Up @@ -113,25 +111,6 @@ trait StackReferenceFactory:
}
.flatMap(stackRefType.transform)

// def apply[T](using
// ctx: Context,
// jr: JsonReader[T]
// )(name: NonEmptyString, args: Input.Optional[StackReferenceArgs], opts: StackReferenceResourceOptions): Output[TypedStackReference[T]] =
// apply(using ctx)(name, args, opts).flatMap { stackReference =>
// val objectOutput: Output[T] =
// requireObject(stackReference.outputs, stackReference.secretOutputNames)

// objectOutput.map(t =>
// TypedStackReference(
// urn = stackReference.urn,
// id = stackReference.id,
// name = stackReference.name,
// outputs = t,
// secretOutputNames = stackReference.secretOutputNames
// )
// )
// }

private[internal] def requireObject[T: JsonReader](
outputs: Output[Map[String, JsValue]],
secretOutputNames: Output[Set[String]]
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/besom/internal/codecs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ object Decoder extends DecoderInstancesLowPrio1:
.foldLeft[ValidatedResult[DecodingError, Vector[OutputData[A]]]](ValidatedResult.valid(Vector.empty))(
accumulatedOutputDataOrErrors(_, _, "iterable", label)
)
.map(_.toIterable)
.map(_.toVector)
.map(OutputData.sequence)
end if
}
Expand Down
33 changes: 33 additions & 0 deletions core/src/main/scala/besom/util/JsonReaderInstances.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package besom.util

import besom.json.*
import besom.internal.{Output, Context}
import besom.internal.Constants, Constants.SpecialSig

object JsonReaderInstances:
implicit def outputJsonReader[A](using jsonReader: JsonReader[A], ctx: Context): JsonReader[Output[A]] =
new JsonReader[Output[A]]:
def read(json: JsValue): Output[A] = json match
case JsObject(fields) =>
fields.get(SpecialSig.Key) match
case Some(JsString(sig)) if SpecialSig.fromString(sig) == Some(SpecialSig.OutputSig) =>
val maybeInnerValue = fields.get(Constants.ValueName)
maybeInnerValue
.map { innerValue =>
try Output(jsonReader.read(innerValue))
catch case e: Throwable => Output.fail(e)
}
.getOrElse(Output.fail(Exception("Invalid JSON")))

case Some(JsString(sig)) if SpecialSig.fromString(sig) == Some(SpecialSig.SecretSig) =>
val maybeInnerValue = fields.get(Constants.ValueName)
maybeInnerValue
.map { innerValue =>
try Output.secret(jsonReader.read(innerValue))
catch case e: Throwable => Output.fail(e)
}
.getOrElse(Output.fail(Exception("Invalid JSON")))

case _ => Output.fail(Exception("Invalid JSON"))

case _ => Output.fail(Exception("Invalid JSON"))
17 changes: 0 additions & 17 deletions core/src/test/scala/besom/experimental/StackRefCompilation.scala

This file was deleted.

Binary file removed cs
Binary file not shown.
12 changes: 6 additions & 6 deletions integration-tests/CoreTests.test.scala
Original file line number Diff line number Diff line change
Expand Up @@ -140,19 +140,19 @@ class CoreTests extends munit.FunSuite {
case pulumi.FixtureMultiContext(ctx, Vector(ctx1, ctx2)) =>
println(s"Source stack name: ${ctx1.stackName}, pulumi home: ${ctx.home}")
pulumi.up(ctx1.stackName).call(cwd = ctx1.programDir, env = ctx1.env)
val outputs1 = upickle.default.read[Map[String, ujson.Value]](
val expected = upickle.default.read[Map[String, ujson.Value]](
pulumi.outputs(ctx1.stackName, "--show-secrets").call(cwd = ctx1.programDir, env = ctx1.env).out.text()
)

println(s"Target stack name: ${ctx2.stackName}, pulumi home: ${ctx.home}")
pulumi
.up(ctx2.stackName, "--config", s"sourceStack=organization/source-stack-test/${ctx1.stackName}")
.call(cwd = ctx2.programDir, env = ctx2.env)
val outputs2 = upickle.default.read[Map[String, ujson.Value]](
val obtained = upickle.default.read[Map[String, ujson.Value]](
pulumi.outputs(ctx2.stackName, "--show-secrets").call(cwd = ctx2.programDir, env = ctx2.env).out.text()
)

assertEquals(outputs1, outputs2)
assertEquals(obtained, expected)

case _ => throw new Exception("Invalid number of contexts")
}
Expand Down Expand Up @@ -182,19 +182,19 @@ class CoreTests extends munit.FunSuite {
case pulumi.FixtureMultiContext(ctx, Vector(ctx1, ctx2)) =>
println(s"Source stack name: ${ctx1.stackName}, pulumi home: ${ctx.home}")
pulumi.up(ctx1.stackName).call(cwd = ctx1.programDir, env = ctx1.env)
val outputs1 = upickle.default.read[Map[String, ujson.Value]](
val expected = upickle.default.read[Map[String, ujson.Value]](
pulumi.outputs(ctx1.stackName, "--show-secrets").call(cwd = ctx1.programDir, env = ctx1.env).out.text()
)

println(s"Target stack name: ${ctx2.stackName}, pulumi home: ${ctx.home}")
pulumi
.up(ctx2.stackName, "--config", s"sourceStack=organization/source-stack-test/${ctx1.stackName}")
.call(cwd = ctx2.programDir, env = ctx2.env)
val outputs2 = upickle.default.read[Map[String, ujson.Value]](
val obtained = upickle.default.read[Map[String, ujson.Value]](
pulumi.outputs(ctx2.stackName, "--show-secrets").call(cwd = ctx2.programDir, env = ctx2.env).out.text()
)

assertEquals(outputs1, outputs2)
assertEquals(obtained, expected)

case _ => throw new Exception("Invalid number of contexts")
}
Expand Down
60 changes: 52 additions & 8 deletions integration-tests/resources/references/target-stack/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@ import besom.json.*

//noinspection UnitMethodIsParameterless,TypeAnnotation
@main def main = Pulumi.run {

case class Structured(a: Output[String], b: Double) derives JsonReader, Encoder
case class SourceStack(sshKeyUrn: String, value1: Int, value2: String, structured: Structured) derives JsonReader, Encoder

val sourceStackName = config.requireString("sourceStack").map(NonEmptyString(_).get)
val sourceStack = besom.StackReference(
val sourceStack = StackReference(
"stackRef",
StackReferenceArgs(sourceStackName),
StackReferenceResourceOptions()
Expand All @@ -22,14 +26,29 @@ import besom.json.*
)

val value1 = sourceStack.flatMap(_.getOutput("value1").map {
case Some(JsNumber(s)) => s.toInt
case other => throw RuntimeException(s"Expected string, got $other")
case Some(JsNumber(s)) =>
val i = s.toInt
assert(i == 23, "value1 should be 23")
i
case other => throw RuntimeException(s"Expected string, got $other")
})

val value2 = sourceStack.flatMap(_.getOutput("value2").map {
case Some(JsString(s)) => s
case other => throw RuntimeException(s"Expected string, got $other")
case Some(JsString(s)) =>
assert(s == "Hello world!", "value2 should be Hello world!")
s
case other => throw RuntimeException(s"Expected string, got $other")
})
val structured = sourceStack.flatMap(_.getOutput("structured"))

val structured = sourceStack.flatMap(_.getOutput("structured")).map {
case Some(js @ JsObject(map)) =>
assert(map.size == 2, "structured should have 2 fields")
assert(map.get("a").flatMap(_.asJsObject.fields.get("value")).contains(JsString("ABCDEF")), "structured.a should be ABCDEF")
assert(map.get("b").map(_.toString.toDouble).contains(42.0), "structured.b should be 42.0")
js.asInstanceOf[JsValue]
case Some(_) => throw RuntimeException("structured should be a JsObject")
case None => throw RuntimeException("structured should be a JsObject")
}

val sanityCheck = Output {
for
Expand All @@ -44,10 +63,35 @@ import besom.json.*
assert(s, "structured should be a secret")
}

Stack(Output(sanityCheck)).exports(
val typedSourceStack = StackReference[SourceStack](
"stackRef",
StackReferenceArgs(sourceStackName),
StackReferenceResourceOptions()
)

val typedSanityCheck = typedSourceStack.flatMap { sourceStack =>
val outputs = sourceStack.outputs
outputs.structured.a.flatMap { a =>
assert(a == "ABCDEF", "structured.a should be ABCDEF")

Output {

assert(
outputs.sshKeyUrn.startsWith("urn:pulumi:tests-stack-outputs-and-references-should-work") &&
outputs.sshKeyUrn.endsWith("::source-stack-test::tls:index/privateKey:PrivateKey::sshKey")
)
assert(outputs.value1 == 23, "value1 should be 23")
assert(outputs.value2 == "Hello world!", "value2 should be Hello world!")
assert(outputs.structured.b == 42.0, "structured.b should be 42.0")
outputs
}
}
}

Stack(typedSanityCheck, sanityCheck).exports(
sshKeyUrn = sshKeyUrn,
value1 = value1,
value2 = value2,
structured = structured
structured = typedSourceStack.map(_.outputs.structured)
)
}

0 comments on commit 2515c39

Please sign in to comment.