Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zigzag indicator implementation using numba will close issue #443 and complete pr #693 #761

Merged
merged 3 commits into from
Jun 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions pandas_ta/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1599,17 +1599,17 @@ def xsignals(self, signal=None, xa=None, xb=None, above=None, long=None, asbool=
trend_offset=trend_offset, trend_reset=trend_reset, offset=offset, **kwargs)
return self._post_process(result, **kwargs)

# def zigzag(self, close=None, pivot_leg=None, price_deviation=None, retrace=None, last_extreme=None, offset=None, **kwargs: DictLike):
# high = self._get_column(kwargs.pop("high", "high"))
# low = self._get_column(kwargs.pop("low", "low"))
# if close is not None:
# close = self._get_column(kwargs.pop("close", "close"))
# result = zigzag(
# high=high, low=low, close=close,
# pivot_leg=pivot_leg, price_deviation=price_deviation,
# retrace=retrace, last_extreme=last_extreme,
# offset=offset, **kwargs)
# return self._post_process(result, **kwargs)
def zigzag(self, close=None, pivot_leg=None, price_deviation=None, retrace=None, last_extreme=None, offset=None, **kwargs: DictLike):
high = self._get_column(kwargs.pop("high", "high"))
low = self._get_column(kwargs.pop("low", "low"))
if close is not None:
close = self._get_column(kwargs.pop("close", "close"))
result = zigzag(
high=high, low=low, close=close,
pivot_leg=pivot_leg, price_deviation=price_deviation,
retrace=retrace, last_extreme=last_extreme,
offset=offset, **kwargs)
return self._post_process(result, **kwargs)

