Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add map_from_entries Spark function #11934

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions velox/docs/functions/spark/map.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ Map Functions

SELECT map_from_arrays(array(1.0, 3.0), array('2', '4')); -- {1.0 -> 2, 3.0 -> 4}

.. spark:function:: map_from_entries(array(struct(K,V))) -> map(K,V)

Returns a map created from the given array of entries. Exception is thrown if the entries of structs contain duplicate key,
or one entry has a null key. Returns null if one entry is null. ::

SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'null'))); -- {1 -> 'a', 2 -> 'null'}
SELECT map_from_entries(array(struct(1, 'a'), null)); -- {null}
SELECT map_from_entries(array(struct(null, 'a'))); -- "map key cannot be null"
SELECT map_from_entries(array(struct(1, 'a'), struct(1, 'b'))); -- "Duplicate map keys (1) are not allowed"

.. spark:function:: map_keys(x(K,V)) -> array(K)

Returns all the keys in the map ``x``.
Expand Down
1 change: 1 addition & 0 deletions velox/functions/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ velox_add_library(
CheckNestedNulls.cpp
KllSketch.cpp
MapConcat.cpp
MapFromEntries.cpp
Re2Functions.cpp
Repeat.cpp
Slice.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ namespace {
static const char* kNullKeyErrorMessage = "map key cannot be null";
static const char* kErrorMessageEntryNotNull = "map entry cannot be null";

// See documentation at https://prestodb.io/docs/current/functions/map.html
class MapFromEntriesFunction : public exec::VectorFunction {
public:
// @param throwOnNull If true, throws exception when input array is null or
// contains null entry. Otherwise, returns null.
MapFromEntriesFunction(bool throwOnNull) : throwOnNull_(throwOnNull) {}
void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
Expand Down Expand Up @@ -92,14 +94,16 @@ class MapFromEntriesFunction : public exec::VectorFunction {
auto& inputValueVector = inputArray->elements();
exec::LocalDecodedVector decodedRowVector(context);
decodedRowVector.get()->decode(*inputValueVector);
// If the input array(unknown) then all rows should have errors.
if (inputValueVector->typeKind() == TypeKind::UNKNOWN) {
try {
VELOX_USER_FAIL(kErrorMessageEntryNotNull);
} catch (...) {
context.setErrors(rows, std::current_exception());
// For presto, if the input array(unknown) then all rows should have
// errors.
if (throwOnNull_) {
try {
VELOX_USER_FAIL(kErrorMessageEntryNotNull);
} catch (...) {
context.setErrors(rows, std::current_exception());
}
}

auto sizes = allocateSizes(rows.end(), context.pool());
auto offsets = allocateSizes(rows.end(), context.pool());

Expand Down Expand Up @@ -127,8 +131,9 @@ class MapFromEntriesFunction : public exec::VectorFunction {
});

auto resetSize = [&](vector_size_t row) { mutableSizes[row] = 0; };
auto nulls = allocateNulls(decodedRowVector->size(), context.pool());
auto* mutableNulls = nulls->asMutable<uint64_t>();

// Validate all map entries and map keys are not null.
if (decodedRowVector->mayHaveNulls() || keyVector->mayHaveNulls() ||
keyVector->mayHaveNullsRecursive()) {
context.applyToSelectedNoThrow(rows, [&](vector_size_t row) {
Expand All @@ -139,11 +144,13 @@ class MapFromEntriesFunction : public exec::VectorFunction {
// Check nulls in the top level row vector.
const bool isMapEntryNull = decodedRowVector->isNullAt(offset + i);
if (isMapEntryNull) {
// Set the sizes to 0 so that the final map vector generated is
// valid in case we are inside a try. The map vector needs to be
// valid because its consumed by checkDuplicateKeys before try
// sets invalid rows to null.
// The map vector needs to be valid because its consumed by
// checkDuplicateKeys before try sets invalid rows to null.
resetSize(row);
if (!throwOnNull_) {
bits::setNull(mutableNulls, row);
break;
}
VELOX_USER_FAIL(kErrorMessageEntryNotNull);
}

Expand Down Expand Up @@ -189,8 +196,6 @@ class MapFromEntriesFunction : public exec::VectorFunction {
} else {
// Dictionary.
auto indices = allocateIndices(decodedRowVector->size(), context.pool());
auto nulls = allocateNulls(decodedRowVector->size(), context.pool());
auto* mutableNulls = nulls->asMutable<uint64_t>();
memcpy(
indices->asMutable<vector_size_t>(),
decodedRowVector->indices(),
Expand All @@ -208,12 +213,20 @@ class MapFromEntriesFunction : public exec::VectorFunction {
nulls, indices, decodedRowVector->size(), rowVector->childAt(1));
}

// To avoid creating new buffers, we try to reuse the input's buffers
// as many as possible.
// For Presto, need construct map vector based on input nulls for possible
// outer expression like try(). For Spark, use the updated nulls unless it's
// empty.
if (throwOnNull_) {
nulls = inputArray->nulls();
} else {
if (decodedRowVector->size() == 0) {
nulls = inputArray->nulls();
}
}
auto mapVector = std::make_shared<MapVector>(
context.pool(),
outputType,
inputArray->nulls(),
nulls,
rows.end(),
inputArray->offsets(),
sizes,
Expand All @@ -223,11 +236,15 @@ class MapFromEntriesFunction : public exec::VectorFunction {
checkDuplicateKeys(mapVector, *remianingRows, context);
return mapVector;
}

const bool throwOnNull_;
};
} // namespace

VELOX_DECLARE_VECTOR_FUNCTION(
udf_map_from_entries,
MapFromEntriesFunction::signatures(),
std::make_unique<MapFromEntriesFunction>());
void registerMapFromEntriesFunction(const std::string& name, bool throwOnNull) {
exec::registerVectorFunction(
name,
MapFromEntriesFunction::signatures(),
std::make_unique<MapFromEntriesFunction>(throwOnNull));
}
} // namespace facebook::velox::functions
27 changes: 27 additions & 0 deletions velox/functions/lib/MapFromEntries.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <string>

namespace facebook::velox::functions {

/// @param throwOnNull If true, throws exception when input array is null or
/// contains null entry. Otherwise, returns null.
void registerMapFromEntriesFunction(const std::string& name, bool throwForNull);

} // namespace facebook::velox::functions
1 change: 0 additions & 1 deletion velox/functions/prestosql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ velox_add_library(
JsonFunctions.cpp
Map.cpp
MapEntries.cpp
MapFromEntries.cpp
MapKeysAndValues.cpp
MapZipWith.cpp
Not.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "velox/expression/VectorFunction.h"
#include "velox/functions/Registerer.h"
#include "velox/functions/lib/MapConcat.h"
#include "velox/functions/lib/MapFromEntries.h"
#include "velox/functions/prestosql/Map.h"
#include "velox/functions/prestosql/MapFunctions.h"
#include "velox/functions/prestosql/MapNormalize.h"
Expand Down Expand Up @@ -87,8 +88,8 @@ void registerMapFunctions(const std::string& prefix) {
udf_transform_values, prefix + "transform_values");
registerMapFunction(prefix + "map", false /*allowDuplicateKeys*/);
VELOX_REGISTER_VECTOR_FUNCTION(udf_map_entries, prefix + "map_entries");
VELOX_REGISTER_VECTOR_FUNCTION(
udf_map_from_entries, prefix + "map_from_entries");
registerMapFromEntriesFunction(prefix + "map_from_entries", true);

VELOX_REGISTER_VECTOR_FUNCTION(udf_map_keys, prefix + "map_keys");
VELOX_REGISTER_VECTOR_FUNCTION(udf_map_values, prefix + "map_values");
VELOX_REGISTER_VECTOR_FUNCTION(udf_map_zip_with, prefix + "map_zip_with");
Expand Down
2 changes: 2 additions & 0 deletions velox/functions/sparksql/registration/RegisterMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/functions/lib/MapFromEntries.h"
#include "velox/functions/lib/RegistrationHelpers.h"
#include "velox/functions/sparksql/Size.h"

Expand All @@ -24,6 +25,7 @@ extern void registerElementAtFunction(
void registerSparkMapFunctions(const std::string& prefix) {
VELOX_REGISTER_VECTOR_FUNCTION(
udf_map_allow_duplicates, prefix + "map_from_arrays");
registerMapFromEntriesFunction(prefix + "map_from_entries", false);
// Spark and Presto map_filter function has the same definition:
// function expression corresponds to body, arguments to signature
VELOX_REGISTER_VECTOR_FUNCTION(udf_map_filter, prefix + "map_filter");
Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ add_executable(
MakeDecimalTest.cpp
MakeTimestampTest.cpp
MapTest.cpp
MapFromEntriesTest.cpp
MaskTest.cpp
MightContainTest.cpp
MonotonicallyIncreasingIdTest.cpp
Expand Down
162 changes: 162 additions & 0 deletions velox/functions/sparksql/tests/MapFromEntriesTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstdint>
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/prestosql/ArrayConstructor.h"
#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h"

using namespace facebook::velox::test;

namespace facebook::velox::functions::sparksql::test {
namespace {
std::optional<std::vector<std::pair<int32_t, std::optional<int32_t>>>> O(
const std::vector<std::pair<int32_t, std::optional<int32_t>>>& vector) {
return std::make_optional(vector);
}

class MapFromEntriesTest : public SparkFunctionBaseTest {
protected:
// Create an MAP vector of size 1 using specified 'keys' and 'values' vector.
VectorPtr makeSingleRowMapVector(
const VectorPtr& keys,
const VectorPtr& values) {
BufferPtr offsets = allocateOffsets(1, pool());
BufferPtr sizes = allocateSizes(1, pool());
sizes->asMutable<vector_size_t>()[0] = keys->size();

return std::make_shared<MapVector>(
pool(),
MAP(keys->type(), values->type()),
nullptr,
1,
offsets,
sizes,
keys,
values);
}

void verifyMapFromEntries(const VectorPtr& input, const VectorPtr& expected) {
const std::string expr = fmt::format("map_from_entries({})", "c0");
auto result = evaluate(expr, makeRowVector({input}));
assertEqualVectors(expected, result);
}
};
} // namespace

TEST_F(MapFromEntriesTest, nullMapEntries) {
auto rowType = ROW({INTEGER(), INTEGER()});
{
std::vector<std::vector<std::optional<std::tuple<int32_t, int32_t>>>> data =
{
{std::nullopt},
{{{1, 11}}},
};
auto input = makeArrayOfRowVector(data, rowType);
auto expected =
makeNullableMapVector<int32_t, int32_t>({std::nullopt, O({{1, 11}})});
verifyMapFromEntries(input, expected);
}
{
// Create array(row(a,b)) where a, b sizes are 0 because all row(a, b)
// values are null.
std::vector<std::vector<std::optional<std::tuple<int32_t, int32_t>>>> data =
{
{std::nullopt, std::nullopt, std::nullopt},
{std::nullopt},
};
auto input = makeArrayOfRowVector(data, rowType);
auto rowInput = input->as<ArrayVector>();
rowInput->elements()->as<RowVector>()->childAt(0)->resize(0);
rowInput->elements()->as<RowVector>()->childAt(1)->resize(0);

auto expected =
makeNullableMapVector<int32_t, int32_t>({std::nullopt, std::nullopt});
verifyMapFromEntries(input, expected);
}
}

TEST_F(MapFromEntriesTest, nullKeys) {
auto rowType = ROW({INTEGER(), INTEGER()});
std::vector<std::vector<variant>> data = {
{variant::row({variant::null(TypeKind::INTEGER), 0})},
{variant::row({1, 11})}};
auto input = makeArrayOfRowVector(rowType, data);
VELOX_ASSERT_THROW(
evaluate("map_from_entries(c0)", makeRowVector({input})),
"map key cannot be null");
}

TEST_F(MapFromEntriesTest, arrayOfConstantRowOfNulls) {
RowVectorPtr rowVector =
makeRowVector({makeFlatVector<int32_t>(0), makeFlatVector<int32_t>(0)});
rowVector->resize(1);
rowVector->setNull(0, true);
rowVector->childAt(0)->resize(0);
rowVector->childAt(1)->resize(0);
EXPECT_EQ(rowVector->childAt(0)->size(), 0);
EXPECT_EQ(rowVector->childAt(1)->size(), 0);

VectorPtr rowVectorConstant = BaseVector::wrapInConstant(4, 0, rowVector);

auto offsets = makeIndices({0, 2});
auto sizes = makeIndices({2, 2});

auto arrayVector = std::make_shared<ArrayVector>(
pool(),
ARRAY(ROW({INTEGER(), INTEGER()})),
nullptr,
2,
offsets,
sizes,
rowVectorConstant);
VectorPtr result =
evaluate("map_from_entries(c0)", makeRowVector({arrayVector}));
for (int i = 0; i < result->size(); i++) {
EXPECT_TRUE(result->isNullAt(i));
}
}

TEST_F(MapFromEntriesTest, unknownInputs) {
facebook::velox::functions::registerArrayConstructor("array_constructor");
auto expectedType = MAP(UNKNOWN(), UNKNOWN());
auto test = [&](const std::string& query) {
auto result = evaluate(query, makeRowVector({makeFlatVector<int32_t>(2)}));
ASSERT_TRUE(result->type()->equivalent(*expectedType));
};
VELOX_ASSERT_THROW(
evaluate(
"map_from_entries(array_constructor(row_constructor(null, null)))",
makeRowVector({makeFlatVector<int32_t>(2)})),
"map key cannot be null");
test("map_from_entries(array_constructor(null))");
test("map_from_entries(null)");
}

TEST_F(MapFromEntriesTest, nullRowEntriesWithSmallerChildren) {
// Row vector is of size 3, childrens are of size 2 since row 2 is null.
auto rowVector = makeRowVector(
{makeNullableFlatVector<int32_t>({std::nullopt, 2}),
makeFlatVector<int32_t>({1, 2})});
rowVector->appendNulls(1);
rowVector->setNull(2, true);

// Array [(null,1), (2,2), null]
auto arrayVector = makeArrayVector({0}, rowVector);
VELOX_ASSERT_THROW(
evaluate("map_from_entries(c0)", makeRowVector({arrayVector})),
"map key cannot be null");
}
} // namespace facebook::velox::functions::sparksql::test
Loading