Skip to content

Commit

Permalink
Merge pull request #2 from zerothi/aniSileSiesta-changes-1
Browse files Browse the repository at this point in the history
mnt: fixed and simplified some logic in the read_multiple function
  • Loading branch information
tfrederiksen authored Mar 8, 2023
2 parents d7fd5af + e577063 commit a7f81b3
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 50 deletions.
8 changes: 6 additions & 2 deletions sisl/io/siesta/ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
from ..sile import add_sile
from sisl._internal import set_module


__all__ = ["aniSileSiesta"]


@set_module("sisl.io.siesta")
class aniSileSiesta(xyzSile):

pass
def read_geometry(*args, all=True, **kwargs):
return super().read_geometry(*args, all=all, **kwargs)


add_sile('ani', aniSileSiesta, case=False, gzip=True)
add_sile('ANI', aniSileSiesta, gzip=True)
1 change: 1 addition & 0 deletions sisl/io/siesta/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sisl import Geometry, Atom, AtomGhost, AtomUnknown, Atoms, SuperCell
from sisl.unit.siesta import unit_convert


__all__ = ['structSileSiesta']


Expand Down
117 changes: 70 additions & 47 deletions sisl/io/sile.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,8 @@ def pre_open(self, *args, **kwargs):


def sile_read_multiple(start=0, stop=1000000000, step=None, all=False,
skip_call=None, pre_call=None, post_call=None):
skip_call=None, pre_call=None, post_call=None,
postprocess=None):
""" Method decorator for doing multiple reads
Parameters
Expand All @@ -562,68 +563,90 @@ def sile_read_multiple(start=0, stop=1000000000, step=None, all=False,
all : bool, optional
read all entries (takes precedence over start/step)
skip_call : function, optional
read method without actual data processing
read method without actual data processing.
The skip call must something different from ``None`` if it succeeds.
pre_call : function, optional
function to be applied before each read call
function to be applied before each read call.
Does *not* apply to the `skip_call`.
post_call : function, optional
function to be applied after each read call
Does *not* apply to the `skip_call`.
postprocess : function, optional
function to be applied on the returned data before it gets returned.
Returns
-------
single entry or list from multiple reads
"""
def decorator(func):
# ensure we can access the callables
nonlocal skip_call, pre_call, post_call, postprocess

if pre_call is None:
def pre_call(): pass

if post_call is None:
def post_call(): pass

if postprocess is None:
def postprocess(x): return x

def wrap_func(*args, **kwargs):
pre_call()
r = func(*args, **kwargs)
post_call()
return r

# if skip_call is not set, we'll simply use func.
# Only in this case will pre and post be called
# as well.
# they should be equivalent but skip can be used to
# skip some processing of the read data.
if skip_call is None:
skip_call = wrap_func

@wraps(func)
def multiple(self, *args, start=start, stop=stop, step=step, all=all,
skip_call=skip_call, pre_call=pre_call, post_call=post_call, **kwargs):

if skip_call is None:
skip_call = func
if pre_call is None:
def pre_call(): pass
if post_call is None:
def post_call(): pass

def wrap_func(self, *args, **kwargs):
pre_call()
r = func(self, *args, **kwargs)
post_call()
return r

def wrap_skip_call(self, *args, **kwargs):
pre_call()
r = skip_call(self, *args, **kwargs)
post_call()
return r
def multiple(*args, start=start, stop=stop, step=step, all=all, **kwargs):
self = args[0]

if all:
start, stop, step = 0, 1000000000, 1
start, stop, step = 0, 100000000000, 1
if stop < 0:
# TODO this will probably fail if users do stop=-4 (to not have the last 3 items)
# It should be done in a different manner
stop = 100000000000

# start by reading past start
for _ in range(start):
r = skip_call(*args, **kwargs)

if step is None:
for _ in range(start):
wrap_skip_call(self, *args, **kwargs)
if hasattr(self, "fh"):
return wrap_func(self, *args, **kwargs)
with self:
return wrap_func(self, *args, **kwargs)

def loop(self):
rng = range(start, stop, step)
R, r, ir = [], True, 0
while r is not None:
if ir in rng:
r = wrap_func(self, *args, **kwargs)
if r is not None:
R.append(r)
else:
r = wrap_skip_call(self, *args, **kwargs)
ir += 1
return R
# read a single element and return
def ret_func():
return wrap_func(*args, **kwargs)

else:
# read a set of geometries and return
def ret_func():
R = []
for ir in range(start, stop, step):
r = wrap_func(*args, **kwargs)
if r is None:
return R
R.append(r)

# skip the middle steps
for _ in range(step-1):
r = skip_call(*args, **kwargs)
if r is None:
return R
return R

if hasattr(self, "fh"):
return loop(self)
with self:
return loop(self)
return postprocess(ret_func())
else:
with self:
return postprocess(ret_func())

return multiple
return decorator
Expand Down
2 changes: 2 additions & 0 deletions sisl/io/tests/test_xyz.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,5 @@ def test_xyz_multiple(sisl_tmp):
assert g[0].na == 1 and g[-1].na == 3
g = xyzSile(f).read_geometry(stop=2, step=1)
assert g[0].na == 1 and g[-1].na == 2

g = xyzSile(f).read_geometry(sc=None, atoms=None)
16 changes: 15 additions & 1 deletion sisl/io/xyz.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,21 @@ def _r_geometry(self, na, sp, xyz, sc):
sc = SuperCell(cell, nsc=[1] * 3)
return Geometry(xyz, atoms=sp, sc=sc)

def _r_geometry_skip(self, *args, **kwargs):
""" Read the geometry for a generic xyz file (not sisl, nor ASE) """
# The cell dimensions isn't defined, we are going to create a molecule box
line = self.readline()
if line == '':
return None

na = int(line)
self.readline()
for _ in range(na):
self.readline()
return na

@sile_fh_open()
@sile_read_multiple()
@sile_read_multiple(skip_call=_r_geometry_skip)
def read_geometry(self, atoms=None, sc=None):
""" Returns Geometry object from the XYZ file
Expand All @@ -102,6 +115,7 @@ def read_geometry(self, atoms=None, sc=None):
line = self.readline()
if line == '':
return None

# Read number of atoms
na = int(line)

Expand Down

0 comments on commit a7f81b3

Please sign in to comment.