Skip to content

Commit

Permalink
Trying a simpler implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
ceriottm committed Sep 11, 2023
1 parent 48a2e58 commit b24f69f
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 46 deletions.
86 changes: 52 additions & 34 deletions ipi/engine/forcefields.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,6 @@ def __init__(
baseline_uncertainty=-1.0,
active_thresh=0.0,
active_out=None,
comm_type="default",
):
# force threaded mode as otherwise it cannot have threaded children
super(FFCommittee, self).__init__(
Expand Down Expand Up @@ -1097,7 +1096,6 @@ def __init__(
self.alpha = alpha
self.active_thresh = active_thresh
self.active_out = active_out
self.comm_type = comm_type

def bind(self, output_maker):
super(FFCommittee, self).bind(output_maker)
Expand Down Expand Up @@ -1139,37 +1137,59 @@ def check_finish(self, r):
def gather(self, r):
"""Collects results from all sub-requests, and assemble the committee of models."""

if self.comm_type == "default":
r["result"] = [
0.0,
np.zeros(len(r["pos"]), float),
np.zeros((3, 3), float),
"",
]

# list of pointers to the forcefield requests. shallow copy so we can remove stuff
com_handles = r["ff_handles"].copy()
if self.baseline_name != "":
# looks for the baseline potential, store its value and drops it from the list
names = [ff.name for ff in self.fflist]

for i, ff_r in enumerate(com_handles):
if names[i] == self.baseline_name:
baseline_pot = ff_r["result"][0]
baseline_frc = ff_r["result"][1]
baseline_vir = ff_r["result"][2]
baseline_xtr = ff_r["result"][3]
com_handles.pop(i)
break
r["result"] = [
0.0,
np.zeros(len(r["pos"]), float),
np.zeros((3, 3), float),
"",
]

# Gathers the forcefield energetics and extras
pots = [ff_r["result"][0] for ff_r in com_handles]
frcs = [ff_r["result"][1] for ff_r in com_handles]
virs = [ff_r["result"][2] for ff_r in com_handles]
xtrs = [ff_r["result"][3] for ff_r in com_handles]
print("quants normal ", pots, frcs, virs)
# list of pointers to the forcefield requests. shallow copy so we can remove stuff
com_handles = r["ff_handles"].copy()
if self.baseline_name != "":
# looks for the baseline potential, store its value and drops it from the list
names = [ff.name for ff in self.fflist]

for i, ff_r in enumerate(com_handles):
if names[i] == self.baseline_name:
baseline_pot = ff_r["result"][0]
baseline_frc = ff_r["result"][1]
baseline_vir = ff_r["result"][2]
baseline_xtr = ff_r["result"][3]
com_handles.pop(i)
break

# Gathers the forcefield energetics and extras
pots = []
frcs = []
virs = []
xtrs = []

for ff_r in com_handles:
# if required, tries to extract multiple committe members from the extras JSON string
if "committee_pot" in ff_r["result"][3]:
pots += ff_r["result"][3]["committee_pot"]
if "committee_force" not in ff_r["result"][3]:
raise ValueError("JSON extras for committe potential misses `committee_force` entry")
frcs += ff_r["result"][3]["committee_force"]
if "committee_virial" not in ff_r["result"][3]:
raise ValueError("JSON extras for committe potential misses `committee_virial` entry")
virs += ff_r["result"][3]["committee_virial"]
ff_r["result"][3].pop("committee_pot")
ff_r["result"][3].pop("committee_force")
ff_r["result"][3].pop("committee_virial")
xtrs.append(ff_r["result"][3])
else:
pots.append(ff_r["result"][0])
frcs.append(ff_r["result"][1])
virs.append(ff_r["result"][2])
xtrs.append(ff_r["result"][3])
pots = np.array(pots)
frcs = np.array(frcs).reshape(len(pots), -1)
virs = np.array(virs).reshape(-1,3,3)

elif self.comm_type == "single-extras":
"""
if self.comm_type == "fudge single-extras":
# Check that we indeed have a single ff
if len(self.fflist) != 1 and self.baseline_name == "":
raise ValueError(
Expand Down Expand Up @@ -1219,9 +1239,7 @@ def gather(self, r):
# debug
# print("quants0 ", pots, frcs, virs, pots.shape, frcs.shape, virs.shape, pots.dtype, frcs.dtype, virs.dtype)
# MR: from now on everything else should be the same, hopefully

else:
raise OptionError("Committee option is unknown. Check possible options.")
"""

# Computes the mean energetics
mean_pot = np.mean(pots, axis=0)
Expand Down
20 changes: 8 additions & 12 deletions ipi/inputs/forcefields.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,9 +701,16 @@ class InputFFCommittee(InputForceField):
of energy and forces. These are averaged, and the mean used as the
actual forcefield. Statistics about the distribution are also returned
as extras fields, and can be printed for further postprocessing.
It is also possible for a single FF object to return a JSON-formatted
string containing entries `committee_pot`, `committee_force` and
`committee_virial`, that contain multiple members at once. These
will be unpacked and combined with whatever else is present.
Also contains options to use it for uncertainty estimation and for
active learning in a ML context, based on a committee model.
Implements the approaches discussed in DOI: 10.1063/5.0036522.
Implements the approaches discussed in
[Musil et al.](http://doi.org/10.1021/acs.jctc.8b00959)
and [Imbalzano et al.](http://doi.org/10.1063/5.0036522)
"""
default_label = "FFCOMMITTEE"

Expand Down Expand Up @@ -772,15 +779,6 @@ class InputFFCommittee(InputForceField):
"help": "Output filename for structures that exceed the accuracy threshold of the model, to be used in active learning.",
},
)
fields["committee_type"] = (
InputValue,
{
"dtype": str,
"options": ["default", "single-extras"],
"default": "default",
"help": "Chooses the type of committee communication. Default is from several separate sockets. single-extras expects many values in the extras string from one socket.",
},
)

def store(self, ff):
"""Store all the sub-forcefields"""
Expand All @@ -795,7 +793,6 @@ def store(self, ff):
self.baseline_name.store(ff.baseline_name)
self.active_thresh.store(ff.active_thresh)
self.active_output.store(ff.active_out)
self.committee_type.store(ff.comm_type)

for _ii, _obj in enumerate(_fflist):
if self.extra[_ii] == 0:
Expand Down Expand Up @@ -852,7 +849,6 @@ def fetch(self):
baseline_name=self.baseline_name.fetch(),
active_thresh=self.active_thresh.fetch(),
active_out=self.active_output.fetch(),
comm_type=self.committee_type.fetch(),
)


Expand Down

0 comments on commit b24f69f

Please sign in to comment.