Skip to content

Commit

Permalink
Added some more plotting/prediction functions
Browse files Browse the repository at this point in the history
  • Loading branch information
lbeesleyBIOSTAT authored May 6, 2024
1 parent bf63196 commit da86f83
Showing 1 changed file with 71 additions and 0 deletions.
71 changes: 71 additions & 0 deletions impala/superCal/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,59 @@ def parameter_trace_plot(sample_parameters,ylim=None):



def parameter_trace_plot_rollmean(sample_parameters,ylim=None,num_draws = 100):
palette = plt.get_cmap('Set1')
if len(sample_parameters.shape) == 1:
n = sample_parameters.shape[0]
plt.plot(range(n), sample_parameters, marker = '', linewidth = 1)
plt.plot(range(n), pd.DataFrame(sample_parameters.reshape(-1,1)).rolling(num_draws).mean(), c = 'black', zorder = 2)
else:
# df = pd.DataFrame(sample_parameters, self.parameter_order)
n, d = sample_parameters.shape
for i in range(d):
plt.subplot(d, 1, i+1)
plt.plot(range(n), sample_parameters[:,i], marker = '',
color = palette(i), linewidth = 1)
plt.plot(range(n), pd.DataFrame(sample_parameters[:,i].reshape(-1,1)).rolling(num_draws).mean(), c = 'black', zorder = 2)
ax = plt.gca()
if ylim is not None:
ax.set_ylim(ylim)
# plt.show()
return



### Shows the number of swaps between pairs of temperatures.
### Want to see a substantial proportion of swaps
def total_temperature_swaps(out, setup):
if 'theta0' in dir(out):
ax = sns.heatmap(out.count_temper, linewidths =0)
ax.set_xlabel('Temperature Index')
ax.set_ylabel('Temperature Index')
ax.set_title('Number of Temperature Swaps')
total_swaps = out.count_temper.sum()/2 #symmetric matrix, number of accepted swaps
swaps_per_iter = total_swaps/(out.theta0.shape[0]-setup.start_temper)
attempts_per_iter = setup.nswap_per * setup.nswap
ax.text(out.count_temper.shape[0]*0.9,out.count_temper.shape[0]*0.1,
s = "Prop. Swaps Accepted: " + str(round(swaps_per_iter/attempts_per_iter,3)), color = 'white',
horizontalalignment = 'right')
else:
ax = sns.heatmap(out.count, linewidths =0)
ax.set_xlabel('Temperature Index')
ax.set_ylabel('Temperature Index')
ax.set_title('Number of Temperature Swaps')
total_swaps = out.count.sum()/2 #symmetric matrix, number of accepted swaps
swaps_per_iter = total_swaps/(out.theta.shape[0]-setup.start_temper)
attempts_per_iter = setup.nswap_per * setup.nswap
ax.text(out.count.shape[0]*0.9,out.count.shape[0]*0.1,
s = "Prop. Swaps Accepted: " + str(round(swaps_per_iter/attempts_per_iter,3)), color = 'white',
horizontalalignment = 'right')
return





def save_parent_strength(setup, ptw_mod, calib_out, mcmc_use, path):

theta_parent = sc.chol_sample_1per_constraints(
Expand Down Expand Up @@ -718,3 +771,21 @@ def get_best_sse(results_csv, write_path):

return


### Function for predicting the outcome for a given
### value of theta in its unnormalized (i.e., native) domain
def get_outcome_predictions_impala(setup, theta_input, disc_input = None):
outcome_draws = [np.empty([theta_input.shape[0], setup.ys[i].shape[0]]) for i in range(setup.nexp)]
CONSTRAINTS = setup.checkConstraints(sc.tran_unif(sc.normalize(theta_input,setup.bounds_mat),setup.bounds_mat, setup.bounds.keys()), setup.bounds)
if str(type(disc_input)) == "<class 'NoneType'>":
for i in range(setup.nexp):
outcome_draws[i] = setup.models[i].eval(sc.tran_unif(sc.normalize(theta_input,setup.bounds_mat),setup.bounds_mat, setup.bounds.keys()), pool = True) #even for hierarchical
else:
disc_dims = [setup.models[i].nd for i in range(setup.nexp)]
v_cur_list = [disc_input[:,0:disc_dims[i]] if i == 0 else disc_input[:,int(np.sum(disc_dims[0:(i-1)])):int(np.sum(disc_dims[0:i]))] for i in range(setup.nexp)]
for i in range(setup.nexp):
disc_y = (setup.models[i].D @ v_cur_list[i].T).T if setup.models[i].nd > 0 else np.repeat(0,len(setup.ys[i]))
outcome_draws[i] = setup.models[i].eval(sc.tran_unif(sc.normalize(theta_input,setup.bounds_mat),setup.bounds_mat, setup.bounds.keys()), pool = True) + disc_y#even for hierarchical
RESULTS = dict({'outcome_draws': outcome_draws, 'constraints_satisfied': CONSTRAINTS})
return RESULTS

0 comments on commit da86f83

Please sign in to comment.