Skip to content

Commit

Permalink
Support Spark 3.4 (#754)
Browse files Browse the repository at this point in the history
* first try adding 3.4 support

* first try adding 3.4 support

* first try adding 3.4 support

* first try adding 3.4 support

* moved main logic of ExcelOptions to ExcelOptionsTrait

* cleanup imports

* removed accidental changes
  • Loading branch information
christianknoepfle authored Aug 1, 2023
1 parent 4b55f83 commit ea38542
Show file tree
Hide file tree
Showing 11 changed files with 282 additions and 39 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
os: [ubuntu-latest]
scala: [2.12.18, 2.13.11]
java: [temurin@8]
spark: [2.4.1, 2.4.7, 2.4.8, 3.0.1, 3.0.3, 3.1.1, 3.1.2, 3.1.3, 3.2.4, 3.3.1]
spark: [2.4.1, 2.4.7, 2.4.8, 3.0.1, 3.0.3, 3.1.1, 3.1.2, 3.1.3, 3.2.4, 3.3.2, 3.4.1]
exclude:
- spark: 2.4.1
scala: 2.13.11
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ This library requires Spark 2.0+.

List of spark versions, those are automatically tested:
```
spark: ["2.4.1", "2.4.7", "2.4.8", "3.0.1", "3.0.3", "3.1.1", "3.1.2", "3.2.1"]
spark: ["2.4.1", "2.4.7", "2.4.8", "3.0.1", "3.0.3", "3.1.1", "3.1.2", "3.2.4", "3.3.2", "3.4.1"]
```
For more detail, please refer to project CI: [ci.yml](https://github.com/crealytics/spark-excel/blob/main/.github/workflows/ci.yml#L10)

Expand Down
47 changes: 36 additions & 11 deletions build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,40 @@ import coursier.maven.MavenRepository
import mill._, scalalib._, publish._
import mill.modules.Assembly._

class SparkModule(_scalaVersion: String, sparkVersion: String) extends SbtModule with CiReleaseModule { outer =>
class SparkModule(_scalaVersion: String, sparkVersion: String) extends SbtModule with CiReleaseModule {
outer =>

override def scalaVersion = _scalaVersion

override def millSourcePath = super.millSourcePath / os.up / os.up / os.up

// Custom source layout for Spark Data Source API 2
val sparkVersionSpecificSources = if (sparkVersion >= "3.3.0") {
Seq("scala", "3.3/scala", "3.0_and_up/scala", "3.1_and_up/scala", "3.2_and_up/scala")
val sparkVersionSpecificSources = if (sparkVersion >= "3.4.0") {
Seq("scala", "3.0_and_up/scala", "3.1_and_up/scala", "3.2_and_up/scala", "3.3_and_up/scala", "3.4_and_up/scala")
} else if (sparkVersion >= "3.3.0") {
Seq("scala", "3.0_3.1_3.2_3.3/scala", "3.0_and_up/scala", "3.1_and_up/scala", "3.2_and_up/scala", "3.3_and_up/scala")
} else if (sparkVersion >= "3.2.0") {
Seq("scala", "3.0_3.1_3.2/scala", "3.0_and_up/scala", "3.1_and_up/scala", "3.2_and_up/scala")
Seq("scala", "3.0_3.1_3.2/scala", "3.0_3.1_3.2_3.3/scala", "3.0_and_up/scala", "3.1_and_up/scala", "3.2_and_up/scala")
} else if (sparkVersion >= "3.1.0") {
Seq("scala", "3.1/scala", "3.0_3.1/scala", "3.0_3.1_3.2/scala", "3.0_and_up/scala", "3.1_and_up/scala")
Seq("scala", "3.1/scala", "3.0_3.1/scala", "3.0_3.1_3.2_3.3/scala", "3.0_3.1_3.2/scala", "3.0_and_up/scala", "3.1_and_up/scala")
} else if (sparkVersion >= "3.0.0") {
Seq("scala", "3.0/scala", "3.0_3.1/scala", "3.0_3.1_3.2/scala", "3.0_and_up/scala")
Seq("scala", "3.0/scala", "3.0_3.1/scala", "3.0_3.1_3.2_3.3/scala", "3.0_3.1_3.2/scala", "3.0_and_up/scala")
} else if (sparkVersion >= "2.4.0") {
Seq("scala", "2.4/scala")
} else {
throw new UnsupportedOperationException(s"sparkVersion ${sparkVersion} is not supported")
}

override def sources = T.sources {
super.sources() ++ sparkVersionSpecificSources.map(s => PathRef(millSourcePath / "src" / "main" / os.RelPath(s)))
}

override def docSources = T.sources(Seq[PathRef]())

override def artifactName = "spark-excel"

override def publishVersion = s"${sparkVersion}_${super.publishVersion()}"

def pomSettings = PomSettings(
description = "A Spark plugin for reading and writing Excel files",
organization = "com.crealytics",
Expand All @@ -45,20 +52,23 @@ class SparkModule(_scalaVersion: String, sparkVersion: String) extends SbtModule
Rule.Relocate("org.apache.commons.io.**", "shadeio.commons.io.@1"),
Rule.Relocate("org.apache.commons.compress.**", "shadeio.commons.compress.@1")
)

override def extraPublish = Seq(PublishInfo(assembly(), classifier = None, ivyConfig = "compile"))

val sparkDeps = Agg(
ivy"org.apache.spark::spark-core:$sparkVersion",
ivy"org.apache.spark::spark-sql:$sparkVersion",
ivy"org.apache.spark::spark-hive:$sparkVersion"
)

override def compileIvyDeps = if (sparkVersion < "3.3.0") {
sparkDeps ++ Agg(ivy"org.slf4j:slf4j-api:1.7.36".excludeOrg("stax"))
} else {
sparkDeps
}

val poiVersion = "5.2.3"

override def ivyDeps = {
val base = Agg(
ivy"org.apache.poi:poi:$poiVersion",
Expand All @@ -84,15 +94,25 @@ class SparkModule(_scalaVersion: String, sparkVersion: String) extends SbtModule
base
}
}

object test extends Tests with SbtModule with TestModule.ScalaTest {

override def millSourcePath = super.millSourcePath

override def sources = T.sources {
Seq(PathRef(millSourcePath / "src" / "test" / "scala"))
}
override def resources = T.sources { Seq(PathRef(millSourcePath / "src" / "test" / "resources")) }

override def resources = T.sources {
Seq(PathRef(millSourcePath / "src" / "test" / "resources"))
}

def scalaVersion = outer.scalaVersion()
def repositoriesTask = T.task { super.repositoriesTask() ++ Seq(MavenRepository("https://jitpack.io")) }

def repositoriesTask = T.task {
super.repositoriesTask() ++ Seq(MavenRepository("https://jitpack.io"))
}

def ivyDeps = sparkDeps ++ Agg(
ivy"org.typelevel::cats-core:2.9.0",
ivy"org.scalatest::scalatest:3.2.16",
Expand All @@ -111,11 +131,16 @@ val spark24 = List("2.4.1", "2.4.7", "2.4.8")
val spark30 = List("3.0.1", "3.0.3")
val spark31 = List("3.1.1", "3.1.2", "3.1.3")
val spark32 = List("3.2.4")
val spark33 = List("3.3.1")
val spark33 = List("3.3.2")
val spark34 = List("3.4.1")

val crossMatrix = {

val crossMatrix =
(spark24 ++ spark30 ++ spark31 ++ spark32 ++ spark33).map(spark => (scala212, spark)) ++ (spark32 ++ spark33).map(
(spark24 ++ spark30 ++ spark31 ++ spark32 ++ spark33 ++ spark34).map(spark => (scala212, spark)) ++ (spark32 ++ spark33 ++ spark34).map(
spark => (scala213, spark)
)

// (spark34).map(spark => (scala212, spark))
}

object `spark-excel` extends Cross[SparkModule](crossMatrix: _*) {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright 2022 Martin Mauch (@nightscape)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.crealytics.spark.excel.v2

import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.internal.SQLConf

class ExcelOptions(
@transient
val parameters: CaseInsensitiveMap[String],
val defaultTimeZoneId: String,
val defaultColumnNameOfCorruptRecord: String
) extends ExcelOptionsTrait with Serializable {
// all parameter handling is implemented in ExcelOptionsTrait

def this(parameters: Map[String, String], defaultTimeZoneId: String) = {
this(CaseInsensitiveMap(parameters), defaultTimeZoneId, SQLConf.get.columnNameOfCorruptRecord)
}

def this(parameters: Map[String, String], defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String) = {
this(CaseInsensitiveMap(parameters), defaultTimeZoneId, defaultColumnNameOfCorruptRecord)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright 2022 Martin Mauch (@nightscape)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.crealytics.spark.excel.v2

import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.internal.SQLConf

class ExcelOptions(
@transient
val parameters: CaseInsensitiveMap[String],
val defaultTimeZoneId: String,
val defaultColumnNameOfCorruptRecord: String
) extends ExcelOptionsTrait with Serializable {
// all parameter handling is implemented in ExcelOptionsTrait

def this(parameters: Map[String, String], defaultTimeZoneId: String) = {
this(CaseInsensitiveMap(parameters), defaultTimeZoneId, SQLConf.get.columnNameOfCorruptRecord)
}

def this(parameters: Map[String, String], defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String) = {
this(CaseInsensitiveMap(parameters), defaultTimeZoneId, defaultColumnNameOfCorruptRecord)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright 2022 Martin Mauch (@nightscape)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.crealytics.spark.excel.v2

import org.apache.spark.sql.catalyst.FileSourceOptions
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.internal.SQLConf


class ExcelOptions(
@transient
val parameters: CaseInsensitiveMap[String],
val defaultTimeZoneId: String,
val defaultColumnNameOfCorruptRecord: String
) extends FileSourceOptions(parameters)
with ExcelOptionsTrait {


def this(parameters: Map[String, String], defaultTimeZoneId: String) = {
this(CaseInsensitiveMap(parameters), defaultTimeZoneId, SQLConf.get.columnNameOfCorruptRecord)
}

def this(parameters: Map[String, String], defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String) = {
this(CaseInsensitiveMap(parameters), defaultTimeZoneId, defaultColumnNameOfCorruptRecord)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Copyright 2022 Martin Mauch (@nightscape)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.crealytics.spark.excel.v2

import org.apache.hadoop.conf.Configuration
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read.PartitionReader
import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.execution.datasources.v2._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration
import scala.util.control.NonFatal

/** A factory used to create Excel readers.
*
* @param sqlConf
* SQL configuration.
* @param broadcastedConf
* Broadcasted serializable Hadoop Configuration.
* @param dataSchema
* Schema of Excel files.
* @param readDataSchema
* Required data schema in the batch scan.
* @param partitionSchema
* Schema of partitions.
* @param options
* Options for parsing Excel files.
*/
case class ExcelPartitionReaderFactory(
sqlConf: SQLConf,
broadcastedConf: Broadcast[SerializableConfiguration],
dataSchema: StructType,
readDataSchema: StructType,
partitionSchema: StructType,
options: ExcelOptions,
filters: Seq[Filter]
) extends FilePartitionReaderFactory {

override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = {
val conf = broadcastedConf.value.value
val actualDataSchema =
StructType(dataSchema.filterNot(_.name == options.columnNameOfCorruptRecord))
val actualReadDataSchema =
StructType(readDataSchema.filterNot(_.name == options.columnNameOfCorruptRecord))
val parser = new ExcelParser(actualDataSchema, actualReadDataSchema, options, filters)
val headerChecker =
new ExcelHeaderChecker(actualReadDataSchema, options, source = s"Excel file: ${file.filePath}")
val iter = readFile(conf, file, parser, headerChecker, readDataSchema)
val partitionReader = new SparkExcelPartitionReaderFromIterator(iter)
new PartitionReaderWithPartitionValues(partitionReader, readDataSchema, partitionSchema, file.partitionValues)
}

private def readFile(
conf: Configuration,
file: PartitionedFile,
parser: ExcelParser,
headerChecker: ExcelHeaderChecker,
requiredSchema: StructType
): SheetData[InternalRow] = {
val excelHelper = ExcelHelper(options)
val sheetData = excelHelper.getSheetData(conf, file.filePath.toUri)
try {
SheetData(
ExcelParser.parseIterator(sheetData.rowIterator, parser, headerChecker, requiredSchema),
sheetData.resourcesToClose
)
} catch {
case NonFatal(t) => {
sheetData.close()
throw t
}
}
}

}

private class SparkExcelPartitionReaderFromIterator(sheetData: SheetData[InternalRow])
extends PartitionReaderFromIterator[InternalRow](sheetData.rowIterator) {
override def close(): Unit = {
super.close()
sheetData.close()
}
}
Loading

0 comments on commit ea38542

Please sign in to comment.