Skip to content

Commit

Permalink
working md example and fully documented md driver code
Browse files Browse the repository at this point in the history
  • Loading branch information
shinkle-lanl committed May 8, 2024
1 parent cf85144 commit 639b664
Show file tree
Hide file tree
Showing 12 changed files with 308 additions and 117 deletions.
14 changes: 14 additions & 0 deletions examples/InPSNAPExample.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
"""
Example training to the SNAP database for Indium Phosphide.
This script was designed for an external dataset available at
https://github.com/FitSNAP/FitSNAP
For info on the dataset, see the following publication:
Explicit Multielement Extension of the Spectral Neighbor Analysis Potential for Chemically Complex Systems
M. A. Cusentino, M. A. Wood, and A. P. Thompson
The Journal of Physical Chemistry A 2020 124 (26), 5456-5464
DOI: 10.1021/acs.jpca.0c02450
"""

import numpy as np
import torch
torch.set_default_dtype(torch.float32)
Expand Down
2 changes: 1 addition & 1 deletion examples/ani1x_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import ase.units

import sys
sys.path.append("../../datasets/ani-al/readers/lib/")
sys.path.append("../../datasets/ani-al/readers/lib/") # location of pyanitools.py

import pyanitools

Expand Down
2 changes: 1 addition & 1 deletion examples/ani_aluminum_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import sys

sys.path.append("../../datasets/ani-al/readers/lib/")
sys.path.append("../../datasets/ani-al/readers/lib/") # location of pyanitools.py
import pyanitools # Check if pyanitools is found early

import torch
Expand Down
2 changes: 1 addition & 1 deletion examples/ani_aluminum_example_multilayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import sys

sys.path.append("../../datasets/ani-al/readers/lib/")
sys.path.append("../../datasets/ani-al/readers/lib/") # location of pyanitools.py
import pyanitools # Check if pyanitools is found early

import torch
Expand Down
14 changes: 8 additions & 6 deletions examples/ase_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""
Script running an aluminum model with ASE.
For training see `ani_aluminum_example.py`.
This will generate the files for a model.
Before running this script, you must run
`ani_aluminum_example.py` to train the corresponding
model.
Modified from ase MD example.
Expand All @@ -12,7 +14,6 @@
# Imports
import numpy as np
import torch
import hippynn
import ase
import time

Expand Down Expand Up @@ -45,15 +46,16 @@
nrep = 10 # 1,000 atoms.

# Build the atoms object
atoms = ase.build.bulk("Al", crystalstructure="fcc", a=4.05)
atoms = ase.build.bulk("Al", crystalstructure="fcc", a=4.05, orthorhombic=True)
reps = nrep * np.eye(3, dtype=int)
atoms = ase.build.make_supercell(atoms, reps, wrap=True)
atoms.calc = calc

print("Number of atoms:", len(atoms))

atoms.rattle(0.1)
MaxwellBoltzmannDistribution(atoms, temperature_K=500)
rng = np.random.default_rng(seed=0)
atoms.rattle(0.1, rng=rng)
MaxwellBoltzmannDistribution(atoms, temperature_K=500, rng=rng)
dyn = VelocityVerlet(atoms, 1 * units.fs)


