Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
package datadog.trace.instrumentation.aws.v2.sns;

import static datadog.context.propagation.Propagators.defaultPropagator;
import static datadog.trace.api.datastreams.DataStreamsTags.Direction.OUTBOUND;
import static datadog.trace.api.datastreams.DataStreamsTags.create;
import static datadog.trace.bootstrap.instrumentation.api.AgentTracer.traceConfig;
import static datadog.trace.instrumentation.aws.v2.sns.TextMapInjectAdapter.SETTER;

import datadog.context.Context;
import datadog.trace.api.Config;
Expand All @@ -28,6 +25,9 @@

public class SnsInterceptor implements ExecutionInterceptor {

// SQS subscriber limit; SNS inherits it when SQS is used as a subscriber
private static final int MAX_MESSAGE_ATTRIBUTES = 10;
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Introduce a constant for clarity and deduplication


public static final ExecutionAttribute<Context> CONTEXT_ATTRIBUTE =
InstanceStore.of(ExecutionAttribute.class)
.putIfAbsent("DatadogContext", () -> new ExecutionAttribute<>("DatadogContext"));
Expand All @@ -38,10 +38,12 @@ private SdkBytes getMessageAttributeValueToInject(
StringBuilder jsonBuilder = new StringBuilder();
jsonBuilder.append('{');
if (traceConfig().isDataStreamsEnabled()) {
DataStreamsContext dsmContext = DataStreamsContext.fromTags(getTags(snsTopicName));
DataStreamsTags tags =
DataStreamsTags.create("sns", DataStreamsTags.Direction.OUTBOUND, snsTopicName);
DataStreamsContext dsmContext = DataStreamsContext.fromTags(tags);
context = context.with(dsmContext);
}
defaultPropagator().inject(context, jsonBuilder, SETTER);
defaultPropagator().inject(context, jsonBuilder, TextMapInjectAdapter.SETTER);
jsonBuilder.setLength(jsonBuilder.length() - 1); // Remove the last comma
jsonBuilder.append('}');
return SdkBytes.fromString(jsonBuilder.toString(), StandardCharsets.UTF_8);
Expand All @@ -57,9 +59,7 @@ public SdkRequest modifyRequest(ModifyRequest context, ExecutionAttributes execu
// Injecting the trace context into SNS messageAttributes.
if (context.request() instanceof PublishRequest) {
PublishRequest request = (PublishRequest) context.request();
// 10 messageAttributes is a limit from SQS, which is often used as a subscriber, therefore
// the limit still applies here
if (request.messageAttributes().size() < 10) {
if (request.messageAttributes().size() < MAX_MESSAGE_ATTRIBUTES) {
// Get topic name for DSM
String snsTopicArn = request.topicArn();
if (null == snsTopicArn) {
Expand All @@ -70,17 +70,11 @@ public SdkRequest modifyRequest(ModifyRequest context, ExecutionAttributes execu
}

String snsTopicName = snsTopicArn.substring(snsTopicArn.lastIndexOf(':') + 1);
Map<String, MessageAttributeValue> modifiedMessageAttributes =
new HashMap<>(request.messageAttributes());
modifiedMessageAttributes.put(
"_datadog", // Use Binary since SNS subscription filter policies fail silently with JSON
// strings https://github.com/DataDog/datadog-lambda-js/pull/269
MessageAttributeValue.builder()
.dataType("Binary")
.binaryValue(
this.getMessageAttributeValueToInject(executionAttributes, snsTopicName))
.build());
return request.toBuilder().messageAttributes(modifiedMessageAttributes).build();
Map<String, MessageAttributeValue> messageAttributes =
withDatadogAttribute(
request.messageAttributes(),
this.getMessageAttributeValueToInject(executionAttributes, snsTopicName));
return request.toBuilder().messageAttributes(messageAttributes).build();
}
return request;
} else if (context.request() instanceof PublishBatchRequest) {
Expand All @@ -89,24 +83,29 @@ public SdkRequest modifyRequest(ModifyRequest context, ExecutionAttributes execu
String snsTopicArn = request.topicArn();
String snsTopicName = snsTopicArn.substring(snsTopicArn.lastIndexOf(':') + 1);
ArrayList<PublishBatchRequestEntry> entries = new ArrayList<>();
final SdkBytes sdkBytes =
this.getMessageAttributeValueToInject(executionAttributes, snsTopicName);
for (PublishBatchRequestEntry entry : request.publishBatchRequestEntries()) {
if (entry.messageAttributes().size() < 10) {
Map<String, MessageAttributeValue> modifiedMessageAttributes =
new HashMap<>(entry.messageAttributes());
modifiedMessageAttributes.put(
"_datadog",
MessageAttributeValue.builder().dataType("Binary").binaryValue(sdkBytes).build());
entries.add(entry.toBuilder().messageAttributes(modifiedMessageAttributes).build());
if (entry.messageAttributes().size() < MAX_MESSAGE_ATTRIBUTES) {
Map<String, MessageAttributeValue> messageAttributes =
withDatadogAttribute(
entry.messageAttributes(),
this.getMessageAttributeValueToInject(executionAttributes, snsTopicName));
entry = entry.toBuilder().messageAttributes(messageAttributes).build();
}
entries.add(entry);
}
return request.toBuilder().publishBatchRequestEntries(entries).build();
}
return context.request();
}

private DataStreamsTags getTags(String snsTopicName) {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inlined getTags as it's used only once

return create("sns", OUTBOUND, snsTopicName);
private static Map<String, MessageAttributeValue> withDatadogAttribute(
Map<String, MessageAttributeValue> attributes, SdkBytes value) {
// copy since the original map may be unmodifiable
Map<String, MessageAttributeValue> modified = new HashMap<>(attributes);
// Use Binary since SNS subscription filter policies fail silently with JSON strings
// https://github.com/DataDog/datadog-lambda-js/pull/269
modified.put(
Copy link
Copy Markdown
Contributor Author

@ygree ygree Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extract duplicated parts to this helper method

"_datadog", MessageAttributeValue.builder().dataType("Binary").binaryValue(value).build());
return modified;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import software.amazon.awssdk.auth.credentials.AwsBasicCredentials
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider
import software.amazon.awssdk.regions.Region
import software.amazon.awssdk.services.sns.SnsClient
import software.amazon.awssdk.services.sns.model.MessageAttributeValue
import software.amazon.awssdk.services.sns.model.PublishResponse
import software.amazon.awssdk.services.sqs.SqsClient
import software.amazon.awssdk.services.sqs.model.QueueAttributeName
Expand Down Expand Up @@ -41,15 +40,15 @@ abstract class SnsClientTest extends VersionedNamingTestBase {
LOCALSTACK.start()
def endPoint = "http://" + LOCALSTACK.getHost() + ":" + LOCALSTACK.getMappedPort(4566)
snsClient = SnsClient.builder()
.endpointOverride(URI.create(endPoint))
.region(Region.of("us-east-1"))
.credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("test", "test")))
.build()
.endpointOverride(URI.create(endPoint))
Copy link
Copy Markdown
Contributor Author

@ygree ygree Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These parts and below were auto-formatted on commit

.region(Region.of("us-east-1"))
.credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("test", "test")))
.build()
sqsClient = SqsClient.builder()
.endpointOverride(URI.create(endPoint))
.region(Region.of("us-east-1"))
.credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("test", "test")))
.build()
.endpointOverride(URI.create(endPoint))
.region(Region.of("us-east-1"))
.credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("test", "test")))
.build()
testQueueURL = sqsClient.createQueue { it.queueName("testqueue") }.queueUrl()
testQueueARN = sqsClient.getQueueAttributes {it.queueUrl(testQueueURL).attributeNames(QueueAttributeName.QUEUE_ARN)}.attributes().get(QueueAttributeName.QUEUE_ARN)
testTopicARN = snsClient.createTopic { it.name("testtopic") }.topicArn()
Expand Down Expand Up @@ -82,30 +81,6 @@ abstract class SnsClientTest extends VersionedNamingTestBase {
abstract String expectedOperation(String awsService, String awsOperation)
abstract String expectedService(String awsService, String awsOperation)

def "trace details propagated when message attributes are readonly"() {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test has been moved to the JUnit test suit

when:
TEST_WRITER.clear()

def headers = new HashMap<String, MessageAttributeValue>()
headers.put("mykey", MessageAttributeValue.builder().stringValue("myvalue").dataType("String").build())
def readonlyHeaders = Collections.unmodifiableMap(headers)
snsClient.publish(b -> b.message("sometext").topicArn(testTopicARN).messageAttributes(readonlyHeaders))

def message = sqsClient.receiveMessage {
it.queueUrl(testQueueURL).waitTimeSeconds(3)
}.messages().get(0)

def messageBody = new JsonSlurper().parseText(message.body())

then:
// injected value is here
String injectedValue = messageBody["MessageAttributes"]["_datadog"]["Value"]
injectedValue.length() > 0

// original header value is still present
messageBody["MessageAttributes"]["mykey"] != null
}

def "trace details propagated via SNS system message attributes"() {
when:
TEST_WRITER.clear()
Expand Down Expand Up @@ -214,7 +189,7 @@ abstract class SnsClientTest extends VersionedNamingTestBase {
TEST_WRITER.clear()
snsClient.publish { req ->
req.message("test message")
.topicArn(testTopicARN)
.topicArn(testTopicARN)
}

def message = sqsClient.receiveMessage { it.queueUrl(testQueueURL).waitTimeSeconds(3) }.messages().get(0)
Expand Down Expand Up @@ -339,4 +314,3 @@ class SnsClientV1DataStreamsForkedTest extends SnsClientTest {
1
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package datadog.trace.instrumentation.aws.v2.sns;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotSame;
import static org.junit.jupiter.api.Assertions.assertTrue;

import datadog.context.Context;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.stream.Collectors;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.services.sns.model.MessageAttributeValue;
import software.amazon.awssdk.services.sns.model.PublishBatchRequest;
import software.amazon.awssdk.services.sns.model.PublishBatchRequestEntry;
import software.amazon.awssdk.services.sns.model.PublishRequest;

public class SnsInterceptorTest {

@Test
void publishBatchPreservesEntriesAndOnlyInjectsBelowTheMessageAttributeLimit() {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New test that reproduces the issue

PublishBatchRequest batchRequest =
PublishBatchRequest.builder()
.topicArn("arn:aws:sns:us-east-1:123456789012:test-topic")
.publishBatchRequestEntries(
PublishBatchRequestEntry.builder()
.id("at-limit")
.message("first")
.messageAttributes(stringAttributes(10))
.build(),
PublishBatchRequestEntry.builder()
.id("under-limit")
.message("second")
.messageAttributes(stringAttributes(9))
.build())
.build();

PublishBatchRequest modified =
(PublishBatchRequest)
new SnsInterceptor().modifyRequest(() -> batchRequest, executionAttributes());

assertEquals(
Arrays.asList("at-limit", "under-limit"),
modified.publishBatchRequestEntries().stream()
.map(PublishBatchRequestEntry::id)
.collect(Collectors.toList()));
assertFalse(
modified.publishBatchRequestEntries().get(0).messageAttributes().containsKey("_datadog"));
assertTrue(
modified.publishBatchRequestEntries().get(1).messageAttributes().containsKey("_datadog"));
}

@Test
void publishPreservesReadonlyAttributesWhileAddingDatadogContext() {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Migrated from Spock to JUnit

Map<String, MessageAttributeValue> headers = new HashMap<>();
headers.put(
"mykey", MessageAttributeValue.builder().dataType("String").stringValue("myvalue").build());
Map<String, MessageAttributeValue> readonlyHeaders = Collections.unmodifiableMap(headers);

PublishRequest request =
PublishRequest.builder()
.topicArn("arn:aws:sns:us-east-1:123456789012:test-topic")
.message("sometext")
.messageAttributes(readonlyHeaders)
.build();

PublishRequest modified =
(PublishRequest) new SnsInterceptor().modifyRequest(() -> request, executionAttributes());

assertNotSame(readonlyHeaders, modified.messageAttributes());
assertEquals("myvalue", modified.messageAttributes().get("mykey").stringValue());
assertTrue(modified.messageAttributes().containsKey("_datadog"));
assertFalse(readonlyHeaders.containsKey("_datadog"));
}

private static ExecutionAttributes executionAttributes() {
ExecutionAttributes executionAttributes = new ExecutionAttributes();
executionAttributes.putAttribute(SnsInterceptor.CONTEXT_ATTRIBUTE, Context.root());
return executionAttributes;
}

private static Map<String, MessageAttributeValue> stringAttributes(int count) {
Map<String, MessageAttributeValue> attributes = new LinkedHashMap<>();
for (int index = 1; index <= count; index++) {
attributes.put(
"key" + index,
MessageAttributeValue.builder().dataType("String").stringValue("value" + index).build());
}
return attributes;
}
}
Loading