# Volatility
def aberration(self, length=None, atr_length=None, offset=None, **kwargs: DictLike):
Expand Down
2 changes: 1 addition & 1 deletion pandas_ta/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
"adx", "alphatrend", "amat", "aroon", "chop", "cksp", "decay",
"decreasing", "dpo", "increasing", "long_run", "psar", "qstick",
"rwi", "short_run", "trendflex", "tsignals", "ttm_trend", "vhf",
"vortex", "xsignals"
"vortex", "xsignals", "zigzag"
],
# Volatility
"volatility": [
Expand Down
4 changes: 3 additions & 1 deletion pandas_ta/trend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .vhf import vhf
from .vortex import vortex
from .xsignals import xsignals
from .zigzag import zigzag

__all__ = [
"adx",
Expand All @@ -42,5 +43,6 @@
"ttm_trend",
"vhf",
"vortex",
"xsignals"
"xsignals",
"zigzag",
]
150 changes: 139 additions & 11 deletions pandas_ta/trend/zigzag.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# from numpy import isnan, nan, zeros
from numpy import isnan, nan, zeros, zeros_like, floor
from numba import njit
from pandas import Series
from pandas import Series, DataFrame
from pandas_ta._typing import DictLike, Int, IntFloat
from pandas_ta.utils import (
v_bool,
Expand All @@ -11,6 +11,113 @@
)


@njit
def np_rolling_hl(highs, lows, window_size):

num_extremum = 0
candles_len = len(highs)
rollings_idx = zeros(candles_len)
rollings_types = zeros(candles_len)
rollings_values = zeros(candles_len)

left_side, right_side = int(floor(window_size / 2)), int(floor(window_size / 2)) + 1
for i in range(left_side, candles_len - right_side): # sample_array = [*[left-window], *[center], *[right-window]]
lows_center = lows[i]
highs_center = highs[i]
lows_window = lows[i - left_side: i + right_side]
highs_window = highs[i - left_side: i + right_side]

if (lows_center <= lows_window).all():
rollings_idx[num_extremum] = i
rollings_types[num_extremum] = -1 # This -1 means it's a low swing
rollings_values[num_extremum] = lows_center
num_extremum += 1
if (highs_center >= highs_window).all():
rollings_idx[num_extremum] = i
rollings_types[num_extremum] = 1 # This 1 means it's a high swing
rollings_values[num_extremum] = highs_center
num_extremum += 1
return rollings_idx[:num_extremum], rollings_types[:num_extremum], rollings_values[:num_extremum]


@njit
def np_find_zigzags(rolling_idx, rolling_types, rolling_values, deviation):
rolling_len, num_zigzag = len(rolling_idx), 0

zigzag_idx = zeros_like(rolling_idx)
zigzag_types = zeros_like(rolling_types)
zigzag_values = zeros_like(rolling_values)
zigzag_dev = zeros(rolling_len)

zigzag_idx[num_zigzag] = rolling_idx[-1]
zigzag_types[num_zigzag] = rolling_types[-1]
zigzag_values[num_zigzag] = rolling_values[-1]
zigzag_dev[num_zigzag] = 0

for i in range(rolling_len - 2, -1, -1):
# last point in zigzag is bottom
if zigzag_types[num_zigzag] == -1:
if rolling_types[i] == -1:
if zigzag_values[num_zigzag] > rolling_values[i] and num_zigzag > 1:
current_deviation = (zigzag_values[num_zigzag - 1] - rolling_values[i]) / rolling_values[i]
zigzag_idx[num_zigzag] = rolling_idx[i]
zigzag_types[num_zigzag] = rolling_types[i]
zigzag_values[num_zigzag] = rolling_values[i]
zigzag_dev[num_zigzag - 1] = 100 * current_deviation
else:
current_deviation = (rolling_values[i] - zigzag_values[num_zigzag]) / rolling_values[i]
if current_deviation > deviation / 100:
if zigzag_idx[num_zigzag] == rolling_idx[i]:
continue
num_zigzag += 1
zigzag_idx[num_zigzag] = rolling_idx[i]
zigzag_types[num_zigzag] = rolling_types[i]
zigzag_values[num_zigzag] = rolling_values[i]
zigzag_dev[num_zigzag - 1] = 100 * current_deviation

# last point in zigzag is peak
else:
if rolling_types[i] == 1:
if zigzag_values[num_zigzag] < rolling_values[i] and num_zigzag > 1:
current_deviation = (rolling_values[i] - zigzag_values[num_zigzag - 1]) / rolling_values[i]
zigzag_idx[num_zigzag] = rolling_idx[i]
zigzag_types[num_zigzag] = rolling_types[i]
zigzag_values[num_zigzag] = rolling_values[i]
zigzag_dev[num_zigzag - 1] = 100 * current_deviation
else:
current_deviation = (zigzag_values[num_zigzag] - rolling_values[i]) / rolling_values[i]
if current_deviation > deviation / 100:
if zigzag_idx[num_zigzag] == rolling_idx[i]:
continue
num_zigzag += 1
zigzag_idx[num_zigzag] = rolling_idx[i]
zigzag_types[num_zigzag] = rolling_types[i]
zigzag_values[num_zigzag] = rolling_values[i]
zigzag_dev[num_zigzag - 1] = 100 * current_deviation

return zigzag_idx[:num_zigzag + 1], zigzag_types[:num_zigzag + 1], \
zigzag_values[:num_zigzag + 1], zigzag_dev[:num_zigzag + 1]


@njit
def map_zigzag(zigzag_idx, zigzag_types, zigzag_values, zigzag_dev, candles_num):
_values = zeros(candles_num)
_types = zeros(candles_num)
_dev = zeros(candles_num)

for i, index in enumerate(zigzag_idx):
_values[int(index)] = zigzag_values[i]
_types[int(index)] = zigzag_types[i]
_dev[int(index)] = zigzag_dev[i]

for i in range(candles_num):
if _types[i] == 0:
_values[i] = nan
_types[i] = nan
_dev[i] = nan
return _types, _values, _dev


def zigzag(
high: Series, low: Series, close: Series = None,
pivot_leg: int = None, price_deviation: IntFloat = None,
Expand Down Expand Up @@ -51,6 +158,7 @@ def zigzag(
pd.DataFrame: swing, and swing_type (high or low).
"""
# Validate
_length = 0
pivot_leg = _length = v_pos_default(pivot_leg, 10)
high = v_series(high, _length + 1)
low = v_series(low, _length + 1)
Expand All @@ -71,16 +179,36 @@ def zigzag(

# Calculation
np_high, np_low = high.values, low.values
highest_high = high.rolling(window=pivot_leg, center=True, min_periods=0).max()
lowest_low = low.rolling(window=pivot_leg, center=True, min_periods=0).min()

# Fix and fill working code
_rollings_idx, _rollings_types, _rollings_values = np_rolling_hl(highs=np_high, lows=np_low, window_size=pivot_leg)
_zigzags_idx, _zigzags_types, _zigzags_values, _zigzags_dev = np_find_zigzags(_rollings_idx, _rollings_types,
_rollings_values,
deviation=price_deviation)
_types, _values, _dev = map_zigzag(_zigzags_idx, _zigzags_types, _zigzags_values, _zigzags_dev, len(high))

# Offset
# if offset != 0:
if offset != 0:
_types = _types.shift(offset)
_values = _values.shift(offset)
_dev = _dev.shift(offset)

# Fill
# if "fillna" in kwargs:
# if "fill_method" in kwargs:

# Name and Category
if "fillna" in kwargs:
_types.fillna(kwargs["fillna"], inplace=True)
_values.fillna(kwargs["fillna"], inplace=True)
_dev.fillna(kwargs["fillna"], inplace=True)
if "fill_method" in kwargs:
_types.fillna(method=kwargs["fill_method"], inplace=True)
_values.fillna(method=kwargs["fill_method"], inplace=True)
_dev.fillna(method=kwargs["fill_method"], inplace=True)

_params = f"_{price_deviation}%_{pivot_leg}"
data = {
f"ZIGZAGt{_params}": _types,
f"ZIGZAGv{_params}": _values,
f"ZIGZAGd{_params}": _dev,
}
df = DataFrame(data, index=high.index)
df.name = f"ZIGZAG{_params}"
df.category = "trend"

return df