From a719e8d37616d09c64cd3031226f347faddcd1c1 Mon Sep 17 00:00:00 2001 From: Yuh Shin Ong Date: Fri, 15 Dec 2023 13:06:46 -0800 Subject: [PATCH] Parser: Support class intervals Summary: Support parsing class intervals. In the analysis, intervals are represented as: LocalTaint -> Kind -> Interval -> Frame In the json output, it is: LocalTaint -> Frame, where Frame contains the kind and interval. It's not as nice from a "reflect the underlying representation" perspective, but it reduces nesting and improves readability. I'm inclined to stick with the current representation. In SAPP, a trace_frame is uniquely identified by "caller, callee, caller port, callee port, interval". Therefore, the parser has to do some interval deduplication (see `Kind.partition_by_interval(...)`). A similar deduplication logic already exists for callees in origin frames. This diff adds "intervals" to the mix. Reviewed By: anwesht Differential Revision: D51931277 fbshipit-source-id: 75a8629682dc41e5884287e0eb29807e21aad438 --- sapp/pipeline/mariana_trench_parser.py | 192 ++++++---- .../tests/test_mariana_trench_parser.py | 342 ++++++++++++++++++ 2 files changed, 469 insertions(+), 65 deletions(-) diff --git a/sapp/pipeline/mariana_trench_parser.py b/sapp/pipeline/mariana_trench_parser.py index 65afe288..9450ee55 100644 --- a/sapp/pipeline/mariana_trench_parser.py +++ b/sapp/pipeline/mariana_trench_parser.py @@ -7,6 +7,8 @@ import logging import re import sys + +from collections import defaultdict from typing import ( Any, Dict, @@ -320,6 +322,8 @@ class Kind(NamedTuple): distance: int origins: List[Origin] extra_traces: List[ExtraTrace] + callee_interval: Optional[Tuple[int, int]] + preserves_type_context: bool @staticmethod def from_json( @@ -331,13 +335,29 @@ def from_json( extra_traces = [] for extra_trace in kind.get("extra_traces", []): extra_traces.append(ExtraTrace.from_json(extra_trace, caller_position)) + interval = kind.get("callee_interval") return Kind( name=kind["kind"], distance=kind.get("distance", 0), origins=origins, extra_traces=extra_traces, + callee_interval=(interval[0], interval[1]) if interval else None, + preserves_type_context=kind.get("preserves_type_context", False), ) + @staticmethod + def partition_by_interval( + kinds: List["Kind"], + ) -> Dict[Optional["ConditionTypeInterval"], List["Kind"]]: + kinds_by_interval = defaultdict(list) + for kind in kinds: + if kind.callee_interval is None: + kinds_by_interval[None].append(kind) + else: + interval = ConditionTypeInterval.from_kind(kind) + kinds_by_interval[interval].append(kind) + return kinds_by_interval + class ConditionLeaf(NamedTuple): kind: str @@ -375,6 +395,29 @@ def from_origin(origin: Origin, call_info: CallInfo) -> "ConditionCall": ) +class ConditionTypeInterval(NamedTuple): + start: int + finish: int + preserves_type_context: bool + + @staticmethod + def from_kind(kind: Kind) -> "ConditionTypeInterval": + if kind.callee_interval is None: + raise sapp.ParseError(f"Callee interval expected in {kind}") + return ConditionTypeInterval( + start=kind.callee_interval[0], + finish=kind.callee_interval[1], + preserves_type_context=kind.preserves_type_context, + ) + + def to_sapp(self) -> sapp.ParseTypeInterval: + return sapp.ParseTypeInterval( + start=self.start, + finish=self.finish, + preserves_type_context=self.preserves_type_context, + ) + + class Condition(NamedTuple): caller: ConditionCall callee: ConditionCall @@ -382,6 +425,7 @@ class Condition(NamedTuple): local_positions: LocalPositions features: Features extra_traces: Set[ExtraTrace] + type_interval: Optional[ConditionTypeInterval] def convert_to_sapp( self, kind: Literal[sapp.ParseType.PRECONDITION, sapp.ParseType.POSTCONDITION] @@ -394,7 +438,9 @@ def convert_to_sapp( callee=self.callee.method.name, callee_port=self.callee.port.value, callee_location=self.callee.position.to_sapp(), - type_interval=None, + type_interval=( + self.type_interval.to_sapp() if self.type_interval else None + ), features=self.features.to_sapp_as_parsetracefeature(), titos=self.local_positions.to_sapp(), leaves=[leaf.to_sapp() for leaf in self.leaves], @@ -426,6 +472,7 @@ class IssueCondition(NamedTuple): local_positions: LocalPositions features: Features extra_traces: Set[ExtraTrace] + type_interval: Optional[ConditionTypeInterval] def to_sapp(self) -> sapp.ParseIssueConditionTuple: return sapp.ParseIssueConditionTuple( @@ -435,7 +482,9 @@ def to_sapp(self) -> sapp.ParseIssueConditionTuple: leaves=[leaf.to_sapp() for leaf in self.leaves], titos=self.local_positions.to_sapp(), features=self.features.to_sapp_as_parsetracefeature(), - type_interval=None, + type_interval=( + self.type_interval.to_sapp() if self.type_interval else None + ), annotations=[extra_trace.to_sapp() for extra_trace in self.extra_traces], ) @@ -652,10 +701,12 @@ def _parse_issue_conditions( condition_taint["call_info"], leaf_kind, callable_position ) - kinds = [ - Kind.from_json(kind_json, leaf_kind, callable_position) - for kind_json in condition_taint["kinds"] - ] + kinds_by_interval = Kind.partition_by_interval( + [ + Kind.from_json(kind_json, leaf_kind, callable_position) + for kind_json in condition_taint["kinds"] + ] + ) issue_leaves.update( { @@ -664,6 +715,7 @@ def _parse_issue_conditions( kind=kind.name, distance=kind.distance, ) + for _, kinds in kinds_by_interval.items() for kind in kinds for origin in kind.origins } @@ -675,32 +727,37 @@ def _parse_issue_conditions( ) if call_info.is_origin(): - for kind in kinds: - condition_leaves = [ConditionLeaf.from_kind(kind)] - for origin in kind.origins: - conditions.append( - IssueCondition( - callee=ConditionCall.from_origin(origin, call_info), - leaves=condition_leaves, - local_positions=local_positions, - features=features, - extra_traces=set(kind.extra_traces), + for interval, kinds in kinds_by_interval.items(): + for kind in kinds: + condition_leaves = [ConditionLeaf.from_kind(kind)] + for origin in kind.origins: + conditions.append( + IssueCondition( + callee=ConditionCall.from_origin(origin, call_info), + leaves=condition_leaves, + local_positions=local_positions, + features=features, + extra_traces=set(kind.extra_traces), + type_interval=interval, + ) ) - ) else: - condition_leaves = [ConditionLeaf.from_kind(kind) for kind in kinds] - extra_traces = set() - for kind in kinds: - extra_traces.update(kind.extra_traces) - conditions.append( - IssueCondition( - callee=ConditionCall.from_call_info(call_info), - leaves=condition_leaves, - local_positions=local_positions, - features=features, - extra_traces=extra_traces, + for interval, kinds in kinds_by_interval.items(): + condition_leaves = [] + extra_traces = set() + for kind in kinds: + condition_leaves.append(ConditionLeaf.from_kind(kind)) + extra_traces.update(kind.extra_traces) + conditions.append( + IssueCondition( + callee=ConditionCall.from_call_info(call_info), + leaves=condition_leaves, + local_positions=local_positions, + features=features, + extra_traces=extra_traces, + type_interval=interval, + ) ) - ) return conditions, issue_leaves @@ -783,42 +840,47 @@ def _parse_condition( local_features = Features.from_taint_json(leaf_taint) kinds_json = leaf_taint["kinds"] - kinds = [ - Kind.from_json(kind_json, leaf_kind, caller_position) - for kind_json in kinds_json - ] + kinds_by_interval = Kind.partition_by_interval( + [ + Kind.from_json(kind_json, leaf_kind, caller_position) + for kind_json in kinds_json + ] + ) if call_info.is_origin(): - condition_by_callee = {} - for kind in kinds: - for origin in kind.origins: - callee = ConditionCall.from_origin(origin, call_info) - condition = condition_by_callee.get( - callee, - condition_class( - caller=caller, - callee=callee, - leaves=[], - local_positions=local_positions, - features=local_features, - extra_traces=set(), - ), - ) - condition.leaves.append(ConditionLeaf.from_kind(kind)) - condition.extra_traces.update(kind.extra_traces) - condition_by_callee[callee] = condition - for condition in condition_by_callee.values(): - yield condition + for interval, kinds in kinds_by_interval.items(): + condition_by_callee = {} + for kind in kinds: + for origin in kind.origins: + callee = ConditionCall.from_origin(origin, call_info) + condition = condition_by_callee.get( + callee, + condition_class( + caller=caller, + callee=callee, + leaves=[], + local_positions=local_positions, + features=local_features, + extra_traces=set(), + type_interval=interval, + ), + ) + condition.leaves.append(ConditionLeaf.from_kind(kind)) + condition.extra_traces.update(kind.extra_traces) + condition_by_callee[callee] = condition + for condition in condition_by_callee.values(): + yield condition else: - extra_traces = set() - for kind in kinds: - extra_traces.update(kind.extra_traces) - - yield condition_class( - caller=caller, - callee=ConditionCall.from_call_info(call_info), - leaves=[ConditionLeaf.from_kind(kind) for kind in kinds], - local_positions=local_positions, - features=local_features, - extra_traces=extra_traces, - ) + for interval, kinds in kinds_by_interval.items(): + extra_traces = set() + for kind in kinds: + extra_traces.update(kind.extra_traces) + yield condition_class( + caller=caller, + callee=ConditionCall.from_call_info(call_info), + leaves=[ConditionLeaf.from_kind(kind) for kind in kinds], + local_positions=local_positions, + features=local_features, + extra_traces=extra_traces, + type_interval=interval, + ) diff --git a/sapp/pipeline/tests/test_mariana_trench_parser.py b/sapp/pipeline/tests/test_mariana_trench_parser.py index 3d276213..e02eceb3 100644 --- a/sapp/pipeline/tests/test_mariana_trench_parser.py +++ b/sapp/pipeline/tests/test_mariana_trench_parser.py @@ -15,6 +15,7 @@ ParseTraceAnnotation, ParseTraceAnnotationSubtrace, ParseTraceFeature, + ParseTypeInterval, SourceLocation, ) from ..base_parser import ParseType @@ -2349,3 +2350,344 @@ def testModelPropagations(self) -> None: ) ], ) + + def testClassIntervals(self) -> None: + # Intervals at origin + self.assertParsed( + """ + { + "method": "LSink;.sink_wrapper:(LData;)V", + "sinks": [ + { + "port": "Argument(1)", + "taint": [ + { + "call_info": { + "call_kind": "Origin" + }, + "kinds": [ + { + "call_kind": "Origin", + "kind": "TestSink", + "origins": [ + { + "method": "LSink;.sink:(LData;)V", + "port": "Argument(1)" + } + ], + "callee_interval": [1, 2], + "preserves_type_context": true + }, + { + "call_kind": "Origin", + "kind": "TestSink", + "origins": [ + { + "method": "LSink;.sink:(LData;)V", + "port": "Argument(1)" + } + ], + "callee_interval": [3, 4], + "preserves_type_context": true + } + ] + } + ] + } + ], + "position": { + "line": 1, + "path": "Sink.java" + } + } + """, + [ + ParseConditionTuple( + type=ParseType.PRECONDITION, + caller="LSink;.sink_wrapper:(LData;)V", + callee="LSink;.sink:(LData;)V", + callee_location=SourceLocation( + line_no=1, + begin_column=1, + end_column=1, + ), + filename="Sink.java", + titos=[], + leaves=[("TestSink", 0)], + caller_port="argument(1)", + callee_port="sink", + type_interval=ParseTypeInterval( + start=1, finish=2, preserves_type_context=True + ), + features=[], + annotations=[], + ), + ParseConditionTuple( + type=ParseType.PRECONDITION, + caller="LSink;.sink_wrapper:(LData;)V", + callee="LSink;.sink:(LData;)V", + callee_location=SourceLocation( + line_no=1, + begin_column=1, + end_column=1, + ), + filename="Sink.java", + titos=[], + leaves=[("TestSink", 0)], + caller_port="argument(1)", + callee_port="sink", + type_interval=ParseTypeInterval( + start=3, finish=4, preserves_type_context=True + ), + features=[], + annotations=[], + ), + ], + ) + + # Intervals at call site + self.assertParsed( + """ + { + "method": "LClass;.indirect_sink:(LData;LData;)V", + "sinks": [ + { + "port": "Argument(2)", + "taint": [ + { + "call_info": { + "call_kind": "CallSite", + "resolves_to": "LSink;.sink:(LData;)V", + "port": "Argument(1)", + "position": { + "path": "Class.java", + "line": 10, + "start": 11, + "end": 12 + } + }, + "kinds": [ + { + "call_kind": "CallSite", + "distance": 1, + "kind": "TestSink", + "callee_interval": [10, 20], + "preserves_type_context": false + }, + { + "call_kind": "CallSite", + "distance": 2, + "kind": "TestSink2", + "callee_interval": [10, 20], + "preserves_type_context": false + }, + { + "call_kind": "CallSite", + "distance": 1, + "kind": "TestSink", + "callee_interval": [21, 30], + "preserves_type_context": true + } + ], + "local_positions": [ + {"line": 13, "start": 14, "end": 15}, + {"line": 16, "start": 17, "end": 18} + ] + } + ] + } + ], + "position": { + "line": 1, + "path": "Class.java" + } + } + """, + [ + ParseConditionTuple( + type=ParseType.PRECONDITION, + caller="LClass;.indirect_sink:(LData;LData;)V", + callee="LSink;.sink:(LData;)V", + callee_location=SourceLocation( + line_no=10, + begin_column=12, + end_column=13, + ), + filename="Class.java", + titos=[ + SourceLocation(line_no=13, begin_column=15, end_column=16), + SourceLocation(line_no=16, begin_column=18, end_column=19), + ], + leaves=[("TestSink", 1), ("TestSink2", 2)], + caller_port="argument(2)", + callee_port="argument(1)", + type_interval=ParseTypeInterval( + start=10, finish=20, preserves_type_context=False + ), + features=[], + annotations=[], + ), + ParseConditionTuple( + type=ParseType.PRECONDITION, + caller="LClass;.indirect_sink:(LData;LData;)V", + callee="LSink;.sink:(LData;)V", + callee_location=SourceLocation( + line_no=10, + begin_column=12, + end_column=13, + ), + filename="Class.java", + titos=[ + SourceLocation(line_no=13, begin_column=15, end_column=16), + SourceLocation(line_no=16, begin_column=18, end_column=19), + ], + leaves=[("TestSink", 1)], + caller_port="argument(2)", + callee_port="argument(1)", + type_interval=ParseTypeInterval( + start=21, finish=30, preserves_type_context=True + ), + features=[], + annotations=[], + ), + ], + ) + + # Intervals in issue condition + self.assertParsed( + """ + { + "method": "LClass;.flow:()V", + "issues": [ + { + "rule": 1, + "position": { + "path": "Flow.java", + "line": 10, + "start": 11, + "end": 12 + }, + "callee": "LSink;.sink:(LData;)V", + "sink_index": 0, + "sinks": [ + { + "call_info": { + "call_kind": "CallSite", + "resolves_to": "LSink;.sink:(LData;)V", + "port": "Argument(1)", + "position": { + "path": "Flow.java", + "line": 10, + "start": 11, + "end": 12 + } + }, + "kinds": [ + { + "call_kind": "CallSite", + "distance": 2, + "kind": "TestSink", + "origins": [ + { + "method": "LSink;.sink:(LData;)V", + "port": "Argument(1)" + } + ], + "callee_interval": [123, 456], + "preserves_type_context": true + } + ] + } + ], + "sources": [ + { + "call_info": { + "call_kind": "CallSite", + "resolves_to": "LSource;.source:()LData;", + "port": "Return", + "position": { + "path": "Flow.java", + "line": 20, + "start": 21, + "end": 22 + } + }, + "kinds": [ + { + "call_kind": "CallSite", + "distance": 3, + "kind": "TestSource", + "origins": [ + { + "method": "LSource;.source:(LData;)V", + "port": "Argument(1)" + } + ], + "callee_interval": [234, 345], + "preserves_type_context": true + } + ] + } + ] + } + ], + "position": { + "line": 2, + "path": "Flow.java" + } + } + """, + [ + ParseIssueTuple( + code=1, + message="TestRule: Test Rule Description", + callable="LClass;.flow:()V", + handle="LClass;.flow:()V:LSink;.sink:(LData;)V:0:1:1ef9022f932a64d0", + filename="Flow.java", + callable_line=2, + line=10, + start=12, + end=13, + preconditions=[ + ParseIssueConditionTuple( + callee="LSink;.sink:(LData;)V", + port="argument(1)", + location=SourceLocation( + line_no=10, + begin_column=12, + end_column=13, + ), + leaves=[("TestSink", 2)], + titos=[], + features=[], + type_interval=ParseTypeInterval( + start=123, finish=456, preserves_type_context=True + ), + annotations=[], + ) + ], + postconditions=[ + ParseIssueConditionTuple( + callee="LSource;.source:()LData;", + port="result", + location=SourceLocation( + line_no=20, + begin_column=22, + end_column=23, + ), + leaves=[("TestSource", 3)], + titos=[], + features=[], + type_interval=ParseTypeInterval( + start=234, finish=345, preserves_type_context=True + ), + annotations=[], + ) + ], + initial_sources={("LSource;.source:(LData;)V", "TestSource", 3)}, + final_sinks={("LSink;.sink:(LData;)V", "TestSink", 2)}, + features=[], + fix_info=None, + ) + ], + )