Expand Down
20 changes: 11 additions & 9 deletions examples/ase_example_multilayer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""
Script running an aluminum model with ASE.
For training see `ani_aluminum_example.py`.
This will generate the files for a model.
Modified from ase MD example.
This script is designed to match the
LAMMPS script located at
./lammps/in.mliap.unified.hippynn.Al
Before running this script, you must run
`ani_aluminum_example_multilayer.py` to
train the corresponding model.
If a GPU is available, this script
will use it, and run a somewhat bigger system.
Modified from ase MD example.
"""

# Imports
Expand All @@ -29,7 +32,7 @@
with active_directory("TEST_ALUMINUM_MODEL_MULTILAYER", create=False):
bundle = load_checkpoint_from_cwd(map_location='cpu',restore_db=False)
except FileNotFoundError:
raise FileNotFoundError("Model not found, run ani_aluminum_example.py first!")
raise FileNotFoundError("Model not found, run ani_aluminum_example_multilayer.py first!")

model = bundle["training_modules"].model

Expand All @@ -38,15 +41,14 @@
energy_node = model.node_from_name("energy")
calc = HippynnCalculator(energy_node, en_unit=units.eV)
calc.to(torch.float64)

if torch.cuda.is_available():
nrep = 4
calc.to(torch.device('cuda'))
else:
nrep = 10

# Build the atoms object
atoms = FaceCenteredCubic(directions=np.eye(3, dtype=int),
size=(1,1,1), symbol='Al', pbc=(True,True,True))
nrep = 4
reps = nrep*np.eye(3, dtype=int)
atoms = ase.build.make_supercell(atoms, reps, wrap=True)
atoms.calc = calc
Expand Down
2 changes: 1 addition & 1 deletion examples/close_contact_finding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""
import sys

sys.path.append("../../datasets/ani-al/readers/lib/")
sys.path.append("../../datasets/ani-al/readers/lib/") # location of pyanitools.py
import pyanitools # Check if pyanitools is found early

### Loading the database
Expand Down
2 changes: 1 addition & 1 deletion examples/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@

import sys

sys.path.append("../../datasets/ani-al/readers/lib/")
sys.path.append("../../datasets/ani-al/readers/lib/") # location of pyanitools.py
import pyanitools # Check if pyanitools is found early
from hippynn.databases.h5_pyanitools import PyAniDirectoryDB

Expand Down
24 changes: 12 additions & 12 deletions examples/molecular_dynamics_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)

if torch.cuda.is_available():
nrep = 10
nrep = 30
device = "cuda"
else:
nrep = 10
Expand All @@ -52,23 +52,23 @@
energy_node = model.node_from_name("energy")
force_node = physics.GradientNode("force", (energy_node, positions_node), sign=-1)

# # Replace pair-finder with more efficient one (the HippynnCalculator also does this)
# old_pairs_node = model.node_from_name("PairIndexer")
# species_node = model.node_from_name("species")
# cell_node = model.node_from_name("cell")
# model.print_structure()
# # PositionsNode, Encoder, PaddingIndexer, CellNode
# new_pairs_node = KDTreePairsMemory("PairIndexer", parents=(positions_node, species_node, cell_node), skin=2, dist_hard_max=7.5)
# hippynn_node = model.node_from_name("HIPNN")
# print(hippynn_node.parents)
# replace_node(old_pairs_node, new_pairs_node)
# Replace pair-finder with more efficient one (the HippynnCalculator also does this)
old_pairs_node = model.node_from_name("PairIndexer")
species_node = model.node_from_name("species")
cell_node = model.node_from_name("cell")
model.print_structure()
# PositionsNode, Encoder, PaddingIndexer, CellNode
new_pairs_node = KDTreePairsMemory("PairIndexer", parents=(positions_node, species_node, cell_node), skin=2, dist_hard_max=7.5)
hippynn_node = model.node_from_name("HIPNN")
print(hippynn_node.parents)
replace_node(old_pairs_node, new_pairs_node)

model = Predictor(inputs=model.input_nodes, outputs=[force_node])
model.to(device)
model.to(torch.float64)

# Use ASE to generate initial positions and velocities
atoms = ase.build.bulk("Al", crystalstructure="fcc", a=4.05)
atoms = ase.build.bulk("Al", crystalstructure="fcc", a=4.05, orthorhombic=True)
reps = nrep * np.eye(3, dtype=int)
atoms = ase.build.make_supercell(atoms, reps, wrap=True)

Expand Down
2 changes: 2 additions & 0 deletions examples/singlet_triplet_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# NOTE: This script needs revision before it will run

import torch

# Setup pytorch things
Expand Down
2 changes: 1 addition & 1 deletion hippynn/layers/pairs/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def neighbor_list_torch(cutoff: float, coords, cell):
def neighbor_list_kdtree(cutoff, coords, cell):
'''
Use KD Tree implementation from scipy.spatial to find pairs under periodic boundary conditions
with an orthonormal cell.
with an orthorhombic cell.
'''

# Verify that cell is orthorhombic
Expand Down
Loading

0 comments on commit 639b664

Please sign in to comment.