Skip to content

Commit

Permalink
fix: correctly deserialize consecutive XML flat maps (#879)
Browse files Browse the repository at this point in the history
  • Loading branch information
lauzadis authored Jun 30, 2023
1 parent afb36c8 commit 422e04e
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 5 deletions.
8 changes: 8 additions & 0 deletions .changes/a2c5037e-db65-45f5-8033-acbd3bf241ee.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "a2c5037e-db65-45f5-8033-acbd3bf241ee",
"type": "bugfix",
"description": "Properly deserialize XML flat maps",
"issues": [
"https://github.com/awslabs/aws-sdk-kotlin/issues/962"
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,14 @@ public class XmlDeserializer(
return XmlListDeserializer(reader.subTreeReader(depth), descriptor)
}

override fun deserializeMap(descriptor: SdkFieldDescriptor): Deserializer.EntryIterator =
XmlMapDeserializer(reader.subTreeReader(XmlStreamReader.SubtreeStartDepth.CURRENT), descriptor)
override fun deserializeMap(descriptor: SdkFieldDescriptor): Deserializer.EntryIterator {
val depth = when (descriptor.hasTrait<Flattened>()) {
true -> XmlStreamReader.SubtreeStartDepth.CURRENT
else -> XmlStreamReader.SubtreeStartDepth.CHILD
}

return XmlMapDeserializer(reader.subTreeReader(depth), descriptor)
}
}

/**
Expand All @@ -98,10 +104,15 @@ internal class XmlMapDeserializer(
private val mapTrait = descriptor.findTrait<XmlMapName>() ?: XmlMapName.Default

override fun hasNextEntry(): Boolean {
// Seek to either the entry or key token depending on the flatness of the map
val compareTo = when (descriptor.hasTrait<Flattened>()) {
true -> descriptor.findTrait<XmlSerialName>()?.name ?: mapTrait.key // Prefer seeking to XmlSerialName if the trait exists
false -> mapTrait.entry
}

// Seek to either the XML serial name, entry, or key token depending on the flatness of the map and if the name trait is present
val nextEntryToken = when (descriptor.hasTrait<Flattened>()) {
true -> reader.seek<XmlToken.BeginElement> { it.name.local == mapTrait.key }
false -> reader.seek<XmlToken.BeginElement> { it.name.local == mapTrait.entry }
true -> reader.peekSeek<XmlToken.BeginElement> { it.name.local == compareTo }
false -> reader.seek<XmlToken.BeginElement> { it.name.local == compareTo }
}

return nextEntryToken != null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,33 @@ public inline fun <reified T : XmlToken> XmlStreamReader.seek(selectionPredicate
return token as T?
}

/**
* Peek and seek forward until a token of type [T] is found.
* If it matches the [selectionPredicate], consume the token and return it. Otherwise, return `null` without consuming the token.
*
* @param selectionPredicate predicate that evaluates nodes of the required type to match
*/
public inline fun <reified T : XmlToken> XmlStreamReader.peekSeek(selectionPredicate: (T) -> Boolean = { true }): T? {
var token: XmlToken? = lastToken

if (token != null && token is T) {
return if (selectionPredicate.invoke(token)) token else null
}

do {
if (token is T) {
return if (selectionPredicate.invoke(token)) {
nextToken() as T
} else {
null
}
} else { nextToken() }
token = peek()
} while (token != null)

return null
}

/**
* Creates an [XmlStreamReader] instance
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,221 @@ class XmlDeserializerMapTest {

println(resp)
}

// https://github.com/awslabs/aws-sdk-kotlin/issues/962
@Test
fun itHandlesConsecutiveFlatMaps() {
val payload = """
<object>
<firstMap>
<key>key1</key>
<value>1</value>
</firstMap>
<firstMap>
<key>key2</key>
<value>2</value>
</firstMap>
<firstMap>
<key>key3</key>
<value>3</value>
</firstMap>
<secondMap>
<key>key4</key>
<value>4</value>
</secondMap>
<secondMap>
<key>key5</key>
<value>5</value>
</secondMap>
</object>
""".encodeToByteArray()
val firstMapDescriptor = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("firstMap"), XmlMapName(null, "key", "value"), Flattened)
val secondMapDescriptor = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("secondMap"), XmlMapName(null, "key", "value"), Flattened)

val objDescriptor = SdkObjectDescriptor.build {
trait(XmlSerialName("object"))
field(firstMapDescriptor)
field(secondMapDescriptor)
}
var firstMap = mutableMapOf<String, Int>()
var secondMap = mutableMapOf<String, Int>()
val deserializer = XmlDeserializer(payload)
deserializer.deserializeStruct(objDescriptor) {
loop@while (true) {
when (findNextFieldIndex()) {
firstMapDescriptor.index ->
firstMap =
deserializer.deserializeMap(firstMapDescriptor) {
val map0 = mutableMapOf<String, Int>()
while (hasNextEntry()) {
val k0 = key()
val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue }
map0[k0] = v0
}
map0
}
secondMapDescriptor.index ->
secondMap =
deserializer.deserializeMap(secondMapDescriptor) {
val map0 = mutableMapOf<String, Int>()
while (hasNextEntry()) {
val k0 = key()
val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue }
map0[k0] = v0
}
map0
}
null -> break@loop
else -> skipValue()
}
}
}

val expectedFirstMap = mapOf("key1" to 1, "key2" to 2, "key3" to 3)
firstMap.shouldContainExactly(expectedFirstMap)
val expectedSecondMap = mapOf("key4" to 4, "key5" to 5)
secondMap.shouldContainExactly(expectedSecondMap)
}

@Test
fun itHandlesMapsFollowedByFlatMaps() {
val payload = """
<object>
<map>
<entry>
<key>key1</key>
<value>1</value>
</entry>
<entry>
<key>key2</key>
<value>2</value>
</entry>
</map>
<flatMap>
<key>key3</key>
<value>3</value>
</flatMap>
<flatMap>
<key>key4</key>
<value>4</value>
</flatMap>
</object>
""".encodeToByteArray()
val mapDescriptor = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("map"))
val flatMapDescriptor = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("flatMap"), Flattened)
val objDescriptor = SdkObjectDescriptor.build {
trait(XmlSerialName("object"))
field(mapDescriptor)
field(flatMapDescriptor)
}

var map = mutableMapOf<String, Int>()
var flatMap = mutableMapOf<String, Int>()

val deserializer = XmlDeserializer(payload)
deserializer.deserializeStruct(objDescriptor) {
loop@while (true) {
when (findNextFieldIndex()) {
mapDescriptor.index ->
map =
deserializer.deserializeMap(mapDescriptor) {
val map0 = mutableMapOf<String, Int>()
while (hasNextEntry()) {
val k0 = key()
val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue }
map0[k0] = v0
}
map0
}
flatMapDescriptor.index ->
flatMap =
deserializer.deserializeMap(flatMapDescriptor) {
val map0 = mutableMapOf<String, Int>()
while (hasNextEntry()) {
val k0 = key()
val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue }
map0[k0] = v0
}
map0
}
null -> break@loop
else -> skipValue()
}
}
}
map.shouldContainExactly(mapOf("key1" to 1, "key2" to 2))
flatMap.shouldContainExactly(mapOf("key3" to 3, "key4" to 4))
}

@Test
fun itHandlesFlatMapsFollowedByMaps() {
val payload = """
<object>
<flatMap>
<key>key3</key>
<value>3</value>
</flatMap>
<flatMap>
<key>key4</key>
<value>4</value>
</flatMap>
<map>
<entry>
<key>key1</key>
<value>1</value>
</entry>
<entry>
<key>key2</key>
<value>2</value>
</entry>
</map>
</object>
""".encodeToByteArray()
val mapDescriptor = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("map"))
val flatMapDescriptor = SdkFieldDescriptor(SerialKind.Map, XmlSerialName("flatMap"), Flattened)
val objDescriptor = SdkObjectDescriptor.build {
trait(XmlSerialName("object"))
field(mapDescriptor)
field(flatMapDescriptor)
}

var map = mutableMapOf<String, Int>()
var flatMap = mutableMapOf<String, Int>()

val deserializer = XmlDeserializer(payload)
deserializer.deserializeStruct(objDescriptor) {
loop@while (true) {
when (findNextFieldIndex()) {
mapDescriptor.index ->
map =
deserializer.deserializeMap(mapDescriptor) {
val map0 = mutableMapOf<String, Int>()
while (hasNextEntry()) {
val k0 = key()
val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue }
map0[k0] = v0
}
map0
}
flatMapDescriptor.index ->
flatMap =
deserializer.deserializeMap(flatMapDescriptor) {
val map0 = mutableMapOf<String, Int>()
while (hasNextEntry()) {
val k0 = key()
val v0 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue }
map0[k0] = v0
}
map0
}
null -> break@loop
else -> skipValue()
}
}
}
map.shouldContainExactly(mapOf("key1" to 1, "key2" to 2))
flatMap.shouldContainExactly(mapOf("key3" to 3, "key4" to 4))
}
}

internal class XmlMapsOperationDeserializer() {
Expand Down

0 comments on commit 422e04e

Please sign in to comment.