Skip to content

Commit

Permalink
Flak8 refactoring. Fix minor bugs and close tags
Browse files Browse the repository at this point in the history
  • Loading branch information
roaffix committed Nov 14, 2020
1 parent 28a0f61 commit c20136f
Show file tree
Hide file tree
Showing 18 changed files with 496 additions and 600 deletions.
47 changes: 30 additions & 17 deletions arrayfire/algorithm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#######################################################
# Copyright (c) 2019, ArrayFire
# Copyright (c) 2020, ArrayFire
# All rights reserved.
#
# This file is distributed under 3-clause BSD license.
Expand All @@ -14,11 +14,13 @@
from .array import Array
from .library import backend, safe_call, BINARYOP, c_bool_t, c_double_t, c_int_t, c_pointer, c_uint_t


def _parallel_dim(a, dim, c_func):
out = Array()
safe_call(c_func(c_pointer(out.arr), a.arr, c_int_t(dim)))
return out


def _reduce_all(a, c_func):
real = c_double_t(0)
imag = c_double_t(0)
Expand All @@ -29,11 +31,13 @@ def _reduce_all(a, c_func):
imag = imag.value
return real if imag == 0 else real + imag * 1j


def _nan_parallel_dim(a, dim, c_func, nan_val):
out = Array()
safe_call(c_func(c_pointer(out.arr), a.arr, c_int_t(dim), c_double_t(nan_val)))
return out


def _nan_reduce_all(a, c_func, nan_val):
real = c_double_t(0)
imag = c_double_t(0)
Expand All @@ -44,6 +48,7 @@ def _nan_reduce_all(a, c_func, nan_val):
imag = imag.value
return real if imag == 0 else real + imag * 1j


def _FNSD(dim, dims):
if dim >= 0:
return int(dim)
Expand All @@ -55,20 +60,26 @@ def _FNSD(dim, dims):
break
return int(fnsd)


def _rbk_dim(keys, vals, dim, c_func):
keys_out = Array()
vals_out = Array()
rdim = _FNSD(dim, vals.dims())
safe_call(c_func(c_pointer(keys_out.arr), c_pointer(vals_out.arr), keys.arr, vals.arr, c_int_t(rdim)))
return keys_out, vals_out


def _nan_rbk_dim(a, dim, c_func, nan_val):
keys_out = Array()
vals_out = Array()
# FIXME: vals is undefined
rdim = _FNSD(dim, vals.dims())
safe_call(c_func(c_pointer(keys_out.arr), c_pointer(vals_out.arr), keys.arr, vals.arr, c_int_t(rdim), c_double_t(nan_val)))
# FIXME: keys is undefined
safe_call(c_func(
c_pointer(keys_out.arr), c_pointer(vals_out.arr), keys.arr, vals.arr, c_int_t(rdim), c_double_t(nan_val)))
return keys_out, vals_out


def sum(a, dim=None, nan_val=None):
"""
Calculate the sum of all the elements along a specified dimension.
Expand All @@ -88,18 +99,16 @@ def sum(a, dim=None, nan_val=None):
The sum of all elements in `a` along dimension `dim`.
If `dim` is `None`, sum of the entire Array is returned.
"""
if nan_val is not None:
if dim is not None:
if nan_val:
if dim:
return _nan_parallel_dim(a, dim, backend.get().af_sum_nan, nan_val)
return _nan_reduce_all(a, backend.get().af_sum_nan_all, nan_val)

if dim is not None:
if dim:
return _parallel_dim(a, dim, backend.get().af_sum)
return _reduce_all(a, backend.get().af_sum_all)




def sumByKey(keys, vals, dim=-1, nan_val=None):
"""
Calculate the sum of elements along a specified dimension according to a key.
Expand All @@ -122,10 +131,10 @@ def sumByKey(keys, vals, dim=-1, nan_val=None):
values: af.Array or scalar number
The sum of all elements in `vals` along dimension `dim` according to keys
"""
if (nan_val is not None):
if nan_val:
return _nan_rbk_dim(keys, vals, dim, backend.get().af_sum_by_key_nan, nan_val)
else:
return _rbk_dim(keys, vals, dim, backend.get().af_sum_by_key)
return _rbk_dim(keys, vals, dim, backend.get().af_sum_by_key)


