From b24f69fb0b8b3a1ca5c0c9d9936c767dd3121c2a Mon Sep 17 00:00:00 2001 From: Michele Ceriotti Date: Sun, 10 Sep 2023 23:02:18 -0700 Subject: [PATCH] Trying a simpler implementation --- ipi/engine/forcefields.py | 86 +++++++++++++++++++++++---------------- ipi/inputs/forcefields.py | 20 ++++----- 2 files changed, 60 insertions(+), 46 deletions(-) diff --git a/ipi/engine/forcefields.py b/ipi/engine/forcefields.py index 8e53817eb..4038a26a6 100644 --- a/ipi/engine/forcefields.py +++ b/ipi/engine/forcefields.py @@ -1062,7 +1062,6 @@ def __init__( baseline_uncertainty=-1.0, active_thresh=0.0, active_out=None, - comm_type="default", ): # force threaded mode as otherwise it cannot have threaded children super(FFCommittee, self).__init__( @@ -1097,7 +1096,6 @@ def __init__( self.alpha = alpha self.active_thresh = active_thresh self.active_out = active_out - self.comm_type = comm_type def bind(self, output_maker): super(FFCommittee, self).bind(output_maker) @@ -1139,37 +1137,59 @@ def check_finish(self, r): def gather(self, r): """Collects results from all sub-requests, and assemble the committee of models.""" - if self.comm_type == "default": - r["result"] = [ - 0.0, - np.zeros(len(r["pos"]), float), - np.zeros((3, 3), float), - "", - ] - - # list of pointers to the forcefield requests. shallow copy so we can remove stuff - com_handles = r["ff_handles"].copy() - if self.baseline_name != "": - # looks for the baseline potential, store its value and drops it from the list - names = [ff.name for ff in self.fflist] - - for i, ff_r in enumerate(com_handles): - if names[i] == self.baseline_name: - baseline_pot = ff_r["result"][0] - baseline_frc = ff_r["result"][1] - baseline_vir = ff_r["result"][2] - baseline_xtr = ff_r["result"][3] - com_handles.pop(i) - break + r["result"] = [ + 0.0, + np.zeros(len(r["pos"]), float), + np.zeros((3, 3), float), + "", + ] - # Gathers the forcefield energetics and extras - pots = [ff_r["result"][0] for ff_r in com_handles] - frcs = [ff_r["result"][1] for ff_r in com_handles] - virs = [ff_r["result"][2] for ff_r in com_handles] - xtrs = [ff_r["result"][3] for ff_r in com_handles] - print("quants normal ", pots, frcs, virs) + # list of pointers to the forcefield requests. shallow copy so we can remove stuff + com_handles = r["ff_handles"].copy() + if self.baseline_name != "": + # looks for the baseline potential, store its value and drops it from the list + names = [ff.name for ff in self.fflist] + + for i, ff_r in enumerate(com_handles): + if names[i] == self.baseline_name: + baseline_pot = ff_r["result"][0] + baseline_frc = ff_r["result"][1] + baseline_vir = ff_r["result"][2] + baseline_xtr = ff_r["result"][3] + com_handles.pop(i) + break + + # Gathers the forcefield energetics and extras + pots = [] + frcs = [] + virs = [] + xtrs = [] + + for ff_r in com_handles: + # if required, tries to extract multiple committe members from the extras JSON string + if "committee_pot" in ff_r["result"][3]: + pots += ff_r["result"][3]["committee_pot"] + if "committee_force" not in ff_r["result"][3]: + raise ValueError("JSON extras for committe potential misses `committee_force` entry") + frcs += ff_r["result"][3]["committee_force"] + if "committee_virial" not in ff_r["result"][3]: + raise ValueError("JSON extras for committe potential misses `committee_virial` entry") + virs += ff_r["result"][3]["committee_virial"] + ff_r["result"][3].pop("committee_pot") + ff_r["result"][3].pop("committee_force") + ff_r["result"][3].pop("committee_virial") + xtrs.append(ff_r["result"][3]) + else: + pots.append(ff_r["result"][0]) + frcs.append(ff_r["result"][1]) + virs.append(ff_r["result"][2]) + xtrs.append(ff_r["result"][3]) + pots = np.array(pots) + frcs = np.array(frcs).reshape(len(pots), -1) + virs = np.array(virs).reshape(-1,3,3) - elif self.comm_type == "single-extras": + """ + if self.comm_type == "fudge single-extras": # Check that we indeed have a single ff if len(self.fflist) != 1 and self.baseline_name == "": raise ValueError( @@ -1219,9 +1239,7 @@ def gather(self, r): # debug # print("quants0 ", pots, frcs, virs, pots.shape, frcs.shape, virs.shape, pots.dtype, frcs.dtype, virs.dtype) # MR: from now on everything else should be the same, hopefully - - else: - raise OptionError("Committee option is unknown. Check possible options.") + """ # Computes the mean energetics mean_pot = np.mean(pots, axis=0) diff --git a/ipi/inputs/forcefields.py b/ipi/inputs/forcefields.py index b19dfde29..b77512369 100644 --- a/ipi/inputs/forcefields.py +++ b/ipi/inputs/forcefields.py @@ -701,9 +701,16 @@ class InputFFCommittee(InputForceField): of energy and forces. These are averaged, and the mean used as the actual forcefield. Statistics about the distribution are also returned as extras fields, and can be printed for further postprocessing. + It is also possible for a single FF object to return a JSON-formatted + string containing entries `committee_pot`, `committee_force` and + `committee_virial`, that contain multiple members at once. These + will be unpacked and combined with whatever else is present. + Also contains options to use it for uncertainty estimation and for active learning in a ML context, based on a committee model. - Implements the approaches discussed in DOI: 10.1063/5.0036522. + Implements the approaches discussed in + [Musil et al.](http://doi.org/10.1021/acs.jctc.8b00959) + and [Imbalzano et al.](http://doi.org/10.1063/5.0036522) """ default_label = "FFCOMMITTEE" @@ -772,15 +779,6 @@ class InputFFCommittee(InputForceField): "help": "Output filename for structures that exceed the accuracy threshold of the model, to be used in active learning.", }, ) - fields["committee_type"] = ( - InputValue, - { - "dtype": str, - "options": ["default", "single-extras"], - "default": "default", - "help": "Chooses the type of committee communication. Default is from several separate sockets. single-extras expects many values in the extras string from one socket.", - }, - ) def store(self, ff): """Store all the sub-forcefields""" @@ -795,7 +793,6 @@ def store(self, ff): self.baseline_name.store(ff.baseline_name) self.active_thresh.store(ff.active_thresh) self.active_output.store(ff.active_out) - self.committee_type.store(ff.comm_type) for _ii, _obj in enumerate(_fflist): if self.extra[_ii] == 0: @@ -852,7 +849,6 @@ def fetch(self): baseline_name=self.baseline_name.fetch(), active_thresh=self.active_thresh.fetch(), active_out=self.active_output.fetch(), - comm_type=self.committee_type.fetch(), )