Skip to content

Commit

Permalink
Add tests for index_matrix and index_vector. (#89)
Browse files Browse the repository at this point in the history
  • Loading branch information
kklein authored Aug 29, 2024
1 parent 0d1958c commit 241d681
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
1 change: 1 addition & 0 deletions metalearners/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@


def safe_len(X: Matrix) -> int:
"""Determine the length of a Matrix."""
if scipy.sparse.issparse(X):
return X.shape[0]
return len(X)
Expand Down
59 changes: 59 additions & 0 deletions tests/test__utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
from glum import GeneralizedLinearRegressor, GeneralizedLinearRegressorCV
from lightgbm import LGBMClassifier, LGBMRegressor
from scipy.sparse import csr_matrix
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.linear_model import LinearRegression
from xgboost import XGBClassifier, XGBRegressor
Expand All @@ -20,6 +21,8 @@
convert_treatment,
function_has_argument,
get_linear_dimension,
index_matrix,
index_vector,
supports_categoricals,
validate_all_vectors_same_index,
validate_model_and_predict_method,
Expand Down Expand Up @@ -345,3 +348,59 @@ def test_validate_valid_treatment_variant_not_control(
else:
with pytest.raises(ValueError, match="variant"):
validate_valid_treatment_variant_not_control(treatment_variant, n_variants)


@pytest.mark.parametrize("matrix_backend", [np.ndarray, pd.DataFrame, csr_matrix])
@pytest.mark.parametrize("rows_backend", [np.array, pd.Series])
def test_index_matrix(matrix_backend, rows_backend):
n_samples = 10
if matrix_backend == np.ndarray:
matrix = np.array(list(range(n_samples))).reshape((-1, 1))
elif matrix_backend == pd.DataFrame:
# We make sure that the index is not equal to the row number.
matrix = pd.DataFrame(
list(range(n_samples)), index=list(range(20, 20 + n_samples))
)
elif matrix_backend == csr_matrix:
matrix = csr_matrix(np.array(list(range(n_samples))).reshape((-1, 1)))
else:
raise ValueError()
rows = rows_backend([1, 4, 5])
result = index_matrix(matrix=matrix, rows=rows)

assert isinstance(result, matrix_backend)
assert result.shape[1] == matrix.shape[1]

if isinstance(result, pd.DataFrame):
processed_result = result.values[:, 0]
else:
processed_result = result[:, 0]

expected = np.array([1, 4, 5])
assert (processed_result == expected).sum() == len(expected)


@pytest.mark.parametrize("vector_backend", [np.ndarray, pd.Series])
@pytest.mark.parametrize("rows_backend", [np.array, pd.Series])
def test_index_vector(vector_backend, rows_backend):
n_samples = 10
if vector_backend == np.ndarray:
vector = np.array(list(range(n_samples)))
elif vector_backend == pd.Series:
# We make sure that the index is not equal to the row number.
vector = pd.Series(
list(range(n_samples)), index=list(range(20, 20 + n_samples))
)
else:
raise ValueError()

rows = rows_backend([1, 4, 5])

result = index_vector(vector=vector, rows=rows)
assert isinstance(result, vector_backend)

if isinstance(result, pd.Series):
result = result.values

expected = np.array([1, 4, 5])
assert (result == expected).all()

0 comments on commit 241d681

Please sign in to comment.