Skip to content

Commit

Permalink
Rename indices to mask. Closes issue #91 (#92)
Browse files Browse the repository at this point in the history
* rename indices to mask

* add CHANGELOG.rst entry

* rename additional variables in tlearner.py and update CHANGELOG.rst

* rename missed variables

* change indice to mask in drlearner.py and update CHANGELOG.rst

* remove changelog entry for 0.12.0 and files in docs/api
  • Loading branch information
kyracho authored Sep 4, 2024
1 parent 241d681 commit 93ec745
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 51 deletions.
Binary file added docs/examples/model.onnx
Binary file not shown.
10 changes: 4 additions & 6 deletions metalearners/drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,12 @@ def fit_all_nuisance(
self._validate_treatment(w)
self._validate_outcome(y, w)

self._treatment_variants_indices = []
self._treatment_variants_mask = []

qualified_fit_params = self._qualified_fit_params(fit_params)

for treatment_variant in range(self.n_variants):
self._treatment_variants_indices.append(w == treatment_variant)
self._treatment_variants_mask.append(w == treatment_variant)

self._cv_split_indices: SplitIndices | None

Expand All @@ -168,10 +168,8 @@ def fit_all_nuisance(
for treatment_variant in range(self.n_variants):
nuisance_jobs.append(
self._nuisance_joblib_specifications(
X=index_matrix(
X, self._treatment_variants_indices[treatment_variant]
),
y=y[self._treatment_variants_indices[treatment_variant]],
X=index_matrix(X, self._treatment_variants_mask[treatment_variant]),
y=y[self._treatment_variants_mask[treatment_variant]],
model_kind=VARIANT_OUTCOME_MODEL,
model_ord=treatment_variant,
n_jobs_cross_fitting=n_jobs_cross_fitting,
Expand Down
14 changes: 7 additions & 7 deletions metalearners/metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,15 +1336,15 @@ def __init__(
n_folds=n_folds,
random_state=random_state,
)
self._treatment_variants_indices: list[np.ndarray] | None = None
self._treatment_variants_mask: list[np.ndarray] | None = None

def predict_conditional_average_outcomes(
self, X: Matrix, is_oos: bool, oos_method: OosMethod = OVERALL
) -> np.ndarray:
if self._treatment_variants_indices is None:
if self._treatment_variants_mask is None:
raise ValueError(
"The metalearner needs to be fitted before predicting."
"In particular, the MetaLearner's attribute _treatment_variant_indices, "
"In particular, the MetaLearner's attribute _treatment_variant_mask, "
"typically set during fitting, is None."
)
# TODO: Consider multiprocessing
Expand All @@ -1363,17 +1363,17 @@ def predict_conditional_average_outcomes(
)
else:
conditional_average_outcomes_list[tv][
self._treatment_variants_indices[tv]
self._treatment_variants_mask[tv]
] = self.predict_nuisance(
X=index_matrix(X, self._treatment_variants_indices[tv]),
X=index_matrix(X, self._treatment_variants_mask[tv]),
model_kind=VARIANT_OUTCOME_MODEL,
model_ord=tv,
is_oos=False,
)
conditional_average_outcomes_list[tv][
~self._treatment_variants_indices[tv]
~self._treatment_variants_mask[tv]
] = self.predict_nuisance(
X=index_matrix(X, ~self._treatment_variants_indices[tv]),
X=index_matrix(X, ~self._treatment_variants_mask[tv]),
model_kind=VARIANT_OUTCOME_MODEL,
model_ord=tv,
is_oos=True,
Expand Down
10 changes: 4 additions & 6 deletions metalearners/tlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,19 @@ def fit_all_nuisance(
self._validate_treatment(w)
self._validate_outcome(y, w)

self._treatment_variants_indices = []
self._treatment_variants_mask = []

for v in range(self.n_variants):
self._treatment_variants_indices.append(w == v)
self._treatment_variants_mask.append(w == v)

qualified_fit_params = self._qualified_fit_params(fit_params)

nuisance_jobs: list[_ParallelJoblibSpecification | None] = []
for treatment_variant in range(self.n_variants):
nuisance_jobs.append(
self._nuisance_joblib_specifications(
X=index_matrix(
X, self._treatment_variants_indices[treatment_variant]
),
y=y[self._treatment_variants_indices[treatment_variant]],
X=index_matrix(X, self._treatment_variants_mask[treatment_variant]),
y=y[self._treatment_variants_mask[treatment_variant]],
model_kind=VARIANT_OUTCOME_MODEL,
model_ord=treatment_variant,
n_jobs_cross_fitting=n_jobs_cross_fitting,
Expand Down
52 changes: 22 additions & 30 deletions metalearners/xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,17 @@ def fit_all_nuisance(
self._validate_treatment(w)
self._validate_outcome(y, w)

self._treatment_variants_indices = []
self._treatment_variants_mask = []

qualified_fit_params = self._qualified_fit_params(fit_params)

self._cvs: list = []

for treatment_variant in range(self.n_variants):
self._treatment_variants_indices.append(w == treatment_variant)
self._treatment_variants_mask.append(w == treatment_variant)
if synchronize_cross_fitting:
cv_split_indices = self._split(
index_matrix(X, self._treatment_variants_indices[treatment_variant])
index_matrix(X, self._treatment_variants_mask[treatment_variant])
)
else:
cv_split_indices = None
Expand All @@ -116,10 +116,8 @@ def fit_all_nuisance(
for treatment_variant in range(self.n_variants):
nuisance_jobs.append(
self._nuisance_joblib_specifications(
X=index_matrix(
X, self._treatment_variants_indices[treatment_variant]
),
y=y[self._treatment_variants_indices[treatment_variant]],
X=index_matrix(X, self._treatment_variants_mask[treatment_variant]),
y=y[self._treatment_variants_mask[treatment_variant]],
model_kind=VARIANT_OUTCOME_MODEL,
model_ord=treatment_variant,
n_jobs_cross_fitting=n_jobs_cross_fitting,
Expand Down Expand Up @@ -159,10 +157,10 @@ def fit_all_treatment(
synchronize_cross_fitting: bool = True,
n_jobs_base_learners: int | None = None,
) -> Self:
if self._treatment_variants_indices is None:
if self._treatment_variants_mask is None:
raise ValueError(
"The nuisance models need to be fitted before fitting the treatment models."
"In particular, the MetaLearner's attribute _treatment_variant_indices, "
"In particular, the MetaLearner's attribute _treatment_variant_mask, "
"typically set during nuisance fitting, is None."
)
if not hasattr(self, "_cvs"):
Expand All @@ -188,9 +186,7 @@ def fit_all_treatment(
)
treatment_jobs.append(
self._treatment_joblib_specifications(
X=index_matrix(
X, self._treatment_variants_indices[treatment_variant]
),
X=index_matrix(X, self._treatment_variants_mask[treatment_variant]),
y=imputed_te_treatment,
model_kind=TREATMENT_EFFECT_MODEL,
model_ord=treatment_variant - 1,
Expand All @@ -202,7 +198,7 @@ def fit_all_treatment(

treatment_jobs.append(
self._treatment_joblib_specifications(
X=index_matrix(X, self._treatment_variants_indices[0]),
X=index_matrix(X, self._treatment_variants_mask[0]),
y=imputed_te_control,
model_kind=CONTROL_EFFECT_MODEL,
model_ord=treatment_variant - 1,
Expand All @@ -225,10 +221,10 @@ def predict(
is_oos: bool,
oos_method: OosMethod = OVERALL,
) -> np.ndarray:
if self._treatment_variants_indices is None:
if self._treatment_variants_mask is None:
raise ValueError(
"The MetaLearner needs to be fitted before predicting. "
"In particular, the X-Learner's attribute _treatment_variant_indices, "
"In particular, the X-Learner's attribute _treatment_variant_mask, "
"typically set during fitting, is None."
)
n_outputs = 2 if self.is_classification else 1
Expand All @@ -243,14 +239,12 @@ def predict(
oos_method=propensity_score_oos,
)

control_indices = self._treatment_variants_indices[0]
control_indices = self._treatment_variants_mask[0]
non_control_indices = ~control_indices

for treatment_variant in range(1, self.n_variants):
treatment_variant_indices = self._treatment_variants_indices[
treatment_variant
]
non_treatment_variant_indices = ~treatment_variant_indices
treatment_variant_mask = self._treatment_variants_mask[treatment_variant]
non_treatment_variant_mask = ~treatment_variant_mask
if is_oos:
tau_hat_treatment = self.predict_treatment(
X=X,
Expand All @@ -270,18 +264,16 @@ def predict(
tau_hat_treatment = np.zeros(safe_len(X))
tau_hat_control = np.zeros(safe_len(X))

tau_hat_treatment[non_treatment_variant_indices] = (
self.predict_treatment(
X=index_matrix(X, non_treatment_variant_indices),
model_kind=TREATMENT_EFFECT_MODEL,
model_ord=treatment_variant - 1,
is_oos=True,
oos_method=oos_method,
)
tau_hat_treatment[non_treatment_variant_mask] = self.predict_treatment(
X=index_matrix(X, non_treatment_variant_mask),
model_kind=TREATMENT_EFFECT_MODEL,
model_ord=treatment_variant - 1,
is_oos=True,
oos_method=oos_method,
)

tau_hat_treatment[treatment_variant_indices] = self.predict_treatment(
X=index_matrix(X, treatment_variant_indices),
tau_hat_treatment[treatment_variant_mask] = self.predict_treatment(
X=index_matrix(X, treatment_variant_mask),
model_kind=TREATMENT_EFFECT_MODEL,
model_ord=treatment_variant - 1,
is_oos=False,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,10 +880,10 @@ def test_model_reusage(outcome_kind, request):
VARIANT_OUTCOME_MODEL: tlearner._nuisance_models[VARIANT_OUTCOME_MODEL]
},
)
# We need to manually copy _treatment_variants_indices for the xlearner as it's needed
# We need to manually copy _treatment_variants_mask for the xlearner as it's needed
# for predict, the user should not have to do this as they should call fit before predict.
# This is just for testing.
xlearner._treatment_variants_indices = tlearner._treatment_variants_indices
xlearner._treatment_variants_mask = tlearner._treatment_variants_mask
np.testing.assert_allclose(
tlearner.predict_conditional_average_outcomes(covariates, False),
xlearner.predict_conditional_average_outcomes(covariates, False),
Expand Down

0 comments on commit 93ec745

Please sign in to comment.