Skip to content

Commit

Permalink
Appease mypy.
Browse files Browse the repository at this point in the history
  • Loading branch information
kklein committed Sep 5, 2024
1 parent 02529ae commit adc910c
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 18 deletions.
18 changes: 17 additions & 1 deletion metalearners/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import pandas as pd
import polars as pl
import scipy
from sklearn.base import check_array, check_X_y, is_classifier, is_regressor
from sklearn.ensemble import (
Expand All @@ -32,8 +33,15 @@ def safe_len(X: Matrix) -> int:
return len(X)


def copy_matrix(matrix: Matrix) -> Matrix:
"""Make a copy of a matrix."""
if isinstance(matrix, pl.DataFrame):
return matrix.clone()

Check warning on line 39 in metalearners/_utils.py

View check run for this annotation

Codecov / codecov/patch

metalearners/_utils.py#L39

Added line #L39 was not covered by tests
return matrix.copy()


def index_matrix(matrix: Matrix, rows: Vector) -> Matrix:
"""Subselect certain rows from a matrix."""
"""Subselect certain ows from a matrix."""
if isinstance(rows, pd.Series):
rows = rows.to_numpy()
if isinstance(matrix, pd.DataFrame):
Expand All @@ -60,6 +68,14 @@ def are_pd_indices_equal(*args: pd.DataFrame | pd.Series) -> bool:
return True


def to_np(data: Vector | Matrix) -> np.ndarray:
if isinstance(data, np.ndarray):
return data
if hasattr(data, "to_numpy"):
return data.to_numpy()
return np.array(data)

Check warning on line 76 in metalearners/_utils.py

View check run for this annotation

Codecov / codecov/patch

metalearners/_utils.py#L76

Added line #L76 was not covered by tests


def is_pd_df_or_series(arg) -> bool:
return isinstance(arg, pd.DataFrame) or isinstance(arg, pd.Series)

Expand Down
9 changes: 7 additions & 2 deletions metalearners/data_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@

import numpy as np
import pandas as pd
import polars as pl
from scipy.stats import wishart

from metalearners._typing import Matrix, Vector
from metalearners._utils import (
check_probability,
check_propensity_score,
convert_and_pad_propensity_score,
copy_matrix,
default_rng,
get_n_variants,
sigmoid,
Expand Down Expand Up @@ -239,8 +241,11 @@ def insert_missing(
check_probability(missing_probability, zero_included=True)
missing_mask = rng.binomial(1, p=missing_probability, size=X.shape).astype("bool")

masked = X.copy()
masked[missing_mask] = np.nan
masked = copy_matrix(X)
if isinstance(masked, pl.DataFrame):
raise ValueError()

Check warning on line 246 in metalearners/data_generation.py

View check run for this annotation

Codecov / codecov/patch

metalearners/data_generation.py#L246

Added line #L246 was not covered by tests
else:
masked[missing_mask] = np.nan
return masked


Expand Down
3 changes: 2 additions & 1 deletion metalearners/drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
index_matrix,
infer_input_dict,
safe_len,
to_np,
validate_valid_treatment_variant_not_control,
warning_experimental_feature,
)
Expand Down Expand Up @@ -416,7 +417,7 @@ def _pseudo_outcome(
y0_estimate = y0_estimate[:, 0]
y1_estimate = y1_estimate[:, 0]

pseudo_outcome = (
pseudo_outcome = to_np(
(
(y - y1_estimate)
/ clip_element_absolute_value_to_epsilon(
Expand Down
2 changes: 1 addition & 1 deletion metalearners/metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,7 +1336,7 @@ def __init__(
n_folds=n_folds,
random_state=random_state,
)
self._treatment_variants_mask: list[np.ndarray] | None = None
self._treatment_variants_mask: list[Vector] | None = None

def predict_conditional_average_outcomes(
self, X: Matrix, is_oos: bool, oos_method: OosMethod = OVERALL
Expand Down
10 changes: 7 additions & 3 deletions metalearners/rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
get_predict,
get_predict_proba,
index_matrix,
index_vector,
infer_input_dict,
safe_len,
to_np,
validate_all_vectors_same_index,
validate_valid_treatment_variant_not_control,
warning_experimental_feature,
Expand Down Expand Up @@ -516,14 +518,16 @@ def _pseudo_outcome_and_weights(

y_residuals = y[mask] - y_estimates

w_binarized = w[mask] == treatment_variant
w_binarized = to_np(index_vector(w, mask) == treatment_variant)
w_residuals = w_binarized - w_estimates_binarized
w_residuals_padded = clip_element_absolute_value_to_epsilon(
w_residuals, epsilon
)

pseudo_outcomes = y_residuals / w_residuals_padded
weights = np.square(w_residuals)
pseudo_outcomes = to_np(y_residuals / w_residuals_padded)
# In principle np.square could also return a scalar.
# We ensure that the type is np.ndarray.
weights = to_np(np.square(w_residuals))

return pseudo_outcomes, weights

Expand Down
2 changes: 1 addition & 1 deletion metalearners/tlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def fit_all_nuisance(
self._validate_treatment(w)
self._validate_outcome(y, w)

self._treatment_variants_mask = []
self._treatment_variants_mask: list[Vector] = []

for v in range(self.n_variants):
self._treatment_variants_mask.append(w == v)
Expand Down
20 changes: 11 additions & 9 deletions metalearners/xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
get_predict,
get_predict_proba,
index_matrix,
index_vector,
infer_input_dict,
infer_probabilities_output,
safe_len,
to_np,
validate_valid_treatment_variant_not_control,
warning_experimental_feature,
)
Expand Down Expand Up @@ -96,7 +98,7 @@ def fit_all_nuisance(
self._validate_treatment(w)
self._validate_outcome(y, w)

self._treatment_variants_mask = []
self._treatment_variants_mask: list[Vector] = []

qualified_fit_params = self._qualified_fit_params(fit_params)

Expand Down Expand Up @@ -421,12 +423,10 @@ def _pseudo_outcome(
treatment_indices = w == treatment_variant
control_indices = w == 0

treatment_outcome = index_matrix(
conditional_average_outcome_estimates, control_indices
)[:, treatment_variant]
control_outcome = index_matrix(
conditional_average_outcome_estimates, treatment_indices
)[:, 0]
treatment_outcome = conditional_average_outcome_estimates[
control_indices, treatment_variant
]
control_outcome = conditional_average_outcome_estimates[treatment_indices, 0]

if self.is_classification:
# Get the probability of positive class, multiclass is currently not supported.
Expand All @@ -436,8 +436,10 @@ def _pseudo_outcome(
control_outcome = control_outcome[:, 0]
treatment_outcome = treatment_outcome[:, 0]

imputed_te_treatment = y[treatment_indices] - control_outcome
imputed_te_control = treatment_outcome - y[control_indices]
imputed_te_treatment = (
to_np(index_vector(y, treatment_indices)) - control_outcome
)
imputed_te_control = treatment_outcome - to_np(index_vector(y, control_indices))

return imputed_te_control, imputed_te_treatment

Expand Down

0 comments on commit adc910c

Please sign in to comment.