diff --git a/datadog_lambda/asm.py b/datadog_lambda/asm.py index 9636760c3..11df24714 100644 --- a/datadog_lambda/asm.py +++ b/datadog_lambda/asm.py @@ -1,9 +1,12 @@ -from copy import deepcopy import logging +import urllib.parse +from copy import deepcopy from typing import Any, Dict, List, Optional, Union from ddtrace.contrib.internal.trace_utils import _get_request_header_client_ip from ddtrace.internal import core +from ddtrace.internal.utils import get_blocked +from ddtrace.internal.utils import http as http_utils from ddtrace.trace import Span from datadog_lambda.trigger import ( @@ -50,6 +53,7 @@ def asm_set_context(event_source: _EventSource): This allows the AppSecSpanProcessor to know information about the event at the moment the span is created and skip it when not relevant. """ + if event_source.event_type not in _http_event_types: core.set_item("appsec_skip_next_lambda_event", True) @@ -126,6 +130,14 @@ def asm_start_request( span.set_tag_str("http.client_ip", request_ip) span.set_tag_str("network.client.ip", request_ip) + # Encode the parsed query and append it to reconstruct the original raw URI expected by AppSec. + if parsed_query: + try: + encoded_query = urllib.parse.urlencode(parsed_query, doseq=True) + raw_uri += "?" + encoded_query # type: ignore + except Exception: + pass + core.dispatch( # The matching listener is registered in ddtrace.appsec._handlers "aws_lambda.start_request", @@ -182,3 +194,36 @@ def asm_start_response( response_headers, ), ) + + +def get_asm_blocked_response( + event_source: _EventSource, +) -> Optional[Dict[str, Any]]: + """Get the blocked response for the given event source.""" + if event_source.event_type not in _http_event_types: + return None + + blocked = get_blocked() + if not blocked: + return None + + desired_type = blocked.get("type", "auto") + if desired_type == "none": + content_type = "text/plain; charset=utf-8" + content = "" + else: + content_type = blocked.get("content-type", "application/json") + content = http_utils._get_blocked_template(content_type) + + response_headers = { + "content-type": content_type, + } + if "location" in blocked: + response_headers["location"] = blocked["location"] + + return { + "statusCode": blocked.get("status_code", 403), + "headers": response_headers, + "body": content, + "isBase64Encoded": False, + } diff --git a/datadog_lambda/wrapper.py b/datadog_lambda/wrapper.py index 06f8884c1..917a4fa0e 100644 --- a/datadog_lambda/wrapper.py +++ b/datadog_lambda/wrapper.py @@ -9,7 +9,7 @@ from importlib import import_module from time import time_ns -from datadog_lambda.asm import asm_set_context, asm_start_response, asm_start_request +from ddtrace.internal._exceptions import BlockingException from datadog_lambda.extension import should_use_extension, flush_extension from datadog_lambda.cold_start import ( set_cold_start, @@ -46,6 +46,14 @@ extract_http_status_code_tag, ) +if config.appsec_enabled: + from datadog_lambda.asm import ( + asm_set_context, + asm_start_response, + asm_start_request, + get_asm_blocked_response, + ) + if config.profiling_enabled: from ddtrace.profiling import profiler @@ -120,6 +128,7 @@ def __init__(self, func): self.span = None self.inferred_span = None self.response = None + self.blocking_response = None if config.profiling_enabled: self.prof = profiler.Profiler(env=config.env, service=config.service) @@ -159,8 +168,12 @@ def __call__(self, event, context, **kwargs): """Executes when the wrapped function gets called""" self._before(event, context) try: + if self.blocking_response: + return self.blocking_response self.response = self.func(event, context, **kwargs) return self.response + except BlockingException: + self.blocking_response = get_asm_blocked_response(self.event_source) except Exception: from datadog_lambda.metric import submit_errors_metric @@ -171,6 +184,8 @@ def __call__(self, event, context, **kwargs): raise finally: self._after(event, context) + if self.blocking_response: + return self.blocking_response def _inject_authorizer_span_headers(self, request_id): reference_span = self.inferred_span if self.inferred_span else self.span @@ -203,6 +218,7 @@ def _inject_authorizer_span_headers(self, request_id): def _before(self, event, context): try: self.response = None + self.blocking_response = None set_cold_start(init_timestamp_ns) if not should_use_extension: @@ -253,6 +269,7 @@ def _before(self, event, context): ) if config.appsec_enabled: asm_start_request(self.span, event, event_source, self.trigger_tags) + self.blocking_response = get_asm_blocked_response(self.event_source) else: set_correlation_ids() if config.profiling_enabled and is_new_sandbox(): @@ -286,13 +303,14 @@ def _after(self, event, context): if status_code: self.span.set_tag("http.status_code", status_code) - if config.appsec_enabled: + if config.appsec_enabled and not self.blocking_response: asm_start_response( self.span, status_code, self.event_source, response=self.response, ) + self.blocking_response = get_asm_blocked_response(self.event_source) self.span.finish() diff --git a/tests/test_asm.py b/tests/test_asm.py index e57c289fd..e3a5e027b 100644 --- a/tests/test_asm.py +++ b/tests/test_asm.py @@ -2,8 +2,17 @@ import pytest from unittest.mock import MagicMock, patch -from datadog_lambda.asm import asm_start_request, asm_start_response -from datadog_lambda.trigger import parse_event_source, extract_trigger_tags +from datadog_lambda.asm import ( + asm_start_request, + asm_start_response, + get_asm_blocked_response, +) +from datadog_lambda.trigger import ( + EventTypes, + _EventSource, + extract_trigger_tags, + parse_event_source, +) from tests.utils import get_mock_context event_samples = "tests/event_samples/" @@ -15,7 +24,7 @@ "application_load_balancer", "application-load-balancer.json", "72.12.164.125", - "/lambda", + "/lambda?query=1234ABCD", "GET", "", False, @@ -27,7 +36,7 @@ "application_load_balancer_multivalue_headers", "application-load-balancer-mutivalue-headers.json", "72.12.164.125", - "/lambda", + "/lambda?query=1234ABCD", "GET", "", False, @@ -51,7 +60,7 @@ "api_gateway", "api-gateway.json", "127.0.0.1", - "/path/to/resource", + "/path/to/resource?foo=bar", "POST", "eyJ0ZXN0IjoiYm9keSJ9", True, @@ -199,6 +208,40 @@ ), ] +ASM_BLOCKED_RESPONSE_TEST_CASES = [ + # JSON blocking response + ( + {"status_code": 403, "type": "auto", "content-type": "application/json"}, + 403, + {"content-type": "application/json"}, + ), + # HTML blocking response + ( + { + "status_code": 401, + "type": "html", + "content-type": "text/html", + }, + 401, + {"content-type": "text/html"}, + ), + # Plain text redirect response + ( + {"status_code": 301, "type": "none", "location": "https://example.com/blocked"}, + 301, + { + "content-type": "text/plain; charset=utf-8", + "location": "https://example.com/blocked", + }, + ), + # Default to content-type application/json and status code 403 when not provided + ( + {"type": "auto"}, + 403, + {"content-type": "application/json"}, + ), +] + @pytest.mark.parametrize( "name,file,expected_ip,expected_uri,expected_method,expected_body,expected_base64,expected_query,expected_path_params,expected_route", @@ -327,3 +370,31 @@ def test_asm_start_response_parametrized( else: # Verify core.dispatch was not called for non-HTTP events mock_core.dispatch.assert_not_called() + + +@pytest.mark.parametrize( + "blocked_config, expected_status, expected_headers", + ASM_BLOCKED_RESPONSE_TEST_CASES, +) +@patch("datadog_lambda.asm.get_blocked") +def test_get_asm_blocked_response_blocked( + mock_get_blocked, + blocked_config, + expected_status, + expected_headers, +): + mock_get_blocked.return_value = blocked_config + event_source = _EventSource(event_type=EventTypes.API_GATEWAY) + response = get_asm_blocked_response(event_source) + assert response["statusCode"] == expected_status + assert response["headers"] == expected_headers + + +@patch("datadog_lambda.asm.get_blocked") +def test_get_asm_blocked_response_not_blocked( + mock_get_blocked, +): + mock_get_blocked.return_value = None + event_source = _EventSource(event_type=EventTypes.API_GATEWAY) + response = get_asm_blocked_response(event_source) + assert response is None diff --git a/tests/test_wrapper.py b/tests/test_wrapper.py index fc081e904..e07b5ca91 100644 --- a/tests/test_wrapper.py +++ b/tests/test_wrapper.py @@ -2,8 +2,9 @@ import json import os import unittest +import importlib -from unittest.mock import patch, call, ANY +from unittest.mock import MagicMock, patch, call, ANY from datadog_lambda.constants import TraceHeader import datadog_lambda.wrapper as wrapper @@ -660,3 +661,117 @@ def lambda_handler(event, context): lambda_handler(lambda_event, lambda_context) self.assertEqual(len(flushes), 0) + + +class TestLambdaWrapperAppsecBlocking(unittest.TestCase): + def setUp(self): + os.environ["DD_APPSEC_ENABLED"] = "true" + os.environ["DD_TRACE_ENABLED"] = "true" + + importlib.reload(wrapper) + + self.addCleanup(os.environ.pop, "DD_APPSEC_ENABLED", None) + self.addCleanup(os.environ.pop, "DD_TRACE_ENABLED", None) + self.addCleanup(lambda: importlib.reload(wrapper)) + + patcher = patch("datadog_lambda.wrapper.asm_set_context") + self.mock_asm_set_context = patcher.start() + self.addCleanup(patcher.stop) + + patcher = patch("datadog_lambda.wrapper.asm_start_request") + self.mock_asm_start_request = patcher.start() + self.addCleanup(patcher.stop) + + patcher = patch("datadog_lambda.wrapper.asm_start_response") + self.mock_asm_start_response = patcher.start() + self.addCleanup(patcher.stop) + + patcher = patch("datadog_lambda.wrapper.get_asm_blocked_response") + self.mock_get_asm_blocking_response = patcher.start() + self.addCleanup(patcher.stop) + + self.fake_blocking_response = { + "statusCode": "403", + "headers": { + "Content-Type": "application/json", + }, + "body": '{"message": "Blocked by AppSec"}', + "isBase64Encoded": False, + } + + def test_blocking_before(self): + self.mock_get_asm_blocking_response.return_value = self.fake_blocking_response + + mock_handler = MagicMock() + + lambda_handler = wrapper.datadog_lambda_wrapper(mock_handler) + + response = lambda_handler({}, get_mock_context()) + self.assertEqual(response, self.fake_blocking_response) + + mock_handler.assert_not_called() + + self.mock_asm_set_context.assert_called_once() + self.mock_asm_start_request.assert_called_once() + self.mock_asm_start_response.assert_not_called() + + def test_blocking_during(self): + self.mock_get_asm_blocking_response.return_value = None + + @wrapper.datadog_lambda_wrapper + def lambda_handler(event, context): + self.mock_get_asm_blocking_response.return_value = ( + self.fake_blocking_response + ) + raise wrapper.BlockingException() + + response = lambda_handler({}, get_mock_context()) + self.assertEqual(response, self.fake_blocking_response) + + self.mock_asm_set_context.assert_called_once() + self.mock_asm_start_request.assert_called_once() + self.mock_asm_start_response.assert_not_called() + + def test_blocking_after(self): + self.mock_get_asm_blocking_response.return_value = None + + @wrapper.datadog_lambda_wrapper + def lambda_handler(event, context): + self.mock_get_asm_blocking_response.return_value = ( + self.fake_blocking_response + ) + return { + "statusCode": 200, + "body": "This should not be returned", + } + + response = lambda_handler({}, get_mock_context()) + self.assertEqual(response, self.fake_blocking_response) + + self.mock_asm_set_context.assert_called_once() + self.mock_asm_start_request.assert_called_once() + self.mock_asm_start_response.assert_called_once() + + def test_no_blocking_appsec_disabled(self): + os.environ["DD_APPSEC_ENABLED"] = "false" + + importlib.reload(wrapper) + + self.mock_get_asm_blocking_response.return_value = self.fake_blocking_response + + expected_response = { + "statusCode": 200, + "body": "This should be returned", + } + + @wrapper.datadog_lambda_wrapper + def lambda_handler(event, context): + return expected_response + + response = lambda_handler({}, get_mock_context()) + self.assertEqual(response, expected_response) + + self.mock_get_asm_blocking_response.assert_not_called() + self.mock_asm_set_context.assert_not_called() + self.mock_asm_start_request.assert_not_called() + self.mock_asm_start_response.assert_not_called()