def product(a, dim=None, nan_val=None):
"""
Expand Down Expand Up @@ -178,10 +187,10 @@ def productByKey(keys, vals, dim=-1, nan_val=None):
values: af.Array or scalar number
The product of all elements in `vals` along dimension `dim` according to keys
"""
if (nan_val is not None):
if nan_val is not None:
return _nan_rbk_dim(keys, vals, dim, backend.get().af_product_by_key_nan, nan_val)
else:
return _rbk_dim(keys, vals, dim, backend.get().af_product_by_key)
return _rbk_dim(keys, vals, dim, backend.get().af_product_by_key)


def min(a, dim=None):
"""
Expand Down Expand Up @@ -227,6 +236,7 @@ def minByKey(keys, vals, dim=-1):
"""
return _rbk_dim(keys, vals, dim, backend.get().af_min_by_key)


def max(a, dim=None):
"""
Find the maximum value of all the elements along a specified dimension.
Expand Down Expand Up @@ -271,6 +281,7 @@ def maxByKey(keys, vals, dim=-1):
"""
return _rbk_dim(keys, vals, dim, backend.get().af_max_by_key)


def all_true(a, dim=None):
"""
Check if all the elements along a specified dimension are true.
Expand Down Expand Up @@ -315,6 +326,7 @@ def allTrueByKey(keys, vals, dim=-1):
"""
return _rbk_dim(keys, vals, dim, backend.get().af_all_true_by_key)


def any_true(a, dim=None):
"""
Check if any the elements along a specified dimension are true.
Expand All @@ -334,8 +346,8 @@ def any_true(a, dim=None):
"""
if dim is not None:
return _parallel_dim(a, dim, backend.get().af_any_true)
else:
return _reduce_all(a, backend.get().af_any_true_all)
return _reduce_all(a, backend.get().af_any_true_all)


def anyTrueByKey(keys, vals, dim=-1):
"""
Expand All @@ -359,6 +371,7 @@ def anyTrueByKey(keys, vals, dim=-1):
"""
return _rbk_dim(keys, vals, dim, backend.get().af_any_true_by_key)


def count(a, dim=None):
"""
Count the number of non zero elements in an array along a specified dimension.
Expand All @@ -378,8 +391,7 @@ def count(a, dim=None):
"""
if dim is not None:
return _parallel_dim(a, dim, backend.get().af_count)
else:
return _reduce_all(a, backend.get().af_count_all)
return _reduce_all(a, backend.get().af_count_all)


def countByKey(keys, vals, dim=-1):
Expand All @@ -404,6 +416,7 @@ def countByKey(keys, vals, dim=-1):
"""
return _rbk_dim(keys, vals, dim, backend.get().af_count_by_key)


def imin(a, dim=None):
"""
Find the value and location of the minimum value along a specified dimension
Expand Down
14 changes: 4 additions & 10 deletions arrayfire/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def cast(a, dtype):
out : af.Array
array containing the values from `a` after converting to `dtype`.
"""
out=Array()
out = Array()
safe_call(backend.get().af_cast(c_pointer(out.arr), a.arr, dtype.value))
return out

Expand Down Expand Up @@ -156,15 +156,8 @@ def clamp(val, low, high):
vdims = dim4_to_tuple(val.dims())
vty = val.type()

if not is_low_array:
low_arr = constant_array(low, vdims[0], vdims[1], vdims[2], vdims[3], vty)
else:
low_arr = low.arr

if not is_high_array:
high_arr = constant_array(high, vdims[0], vdims[1], vdims[2], vdims[3], vty)
else:
high_arr = high.arr
low_arr = low.arr if is_low_array else constant_array(low, vdims[0], vdims[1], vdims[2], vdims[3], vty)
high_arr = high.arr if is_high_array else constant_array(high, vdims[0], vdims[1], vdims[2], vdims[3], vty)

safe_call(backend.get().af_clamp(c_pointer(out.arr), val.arr, low_arr, high_arr, _bcast_var.get()))

Expand Down Expand Up @@ -1003,6 +996,7 @@ def sqrt(a):
"""
return _arith_unary_func(a, backend.get().af_sqrt)


def rsqrt(a):
"""
Reciprocal or inverse square root of each element in the array.
Expand Down
Loading

0 comments on commit c20136f

Please sign in to comment.