Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nsxt: Typing fixes for Python 3.6 compatibility. #353

Merged
merged 1 commit into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 29 additions & 28 deletions capirca/lib/nsxt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@

import datetime
import json
from typing import Literal, TypedDict, Optional, Union, Tuple
from typing import Optional, Union, Tuple, List

from absl import logging
from capirca.lib import aclgenerator
from capirca.lib import nacaddr
from capirca.lib import policy # for typing information
from typing_extensions import Literal, TypedDict # pylint: disable=g-multiple-import

_ACTION_TABLE = {
'accept': 'ALLOW',
Expand Down Expand Up @@ -90,9 +91,9 @@ class NsxtUnsupportedManyPoliciesError(Error):
class ServiceEntries:
"""Represents service entries for a rule."""

def __init__(self, protocol: int, source_ports: list[Tuple[str, str]],
destination_ports: list[Tuple[str, str]],
icmp_types: list[int]):
def __init__(self, protocol: int, source_ports: List[Tuple[str, str]],
destination_ports: List[Tuple[str, str]],
icmp_types: List[int]):
"""Setting things up.

Args:
Expand Down Expand Up @@ -239,8 +240,8 @@ def __str__(self):
af_list = [self.af]

# There can be many source and destination addresses.
source_address: list[nacaddr.IPType] = []
destination_address: list[nacaddr.IPType] = []
source_address: List[nacaddr.IPType] = []
destination_address: List[nacaddr.IPType] = []
source_addr = []
destination_addr = []

Expand All @@ -257,32 +258,32 @@ def __str__(self):
# cannot be a part of a netblock passed into NSX-T API. Currently only
# addressing IPv4 as that's where the issue has been identified.
# https://github.com/google/capirca/issues/348
zero_ip_address: list[nacaddr.IPType] = []
zero_ip_address: List[nacaddr.IPType] = []
if af == 4:
zero_ip_address: list[nacaddr.IPType] = [nacaddr.IP('0.0.0.0/32')]
zero_ip_address: List[nacaddr.IPType] = [nacaddr.IP('0.0.0.0/32')]

# source address
if self.term.source_address:
source_address: list[nacaddr.IPType] = self.term.GetAddressOfVersion(
source_address: List[nacaddr.IPType] = self.term.GetAddressOfVersion(
'source_address', af)
source_address_exclude: list[nacaddr.IPType] = (
source_address_exclude: List[nacaddr.IPType] = (
self.term.GetAddressOfVersion('source_address_exclude', af))

if source_address_exclude:
source_address: list[nacaddr.IPType] = nacaddr.ExcludeAddrs(
source_address: List[nacaddr.IPType] = nacaddr.ExcludeAddrs(
source_address,
source_address_exclude + zero_ip_address)
else:
if (af == 4 and source_address and
'0.0.0.0/0' not in [str(a) for a in source_address]):
# Exclude 0.0.0.0/32, removing 0.0.0.0/anything netblocks. However,
# do so only if we would not already have 'ANY' in the list.
source_address: list[nacaddr.IPType] = nacaddr.ExcludeAddrs(
source_address: List[nacaddr.IPType] = nacaddr.ExcludeAddrs(
source_address, zero_ip_address)
if source_address:
if af == 4:
source_address: list[nacaddr.IPv4]
source_v4_addr: list[nacaddr.IPv4] = source_address
source_address: List[nacaddr.IPv4]
source_v4_addr: List[nacaddr.IPv4] = source_address
if (source_v4_addr and
'0.0.0.0/0' in [str(a) for a in source_address]):
# Once we make the address list empty, it'll be set to ANY later
Expand All @@ -292,35 +293,35 @@ def __str__(self):
# later, we'll correctly not use ANY.)
#
# See https://github.com/google/capirca/issues/348
source_v4_addr: list[nacaddr.IPv4] = []
source_v4_addr: List[nacaddr.IPv4] = []
else:
source_address: list[nacaddr.IPv6]
source_v6_addr: list[nacaddr.IPv6] = source_address
source_address: List[nacaddr.IPv6]
source_v6_addr: List[nacaddr.IPv6] = source_address
source_addr = source_v4_addr + source_v6_addr

# destination address
if self.term.destination_address:
destination_address: list[
destination_address: List[
nacaddr.IPType] = self.term.GetAddressOfVersion(
'destination_address', af)
destination_address_exclude: list[nacaddr.IPType] = (
destination_address_exclude: List[nacaddr.IPType] = (
self.term.GetAddressOfVersion('destination_address_exclude', af))

if destination_address_exclude:
destination_address: list[nacaddr.IPType] = nacaddr.ExcludeAddrs(
destination_address: List[nacaddr.IPType] = nacaddr.ExcludeAddrs(
destination_address,
destination_address_exclude + zero_ip_address)
else:
if (af == 4 and source_address and
'0.0.0.0/0' not in [str(a) for a in source_address]):
# Exclude 0.0.0.0/32, removing 0.0.0.0/anything netblocks. However,
# do so only if we would not already have 'ANY' in the list.
destination_address: list[nacaddr.IPType] = nacaddr.ExcludeAddrs(
destination_address: List[nacaddr.IPType] = nacaddr.ExcludeAddrs(
destination_address, zero_ip_address)
if destination_address:
if af == 4:
destination_address: list[nacaddr.IPv4]
dest_v4_addr: list[nacaddr.IPv4] = destination_address
destination_address: List[nacaddr.IPv4]
dest_v4_addr: List[nacaddr.IPv4] = destination_address
if (dest_v4_addr and
'0.0.0.0/0' in [str(a) for a in destination_address]):
# Once we make the address list empty, it'll be set to ANY later
Expand All @@ -330,10 +331,10 @@ def __str__(self):
# later, we'll correctly not use ANY.)
#
# See https://github.com/google/capirca/issues/348
dest_v4_addr: list[nacaddr.IPv4] = []
dest_v4_addr: List[nacaddr.IPv4] = []
else:
destination_address: list[nacaddr.IPv6]
dest_v6_addr: list[nacaddr.IPv6] = destination_address
destination_address: List[nacaddr.IPv6]
dest_v6_addr: List[nacaddr.IPv6] = destination_address
destination_addr = dest_v4_addr + dest_v6_addr

# Check for mismatch IP for source and destination address for mixed filter
Expand Down Expand Up @@ -420,13 +421,13 @@ class Nsxt(aclgenerator.ACLGenerator):
_FILTER_OPTIONS_DICT = {}

def _TranslatePolicy(self, pol: policy.Policy, exp_info: int):
self.nsxt_policies: list[Tuple[policy.Header, str, list[Term]]] = []
self.nsxt_policies: List[Tuple[policy.Header, str, List[Term]]] = []
current_date = datetime.datetime.utcnow().date()

# Warn about policies that will expire in less than exp_info weeks.
exp_info_date = current_date + datetime.timedelta(weeks=exp_info)

filters: list[Tuple[policy.Header, list[policy.Term]]] = pol.filters
filters: List[Tuple[policy.Header, List[policy.Term]]] = pol.filters
for header, terms in filters:
if self._PLATFORM not in header.platforms:
continue
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ absl-py
ply
PyYAML
six>=1.12.0
typing_extensions
17 changes: 9 additions & 8 deletions tests/lib/nsxt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import copy
import json
from typing import Any, Literal, Tuple, Union
from typing import Any, Tuple, Union, Dict, List
from unittest import mock

from absl.testing import absltest
Expand All @@ -25,6 +25,7 @@
from capirca.lib import naming
from capirca.lib import nsxt
from capirca.lib import policy
from typing_extensions import Literal


ICMPV6_TERM = """\
Expand Down Expand Up @@ -864,14 +865,14 @@ class TestTrafficKindGrid(parameterized.TestCase):

# Which address set should be put into the policy, based on the type of policy
# we're testing?
KIND_TO_ADDRESS: dict[_TRAFFIC_KIND, _ADDRESSES] = {
KIND_TO_ADDRESS: Dict[_TRAFFIC_KIND, _ADDRESSES] = {
'mixed': 'GOOGLE_DNS',
'v4': 'INTERNAL_V4',
'v6': 'INTERNAL_V6'}

# Which expanded address group (e.g. netblocks) is expected, based on the type
# of policy we're testing?
KIND_TO_ADDRESS_GROUPS: dict[
KIND_TO_ADDRESS_GROUPS: Dict[
_TRAFFIC_KIND, Union[nacaddr.IPv4, nacaddr.IPv6, Literal['ANY']]] = {
# 'GOOGLE_DNS'
'mixed': [nacaddr.IP('8.8.4.4/32'), nacaddr.IP('8.8.8.8/32'),
Expand Down Expand Up @@ -961,11 +962,11 @@ def test_generator_works(self):
' destination-address:: INTERNAL_V6',
'}']))

def get_source_dest_addresses(self, nsxt_json: dict[str, Any]) -> (
Tuple[list[str], list[str]]):
rules: list[dict[str, Any]] = nsxt_json['rules']
src: list[str] = []
dst: list[str] = []
def get_source_dest_addresses(self, nsxt_json: Dict[str, Any]) -> (
Tuple[List[str], List[str]]):
rules: List[Dict[str, Any]] = nsxt_json['rules']
src: List[str] = []
dst: List[str] = []

for rule in rules:
src.extend(i for i in rule['source_groups'])
Expand Down
Loading