diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-sns-2.0/src/main/java/datadog/trace/instrumentation/aws/v2/sns/SnsInterceptor.java b/dd-java-agent/instrumentation/aws-java/aws-java-sns-2.0/src/main/java/datadog/trace/instrumentation/aws/v2/sns/SnsInterceptor.java index fb9ca4b1495..6a063562f7f 100644 --- a/dd-java-agent/instrumentation/aws-java/aws-java-sns-2.0/src/main/java/datadog/trace/instrumentation/aws/v2/sns/SnsInterceptor.java +++ b/dd-java-agent/instrumentation/aws-java/aws-java-sns-2.0/src/main/java/datadog/trace/instrumentation/aws/v2/sns/SnsInterceptor.java @@ -1,10 +1,7 @@ package datadog.trace.instrumentation.aws.v2.sns; import static datadog.context.propagation.Propagators.defaultPropagator; -import static datadog.trace.api.datastreams.DataStreamsTags.Direction.OUTBOUND; -import static datadog.trace.api.datastreams.DataStreamsTags.create; import static datadog.trace.bootstrap.instrumentation.api.AgentTracer.traceConfig; -import static datadog.trace.instrumentation.aws.v2.sns.TextMapInjectAdapter.SETTER; import datadog.context.Context; import datadog.trace.api.Config; @@ -28,6 +25,9 @@ public class SnsInterceptor implements ExecutionInterceptor { + // SQS subscriber limit; SNS inherits it when SQS is used as a subscriber + private static final int MAX_MESSAGE_ATTRIBUTES = 10; + public static final ExecutionAttribute CONTEXT_ATTRIBUTE = InstanceStore.of(ExecutionAttribute.class) .putIfAbsent("DatadogContext", () -> new ExecutionAttribute<>("DatadogContext")); @@ -38,10 +38,12 @@ private SdkBytes getMessageAttributeValueToInject( StringBuilder jsonBuilder = new StringBuilder(); jsonBuilder.append('{'); if (traceConfig().isDataStreamsEnabled()) { - DataStreamsContext dsmContext = DataStreamsContext.fromTags(getTags(snsTopicName)); + DataStreamsTags tags = + DataStreamsTags.create("sns", DataStreamsTags.Direction.OUTBOUND, snsTopicName); + DataStreamsContext dsmContext = DataStreamsContext.fromTags(tags); context = context.with(dsmContext); } - defaultPropagator().inject(context, jsonBuilder, SETTER); + defaultPropagator().inject(context, jsonBuilder, TextMapInjectAdapter.SETTER); jsonBuilder.setLength(jsonBuilder.length() - 1); // Remove the last comma jsonBuilder.append('}'); return SdkBytes.fromString(jsonBuilder.toString(), StandardCharsets.UTF_8); @@ -57,9 +59,7 @@ public SdkRequest modifyRequest(ModifyRequest context, ExecutionAttributes execu // Injecting the trace context into SNS messageAttributes. if (context.request() instanceof PublishRequest) { PublishRequest request = (PublishRequest) context.request(); - // 10 messageAttributes is a limit from SQS, which is often used as a subscriber, therefore - // the limit still applies here - if (request.messageAttributes().size() < 10) { + if (request.messageAttributes().size() < MAX_MESSAGE_ATTRIBUTES) { // Get topic name for DSM String snsTopicArn = request.topicArn(); if (null == snsTopicArn) { @@ -70,17 +70,11 @@ public SdkRequest modifyRequest(ModifyRequest context, ExecutionAttributes execu } String snsTopicName = snsTopicArn.substring(snsTopicArn.lastIndexOf(':') + 1); - Map modifiedMessageAttributes = - new HashMap<>(request.messageAttributes()); - modifiedMessageAttributes.put( - "_datadog", // Use Binary since SNS subscription filter policies fail silently with JSON - // strings https://github.com/DataDog/datadog-lambda-js/pull/269 - MessageAttributeValue.builder() - .dataType("Binary") - .binaryValue( - this.getMessageAttributeValueToInject(executionAttributes, snsTopicName)) - .build()); - return request.toBuilder().messageAttributes(modifiedMessageAttributes).build(); + Map messageAttributes = + withDatadogAttribute( + request.messageAttributes(), + this.getMessageAttributeValueToInject(executionAttributes, snsTopicName)); + return request.toBuilder().messageAttributes(messageAttributes).build(); } return request; } else if (context.request() instanceof PublishBatchRequest) { @@ -89,24 +83,29 @@ public SdkRequest modifyRequest(ModifyRequest context, ExecutionAttributes execu String snsTopicArn = request.topicArn(); String snsTopicName = snsTopicArn.substring(snsTopicArn.lastIndexOf(':') + 1); ArrayList entries = new ArrayList<>(); - final SdkBytes sdkBytes = - this.getMessageAttributeValueToInject(executionAttributes, snsTopicName); for (PublishBatchRequestEntry entry : request.publishBatchRequestEntries()) { - if (entry.messageAttributes().size() < 10) { - Map modifiedMessageAttributes = - new HashMap<>(entry.messageAttributes()); - modifiedMessageAttributes.put( - "_datadog", - MessageAttributeValue.builder().dataType("Binary").binaryValue(sdkBytes).build()); - entries.add(entry.toBuilder().messageAttributes(modifiedMessageAttributes).build()); + if (entry.messageAttributes().size() < MAX_MESSAGE_ATTRIBUTES) { + Map messageAttributes = + withDatadogAttribute( + entry.messageAttributes(), + this.getMessageAttributeValueToInject(executionAttributes, snsTopicName)); + entry = entry.toBuilder().messageAttributes(messageAttributes).build(); } + entries.add(entry); } return request.toBuilder().publishBatchRequestEntries(entries).build(); } return context.request(); } - private DataStreamsTags getTags(String snsTopicName) { - return create("sns", OUTBOUND, snsTopicName); + private static Map withDatadogAttribute( + Map attributes, SdkBytes value) { + // copy since the original map may be unmodifiable + Map modified = new HashMap<>(attributes); + // Use Binary since SNS subscription filter policies fail silently with JSON strings + // https://github.com/DataDog/datadog-lambda-js/pull/269 + modified.put( + "_datadog", MessageAttributeValue.builder().dataType("Binary").binaryValue(value).build()); + return modified; } } diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-sns-2.0/src/test/groovy/SnsClientTest.groovy b/dd-java-agent/instrumentation/aws-java/aws-java-sns-2.0/src/test/groovy/SnsClientTest.groovy index 7df92d4724a..d340515d7a1 100644 --- a/dd-java-agent/instrumentation/aws-java/aws-java-sns-2.0/src/test/groovy/SnsClientTest.groovy +++ b/dd-java-agent/instrumentation/aws-java/aws-java-sns-2.0/src/test/groovy/SnsClientTest.groovy @@ -13,7 +13,6 @@ import software.amazon.awssdk.auth.credentials.AwsBasicCredentials import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider import software.amazon.awssdk.regions.Region import software.amazon.awssdk.services.sns.SnsClient -import software.amazon.awssdk.services.sns.model.MessageAttributeValue import software.amazon.awssdk.services.sns.model.PublishResponse import software.amazon.awssdk.services.sqs.SqsClient import software.amazon.awssdk.services.sqs.model.QueueAttributeName @@ -41,15 +40,15 @@ abstract class SnsClientTest extends VersionedNamingTestBase { LOCALSTACK.start() def endPoint = "http://" + LOCALSTACK.getHost() + ":" + LOCALSTACK.getMappedPort(4566) snsClient = SnsClient.builder() - .endpointOverride(URI.create(endPoint)) - .region(Region.of("us-east-1")) - .credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("test", "test"))) - .build() + .endpointOverride(URI.create(endPoint)) + .region(Region.of("us-east-1")) + .credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("test", "test"))) + .build() sqsClient = SqsClient.builder() - .endpointOverride(URI.create(endPoint)) - .region(Region.of("us-east-1")) - .credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("test", "test"))) - .build() + .endpointOverride(URI.create(endPoint)) + .region(Region.of("us-east-1")) + .credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("test", "test"))) + .build() testQueueURL = sqsClient.createQueue { it.queueName("testqueue") }.queueUrl() testQueueARN = sqsClient.getQueueAttributes {it.queueUrl(testQueueURL).attributeNames(QueueAttributeName.QUEUE_ARN)}.attributes().get(QueueAttributeName.QUEUE_ARN) testTopicARN = snsClient.createTopic { it.name("testtopic") }.topicArn() @@ -82,30 +81,6 @@ abstract class SnsClientTest extends VersionedNamingTestBase { abstract String expectedOperation(String awsService, String awsOperation) abstract String expectedService(String awsService, String awsOperation) - def "trace details propagated when message attributes are readonly"() { - when: - TEST_WRITER.clear() - - def headers = new HashMap() - headers.put("mykey", MessageAttributeValue.builder().stringValue("myvalue").dataType("String").build()) - def readonlyHeaders = Collections.unmodifiableMap(headers) - snsClient.publish(b -> b.message("sometext").topicArn(testTopicARN).messageAttributes(readonlyHeaders)) - - def message = sqsClient.receiveMessage { - it.queueUrl(testQueueURL).waitTimeSeconds(3) - }.messages().get(0) - - def messageBody = new JsonSlurper().parseText(message.body()) - - then: - // injected value is here - String injectedValue = messageBody["MessageAttributes"]["_datadog"]["Value"] - injectedValue.length() > 0 - - // original header value is still present - messageBody["MessageAttributes"]["mykey"] != null - } - def "trace details propagated via SNS system message attributes"() { when: TEST_WRITER.clear() @@ -214,7 +189,7 @@ abstract class SnsClientTest extends VersionedNamingTestBase { TEST_WRITER.clear() snsClient.publish { req -> req.message("test message") - .topicArn(testTopicARN) + .topicArn(testTopicARN) } def message = sqsClient.receiveMessage { it.queueUrl(testQueueURL).waitTimeSeconds(3) }.messages().get(0) @@ -339,4 +314,3 @@ class SnsClientV1DataStreamsForkedTest extends SnsClientTest { 1 } } - diff --git a/dd-java-agent/instrumentation/aws-java/aws-java-sns-2.0/src/test/java/datadog/trace/instrumentation/aws/v2/sns/SnsInterceptorTest.java b/dd-java-agent/instrumentation/aws-java/aws-java-sns-2.0/src/test/java/datadog/trace/instrumentation/aws/v2/sns/SnsInterceptorTest.java new file mode 100644 index 00000000000..74e4ad2d454 --- /dev/null +++ b/dd-java-agent/instrumentation/aws-java/aws-java-sns-2.0/src/test/java/datadog/trace/instrumentation/aws/v2/sns/SnsInterceptorTest.java @@ -0,0 +1,95 @@ +package datadog.trace.instrumentation.aws.v2.sns; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import datadog.context.Context; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.stream.Collectors; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.services.sns.model.MessageAttributeValue; +import software.amazon.awssdk.services.sns.model.PublishBatchRequest; +import software.amazon.awssdk.services.sns.model.PublishBatchRequestEntry; +import software.amazon.awssdk.services.sns.model.PublishRequest; + +public class SnsInterceptorTest { + + @Test + void publishBatchPreservesEntriesAndOnlyInjectsBelowTheMessageAttributeLimit() { + PublishBatchRequest batchRequest = + PublishBatchRequest.builder() + .topicArn("arn:aws:sns:us-east-1:123456789012:test-topic") + .publishBatchRequestEntries( + PublishBatchRequestEntry.builder() + .id("at-limit") + .message("first") + .messageAttributes(stringAttributes(10)) + .build(), + PublishBatchRequestEntry.builder() + .id("under-limit") + .message("second") + .messageAttributes(stringAttributes(9)) + .build()) + .build(); + + PublishBatchRequest modified = + (PublishBatchRequest) + new SnsInterceptor().modifyRequest(() -> batchRequest, executionAttributes()); + + assertEquals( + Arrays.asList("at-limit", "under-limit"), + modified.publishBatchRequestEntries().stream() + .map(PublishBatchRequestEntry::id) + .collect(Collectors.toList())); + assertFalse( + modified.publishBatchRequestEntries().get(0).messageAttributes().containsKey("_datadog")); + assertTrue( + modified.publishBatchRequestEntries().get(1).messageAttributes().containsKey("_datadog")); + } + + @Test + void publishPreservesReadonlyAttributesWhileAddingDatadogContext() { + Map headers = new HashMap<>(); + headers.put( + "mykey", MessageAttributeValue.builder().dataType("String").stringValue("myvalue").build()); + Map readonlyHeaders = Collections.unmodifiableMap(headers); + + PublishRequest request = + PublishRequest.builder() + .topicArn("arn:aws:sns:us-east-1:123456789012:test-topic") + .message("sometext") + .messageAttributes(readonlyHeaders) + .build(); + + PublishRequest modified = + (PublishRequest) new SnsInterceptor().modifyRequest(() -> request, executionAttributes()); + + assertNotSame(readonlyHeaders, modified.messageAttributes()); + assertEquals("myvalue", modified.messageAttributes().get("mykey").stringValue()); + assertTrue(modified.messageAttributes().containsKey("_datadog")); + assertFalse(readonlyHeaders.containsKey("_datadog")); + } + + private static ExecutionAttributes executionAttributes() { + ExecutionAttributes executionAttributes = new ExecutionAttributes(); + executionAttributes.putAttribute(SnsInterceptor.CONTEXT_ATTRIBUTE, Context.root()); + return executionAttributes; + } + + private static Map stringAttributes(int count) { + Map attributes = new LinkedHashMap<>(); + for (int index = 1; index <= count; index++) { + attributes.put( + "key" + index, + MessageAttributeValue.builder().dataType("String").stringValue("value" + index).build()); + } + return attributes; + } +}