Skip to content

Commit

Permalink
VPCEndpoint: add security group modification
Browse files Browse the repository at this point in the history
  • Loading branch information
SoenkeD committed May 3, 2024
1 parent 0e8d096 commit 3733040
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 6 deletions.
21 changes: 20 additions & 1 deletion moto/ec2/models/vpcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,8 @@ def modify(
add_subnets: Optional[List[str]],
add_route_tables: Optional[List[str]],
remove_route_tables: Optional[List[str]],
add_security_groups: Optional[List[str]],
remove_security_groups: Optional[List[str]],
) -> None:
if policy_doc:
self.policy_document = policy_doc
Expand All @@ -389,6 +391,14 @@ def modify(
for rt_id in self.route_table_ids
if rt_id not in remove_route_tables
]
if add_security_groups:
self.security_group_ids.extend(add_security_groups)
if remove_security_groups:
self.security_group_ids = [
sg_id
for sg_id in self.security_group_ids
if sg_id not in remove_security_groups
]

def get_filter_value(
self, filter_name: str, method_name: Optional[str] = None
Expand Down Expand Up @@ -966,9 +976,18 @@ def modify_vpc_endpoint(
add_subnets: Optional[List[str]],
remove_route_tables: Optional[List[str]],
add_route_tables: Optional[List[str]],
add_security_groups: Optional[List[str]],
remove_security_groups: Optional[List[str]],
) -> None:
endpoint = self.describe_vpc_endpoints(vpc_end_point_ids=[vpc_id])[0]
endpoint.modify(policy_doc, add_subnets, add_route_tables, remove_route_tables)
endpoint.modify(
policy_doc,
add_subnets,
add_route_tables,
remove_route_tables,
add_security_groups,
remove_security_groups,
)

def delete_vpc_endpoints(self, vpce_ids: Optional[List[str]] = None) -> None:
for vpce_id in vpce_ids or []:
Expand Down
27 changes: 22 additions & 5 deletions moto/ec2/responses/vpcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,17 @@ def modify_vpc_endpoint(self) -> str:
add_route_tables = self._get_multi_param("AddRouteTableId")
remove_route_tables = self._get_multi_param("RemoveRouteTableId")
policy_doc = self._get_param("PolicyDocument")
add_security_groups = self._get_multi_param("AddSecurityGroupId")
remove_security_groups = self._get_multi_param("RemoveSecurityGroupId")

self.ec2_backend.modify_vpc_endpoint(
vpc_id=vpc_id,
policy_doc=policy_doc,
add_subnets=add_subnets,
add_route_tables=add_route_tables,
remove_route_tables=remove_route_tables,
add_security_groups=add_security_groups,
remove_security_groups=remove_security_groups,
)
template = self.response_template(MODIFY_VPC_END_POINT)
return template.render()
Expand All @@ -251,9 +256,17 @@ def describe_vpc_endpoints(self) -> str:
vpc_end_points = self.ec2_backend.describe_vpc_endpoints(
vpc_end_point_ids=vpc_end_points_ids, filters=filters
)

security_group_ids = []
for service in vpc_end_points:
security_group_ids.extend(service.security_group_ids)
security_groups = self.ec2_backend.describe_security_groups(group_ids=security_group_ids)

template = self.response_template(DESCRIBE_VPC_ENDPOINT_RESPONSE)
return template.render(
vpc_end_points=vpc_end_points, account_id=self.current_account
vpc_end_points=vpc_end_points,
account_id=self.current_account,
security_groups=security_groups,
)

def delete_vpc_endpoints(self) -> str:
Expand Down Expand Up @@ -736,10 +749,14 @@ def modify_managed_prefix_list(self) -> str:
{% if vpc_end_point.security_group_ids %}
<groupSet>
{% for group_id in vpc_end_point.security_group_ids %}
<item>
<groupId>{{ group_id }}</groupId>
<groupName>TODO</groupName>
</item>
{% for sec_g in security_groups %}
{% if sec_g.id == group_id %}
<item>
<groupId>{{ group_id }}</groupId>
<groupName>{{ sec_g.name }}</groupName>
</item>
{% endif %}
{% endfor %}
{% endfor %}
</groupSet>
{% endif %}
Expand Down
18 changes: 18 additions & 0 deletions tests/test_ec2/test_vpcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,6 +1170,24 @@ def test_modify_vpc_endpoint():
endpoint = ec2.describe_vpc_endpoints(VpcEndpointIds=[vpc_id])["VpcEndpoints"][0]
assert endpoint["PolicyDocument"] == "doc"

sg_id_1 = ec2.create_security_group(GroupName="sg-1", VpcId=vpc_id, Description="sg-1")["GroupId"]
sg_id_2 = ec2.create_security_group(GroupName="sg-2", VpcId=vpc_id, Description="sg-2")["GroupId"]
ec2.modify_vpc_endpoint(
VpcEndpointId=vpc_id,
AddSecurityGroupIds=[sg_id_1, sg_id_2],
)
endpoint = ec2.describe_vpc_endpoints(VpcEndpointIds=[vpc_id])["VpcEndpoints"][0]
group_ids = endpoint.get("Groups", [])
assert any(group.get("GroupId") == sg_id_1 and group.get("GroupName") == "sg-1" for group in group_ids)
assert any(group.get("GroupId") == sg_id_2 and group.get("GroupName") == "sg-2" for group in group_ids)

ec2.modify_vpc_endpoint(
VpcEndpointId=vpc_id,
RemoveSecurityGroupIds=[sg_id_1],
)
endpoint = ec2.describe_vpc_endpoints(VpcEndpointIds=[vpc_id])["VpcEndpoints"][0]
assert any(group.get("GroupId") == sg_id_2 and group.get("GroupName") == "sg-2" for group in group_ids)


@mock_aws
def test_delete_vpc_end_points():
Expand Down

0 comments on commit 3733040

Please sign in to comment.