Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor BQ to expose all beam's configurations #5456

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class BigQueryClientIT extends AnyFlatSpec with Matchers {

"TableService.getRows" should "work" in {
val rows =
bq.tables.rows(Table.Spec("bigquery-public-data:samples.shakespeare")).take(10).toList
bq.tables.rows(Table("bigquery-public-data:samples.shakespeare")).take(10).toList
val columns = Set("word", "word_count", "corpus", "corpus_date")
all(rows.map(_.keySet().asScala)) shouldBe columns
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class BigQueryIOIT extends PipelineSpec {

"Select" should "read typed values from a SQL query" in
runWithRealContext(options) { sc =>
val scoll = sc.read(BigQueryTyped[ShakespeareFromQuery])
val scoll = sc.typedBigQueryStorage[ShakespeareFromQuery]()
scoll should haveSize(10)
scoll should satisfy[ShakespeareFromQuery] {
_.forall(_.getClass == classOf[ShakespeareFromQuery])
Expand All @@ -54,7 +54,7 @@ class BigQueryIOIT extends PipelineSpec {

"TableRef" should "read typed values from table" in
runWithRealContext(options) { sc =>
val scoll = sc.read(BigQueryTyped[ShakespeareFromTable])
val scoll = sc.typedBigQueryStorage[ShakespeareFromTable]()
scoll.take(10) should haveSize(10)
scoll should satisfy[ShakespeareFromTable] {
_.forall(_.getClass == classOf[ShakespeareFromTable])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package com.spotify.scio.bigquery
import com.google.protobuf.ByteString
import com.spotify.scio._
import com.spotify.scio.avro._
import com.spotify.scio.bigquery.BigQueryTypedTable.Format
import com.spotify.scio.bigquery.client.BigQuery
import com.spotify.scio.testing._
import magnolify.scalacheck.auto._
Expand Down Expand Up @@ -69,7 +68,7 @@ object TypedBigQueryIT {
val now = Instant.now().toString(TIME_FORMATTER)
val spec =
s"data-integration-test:bigquery_avro_it.$name${now}_${Random.nextInt(Int.MaxValue)}"
Table.Spec(spec)
Table(spec)
}
private val tableRowTable = table("records_tablerow")
private val avroTable = table("records_avro")
Expand Down Expand Up @@ -101,37 +100,25 @@ class TypedBigQueryIT extends PipelineSpec with BeforeAndAfterAll {
BigQuery.defaultInstance().tables.delete(avroLogicalTypeTable.ref)
}

"TypedBigQuery" should "read records" in {
"typedBigQuery" should "read records" in {
val sc = ScioContext(options)
sc.typedBigQuery[Record](tableRowTable) should containInAnyOrder(records)
sc.run()
}

it should "convert to avro format" in {
"bigQueryTableFormat" should "read TableRow records" in {
val sc = ScioContext(options)
implicit val coder = avroGenericRecordCoder(Record.avroSchema)
sc.typedBigQuery[Record](tableRowTable)
.map(Record.toAvro)
.map(Record.fromAvro) should containInAnyOrder(
records
)
val format = BigQueryIO.Format.Default(BigQueryType[Record])
val data = sc.bigQueryTableFormat(tableRowTable, format)
data should containInAnyOrder(records)
sc.run()
}

"BigQueryTypedTable" should "read TableRow records" in {
it should "read GenericRecord records" in {
val sc = ScioContext(options)
sc
.bigQueryTable(tableRowTable)
.map(Record.fromTableRow) should containInAnyOrder(records)
sc.run()
}

it should "read GenericRecord recors" in {
val sc = ScioContext(options)
implicit val coder = avroGenericRecordCoder(Record.avroSchema)
sc
.bigQueryTable(tableRowTable, Format.GenericRecord)
.map(Record.fromAvro) should containInAnyOrder(records)
val format = BigQueryIO.Format.Avro(BigQueryType[Record])
val data = sc.bigQueryTableFormat(tableRowTable, format)
data should containInAnyOrder(records)
sc.run()
}

Expand All @@ -157,7 +144,7 @@ class TypedBigQueryIT extends PipelineSpec with BeforeAndAfterAll {
|}
""".stripMargin)
val tap = sc
.bigQueryTable(tableRowTable, Format.GenericRecord)
.bigQueryTableFormat(tableRowTable, BigQueryIO.Format.Avro())
.saveAsBigQueryTable(avroTable, schema = schema, createDisposition = CREATE_IF_NEEDED)

val result = sc.run().waitUntilDone()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class BigQueryStorageIT extends AnyFlatSpec with Matchers {
Array("--project=data-integration-test", "--tempLocation=gs://data-integration-test-eu/temp")
)
val p = sc
.typedBigQuery[NestedWithFields]()
.typedBigQueryStorage[NestedWithFields]()
.map(r => (r.required.int, r.required.string, r.optional.get.int))
.internal
PAssert.that(p).containsInAnyOrder(expected)
Expand All @@ -139,7 +139,7 @@ class BigQueryStorageIT extends AnyFlatSpec with Matchers {
Array("--project=data-integration-test", "--tempLocation=gs://data-integration-test-eu/temp")
)
val p = sc
.typedBigQuery[NestedWithRestriction]()
.typedBigQueryStorage[NestedWithRestriction]()
.map { r =>
val (req, opt, rep) = (r.required, r.optional.get, r.repeated.head)
(req.int, req.string, opt.int, opt.string, rep.int, rep.string)
Expand All @@ -155,8 +155,10 @@ class BigQueryStorageIT extends AnyFlatSpec with Matchers {
val (sc, _) = ContextAndArgs(
Array("--project=data-integration-test", "--tempLocation=gs://data-integration-test-eu/temp")
)
val bqt = BigQueryType[NestedWithRestriction]
val source = Table(bqt.table.get, "required.int < 3")
val p = sc
.typedBigQueryStorage[NestedWithRestriction](rowRestriction = "required.int < 3")
.typedBigQueryStorage[NestedWithRestriction](source)
.map { r =>
val (req, opt, rep) = (r.required, r.optional.get, r.repeated.head)
(req.int, req.string, opt.int, opt.string, rep.int, rep.string)
Expand All @@ -172,7 +174,7 @@ class BigQueryStorageIT extends AnyFlatSpec with Matchers {
Array("--project=data-integration-test", "--tempLocation=gs://data-integration-test-eu/temp")
)
val p = sc
.typedBigQuery[NestedWithAll](Table.Spec(NestedWithAll.table.format("nested")))
.typedBigQueryStorage[NestedWithAll](Table(NestedWithAll.table.format("nested")))
.map(r => (r.required.int, r.required.string, r.optional.get.int))
.internal
PAssert.that(p).containsInAnyOrder(expected)
Expand Down Expand Up @@ -232,7 +234,7 @@ class BigQueryStorageIT extends AnyFlatSpec with Matchers {
val (sc, _) = ContextAndArgs(
Array("--project=data-integration-test", "--tempLocation=gs://data-integration-test-eu/temp")
)
val p = sc.typedBigQuery[FromTable]().internal
val p = sc.typedBigQueryStorage[FromTable]().internal
PAssert.that(p).containsInAnyOrder(expected)
sc.run()
}
Expand All @@ -243,7 +245,7 @@ class BigQueryStorageIT extends AnyFlatSpec with Matchers {
Array("--project=data-integration-test", "--tempLocation=gs://data-integration-test-eu/temp")
)
val p = sc
.typedBigQueryStorage[ToTableRequired](Table.Spec("data-integration-test:storage.required"))
.typedBigQueryStorage[ToTableRequired](Table("data-integration-test:storage.required"))
.internal
PAssert.that(p).containsInAnyOrder(expected)
sc.run()
Expand Down Expand Up @@ -272,7 +274,7 @@ class BigQueryStorageIT extends AnyFlatSpec with Matchers {
val (sc, _) = ContextAndArgs(
Array("--project=data-integration-test", "--tempLocation=gs://data-integration-test-eu/temp")
)
val p = sc.typedBigQuery[FromQuery]().internal
val p = sc.typedBigQueryStorage[FromQuery]().internal
PAssert.that(p).containsInAnyOrder(expected)
sc.run()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,9 @@ class BigQueryTypeIT extends AnyFlatSpec with Matchers {
tableReference.setProjectId("data-integration-test")
tableReference.setDatasetId("partition_a")
tableReference.setTableId("table_$LATEST")
Table.Ref(tableReference).latest().ref.getTableId shouldBe "table_20170302"
Table(tableReference).latest().ref.getTableId shouldBe "table_20170302"

Table
.Spec("data-integration-test:partition_a.table_$LATEST")
Table("data-integration-test:partition_a.table_$LATEST")
.latest()
.ref
.getTableId shouldBe "table_20170302"
Expand All @@ -210,7 +209,7 @@ class BigQueryTypeIT extends AnyFlatSpec with Matchers {
val bqt = BigQueryType[FromTableT]
bqt.isQuery shouldBe false
bqt.isTable shouldBe true
bqt.query shouldBe None
bqt.queryRaw shouldBe None
bqt.table shouldBe Some("bigquery-public-data:samples.shakespeare")
val fields = bqt.schema.getFields.asScala
fields.size shouldBe 4
Expand Down
11 changes: 11 additions & 0 deletions scio-core/src/main/scala/com/spotify/scio/ScioContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ import scala.reflect.ClassTag
import scala.util.control.NoStackTrace
import scala.util.{Failure, Success, Try}
import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions
import org.apache.beam.sdk.transforms.errorhandling.{BadRecord, ErrorHandler}

/** Runner specific context. */
trait RunnerContext {
Expand Down Expand Up @@ -851,6 +852,16 @@ class ScioContext private[scio] (
this.applyTransform(Create.timestamped(v.asJava).withCoder(coder))
}

// =======================================================================
// Error handler
// =======================================================================
def registerBadRecordErrorHandler[T <: POutput](
sinkTransform: PTransform[PCollection[BadRecord], T]
): ErrorHandler[BadRecord, T] =
pipeline.registerBadRecordErrorHandler(sinkTransform)

def errorSink(): ErrorSink = ErrorSink(this)

// =======================================================================
// Metrics
// =======================================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.beam.sdk.io.FileIO.ReadableFile
import org.apache.beam.sdk.io.fs.{MatchResult, MetadataCoderV2, ResourceId, ResourceIdCoder}
import org.apache.beam.sdk.io.ReadableFileCoder
import org.apache.beam.sdk.schemas.{Schema => BSchema}
import org.apache.beam.sdk.transforms.errorhandling.BadRecord
import org.apache.beam.sdk.transforms.windowing.{
BoundedWindow,
GlobalWindow,
Expand Down Expand Up @@ -66,6 +67,11 @@ trait BeamTypeCoders extends CoderGrammar {
str => DefaultJsonObjectParser.parseAndClose(new StringReader(str), ScioUtil.classOf[T]),
DefaultJsonObjectParser.getJsonFactory().toString(_)
)

// rely on serializable
implicit val badRecordCoder: Coder[BadRecord] = kryo
implicit val badRecordRecordCoder: Coder[BadRecord.Record] = kryo
implicit val badRecordFailurCoder: Coder[BadRecord.Failure] = kryo
}

private[coders] object BeamTypeCoders extends BeamTypeCoders {
Expand Down
56 changes: 56 additions & 0 deletions scio-core/src/main/scala/com/spotify/scio/values/ErrorSink.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright 2024 Spotify AB
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.spotify.scio.values

import com.spotify.scio.ScioContext
import org.apache.beam.sdk.transforms.PTransform
import org.apache.beam.sdk.transforms.errorhandling.{BadRecord, ErrorHandler}
import org.apache.beam.sdk.values.{PCollection, PCollectionTuple, TupleTag}

/**
* A sink for error records.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bit more explanation on error records could be helpful, maybe:

Suggested change
* A sink for error records.
* A sink for error records.
*
* An error record is produced by certain PTransforms that catch processing exceptions and transform the resulting (element, exception) pair into a [[BadRecord]] instance.
* When an ErrorSink is configured (via ScioContext#errorSink), these BadRecords can be accessed as an SCollection by invoking the ErrorSink#sink method.
* An ErrorSink is useful if you'd like to set up special handling of exceptions (incrementing Counters, logging the exceptions in a database, etc).

*
* Once the [[sink]] is materialized, the [[handler]] must not be used anymore.
*/
sealed trait ErrorSink {
def handler: ErrorHandler[BadRecord, _]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could def handler be private[scio]? not sure when a user would need to access this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the API exposed by beam. As mentioned in the description we do not pass the ErrorSink directly.

sc.bigQueryStorageFormat[MyType](
  table,
  format,
  errorHandler = errorSink.handler
)

I was thinking of adding to the ScioContext a beam java like API too

def registerBadRecordErrorHandler[T](handler: PTransform[PCollection[BadRecord], T] sinkTransform): BadRecordErrorHandler[OutputT]

def sink: SCollection[BadRecord]
}

object ErrorSink {

private class SinkSideOutput(tag: TupleTag[BadRecord])
extends PTransform[PCollection[BadRecord], PCollectionTuple] {
override def expand(input: PCollection[BadRecord]): PCollectionTuple =
PCollectionTuple.of(tag, input)
}

private[scio] def apply(context: ScioContext): ErrorSink = {
new ErrorSink {
private val tupleTag: TupleTag[BadRecord] = new TupleTag[BadRecord]()

override val handler: ErrorHandler[BadRecord, PCollectionTuple] =
context.pipeline.registerBadRecordErrorHandler(new SinkSideOutput(tupleTag))

override def sink: SCollection[BadRecord] = {
handler.close()
val output = handler.getOutput
context.wrap(output.get(tupleTag))
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ object AutoComplete {
if (outputToBigqueryTable) {
tags
.map(kv => Record(kv._1, kv._2.map(p => Tag(p._1, p._2)).toList))
.saveAsTypedBigQueryTable(Table.Spec(args("output")))
.saveAsTypedBigQueryTable(Table(args("output")))
}
if (outputToDatastore) {
val kind = args.getOrElse("kind", "autocomplete-demo")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ object StreamingWordExtract {
.flatMap(_.split("[^a-zA-Z']+").filter(_.nonEmpty))
.map(_.toUpperCase)
.map(s => TableRow("string_field" -> s))
.saveAsBigQueryTable(Table.Spec(args("output")), schema)
.saveAsBigQueryTable(Table(args("output")), schema)

val result = sc.run()
exampleUtils.waitToFinish(result.pipelineResult)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ object TrafficMaxLaneFlow {
ts
)
}
.saveAsTypedBigQueryTable(Table.Spec(args("output")))
.saveAsTypedBigQueryTable(Table(args("output")))

val result = sc.run()
exampleUtils.waitToFinish(result.pipelineResult)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ object TrafficRoutes {
.map { case (r, ts) =>
Record(r.route, r.avgSpeed, r.slowdownEvent, ts)
}
.saveAsTypedBigQueryTable(Table.Spec(args("output")))
.saveAsTypedBigQueryTable(Table(args("output")))

val result = sc.run()
exampleUtils.waitToFinish(result.pipelineResult)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ object GameStats {
// Done using windowing information, convert back to regular `SCollection`
.toSCollection
// Save to the BigQuery table defined by "output" in the arguments passed in + "_team" suffix
.saveAsTypedBigQueryTable(Table.Spec(args("output") + "_team"))
.saveAsTypedBigQueryTable(Table(args("output") + "_team"))

userEvents
// Window over a variable length of time - sessions end after sessionGap minutes no activity
Expand Down Expand Up @@ -141,7 +141,7 @@ object GameStats {
AvgSessionLength(mean, fmt.print(w.start()))
}
// Save to the BigQuery table defined by "output" + "_sessions" suffix
.saveAsTypedBigQueryTable(Table.Spec(args("output") + "_sessions"))
.saveAsTypedBigQueryTable(Table(args("output") + "_sessions"))

// Execute the pipeline
val result = sc.run()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ object HourlyTeamScore {
TeamScoreSums(team, score, start)
}
// Save to the BigQuery table defined by "output" in the arguments passed in
.saveAsTypedBigQueryTable(Table.Spec(args("output")))
.saveAsTypedBigQueryTable(Table(args("output")))

// Execute the pipeline
sc.run()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ object LeaderBoard {
// Done with windowing information, convert back to regular `SCollection`
.toSCollection
// Save to the BigQuery table defined by "output" in the arguments passed in + "_team" suffix
.saveAsTypedBigQueryTable(Table.Spec(args("output") + "_team"))
.saveAsTypedBigQueryTable(Table(args("output") + "_team"))

gameEvents
// Use a global window for unbounded data, which updates calculation every 10 minutes,
Expand Down Expand Up @@ -126,7 +126,7 @@ object LeaderBoard {
// Map summed results from tuples into `UserScoreSums` case class, so we can save to BQ
.map(kv => UserScoreSums(kv._1, kv._2, fmt.print(Instant.now())))
// Save to the BigQuery table defined by "output" in the arguments passed in + "_user" suffix
.saveAsTypedBigQueryTable(Table.Spec(args("output") + "_user"))
.saveAsTypedBigQueryTable(Table(args("output") + "_user"))

// Execute the pipeline
val result = sc.run()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ object UserScore {
// Map summed results from tuples into `UserScoreSums` case class, so we can save to BQ
.map(UserScoreSums.tupled)
// Save to the BigQuery table defined by "output" in the arguments passed in
.saveAsTypedBigQueryTable(Table.Spec(args("output")))
.saveAsTypedBigQueryTable(Table(args("output")))

// Execute the pipeline
sc.run()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ object BigQueryTornadoes {
)

// Open a BigQuery table as a `SCollection[TableRow]`
val table = Table.Spec(args.getOrElse("input", ExampleData.WEATHER_SAMPLES_TABLE))
val table = Table(args.getOrElse("input", ExampleData.WEATHER_SAMPLES_TABLE))
val resultTap = sc
.bigQueryTable(table)
// Extract months with tornadoes
Expand All @@ -55,7 +55,7 @@ object BigQueryTornadoes {
// Map `(Long, Long)` tuples into result `TableRow`s
.map(kv => TableRow("month" -> kv._1, "tornado_count" -> kv._2))
// Save result as a BigQuery table
.saveAsBigQueryTable(Table.Spec(args("output")), schema, WRITE_TRUNCATE, CREATE_IF_NEEDED)
.saveAsBigQueryTable(Table(args("output")), schema, WRITE_TRUNCATE, CREATE_IF_NEEDED)

// Access the loaded tables
resultTap
Expand Down
Loading
Loading