Skip to content

Commit

Permalink
v2 of md driver
Browse files Browse the repository at this point in the history
  • Loading branch information
shinkle-lanl committed May 31, 2024
1 parent 74b230a commit caec0ab
Show file tree
Hide file tree
Showing 3 changed files with 362 additions and 253 deletions.
61 changes: 30 additions & 31 deletions examples/molecular_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,19 @@
from hippynn.experiment.serialization import load_checkpoint_from_cwd
from hippynn.tools import active_directory
from hippynn.molecular_dynamics.md import (
StaticVariable,
DynamicVariable,
Variable,
NullUpdater,
VelocityVerlet,
MolecularDynamics,
)

# Adjust size of system depending on device
if torch.cuda.is_available():
nrep = 25
device = "cuda"
device = torch.device("cuda")
else:
nrep = 10
device = "cpu"
device = torch.device("cpu")

# Load the pre-trained model
try:
Expand Down Expand Up @@ -79,54 +79,56 @@
# NOTE: Setting the initial acceleration is only necessary to exactly match the results
# in `ase_example.py.` In general, it can be set to zero without impacting the statistics
# of the trajectory.
coordinates=torch.tensor(np.array([atoms.get_positions()]), device=device)
cell=torch.tensor(np.array([atoms.get_cell()]), device=device)
species=torch.tensor(np.array([atoms.get_atomic_numbers()]), device=device)
coordinates = torch.as_tensor(np.array(atoms.get_positions()), device=device).unsqueeze_(0) # add batch axis
init_velocity = torch.as_tensor(np.array(atoms.get_velocities())).unsqueeze_(0)
cell = torch.as_tensor(np.array(atoms.get_cell()), device=device).unsqueeze_(0)
species = torch.as_tensor(np.array(atoms.get_atomic_numbers()), device=device).unsqueeze_(0)
mass = torch.as_tensor(atoms.get_masses()).unsqueeze_(0).unsqueeze_(-1) # add a batch axis and a feature axis
init_force = model(
coordinates=coordinates,
cell=cell,
species=species,
)["force"]
init_acceleration = init_force / atoms.get_masses().reshape(1,-1,1)
)["force"]
init_force = torch.as_tensor(init_force)
init_acceleration = init_force / mass

# Define a position "Variable"
position_variable = DynamicVariable(
# Define a position "Variable" and set updater to "VelocityVerlet"
position_variable = Variable(
name="position",
starting_values={
"position": atoms.get_positions(),
"velocity": atoms.get_velocities(),
data={
"position": coordinates,
"velocity": init_velocity,
"acceleration": init_acceleration,
"mass": atoms.get_masses(),
"mass": mass,
"cell": cell, # if added, PBC will be applied in each step of the VelocityVerlet updater
},
model_input_map={
"coordinates": "position",
},
device=device,
updater=VelocityVerlet(force_key="force"),
)

# Set an "Updater" for the position variable
position_updater = VelocityVerlet(force_key="force")
position_variable.set_updater(position_updater)

# Define species and cell Variables
species_variable = StaticVariable(
species_variable = Variable(
name="species",
values={"values": atoms.get_atomic_numbers()},
model_input_map={"species": "values"},
data={"species": species},
model_input_map={"species": "species"},
device=device,
updater=NullUpdater(),
)

cell_variable = StaticVariable(
cell_variable = Variable(
name="cell",
values={"values": np.array(atoms.get_cell())},
model_input_map={"cell": "values"},
data={"cell": cell},
model_input_map={"cell": "cell"},
device=device,
updater=NullUpdater(),
)

# Set up MD driver
emdee = MolecularDynamics(
dynamic_variables=[position_variable],
static_variables=[species_variable, cell_variable],
variables=[position_variable, species_variable, cell_variable],
model=model,
)

Expand Down Expand Up @@ -156,10 +158,7 @@ def print(self, diff_steps=None, data=None):
# epot = self.atoms.get_potential_energy() / len(self.atoms)
ekin = atoms.get_kinetic_energy() / self.n_atoms
# stress = self.atoms.get_stress()
print(
"Energy per atom: Ekin = %.7feV (T=%3.0fK)"
% (ekin, ekin / (1.5 * units.kB))
)
print("Energy per atom: Ekin = %.7feV (T=%3.0fK)" % (ekin, ekin / (1.5 * units.kB)))

# Run MD!
tracker = Tracker()
Expand Down
2 changes: 1 addition & 1 deletion hippynn/databases/h5_pyanitools.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def filter_arrays(self, arr_dict, allow_unfound=False, quiet=False):


class PyAniFileDB(Database, PyAniMethods, Restartable):
def __init__(self, file, inputs, targets, *args, allow_unfound=False,species_key="species", quiet=False, **kwargs):
def __init__(self, file, inputs, targets, *args, allow_unfound=False, species_key="species", quiet=False, **kwargs):

self.file = file
self.inputs = inputs
Expand Down
Loading

0 comments on commit caec0ab

Please sign in to comment.