diff --git a/aws-xray-recorder-sdk-core/build.gradle.kts b/aws-xray-recorder-sdk-core/build.gradle.kts index e1c7d7c5..e65ac527 100644 --- a/aws-xray-recorder-sdk-core/build.gradle.kts +++ b/aws-xray-recorder-sdk-core/build.gradle.kts @@ -6,6 +6,7 @@ plugins { dependencies { api("commons-logging:commons-logging:1.3.5") + implementation("software.amazon.awssdk:utils-lite:2.34.0") implementation("com.fasterxml.jackson.core:jackson-annotations:2.17.0") implementation("com.fasterxml.jackson.core:jackson-databind:2.17.0") implementation("com.google.auto.value:auto-value-annotations:1.10.4") diff --git a/aws-xray-recorder-sdk-core/src/main/java/com/amazonaws/xray/contexts/LambdaSegmentContext.java b/aws-xray-recorder-sdk-core/src/main/java/com/amazonaws/xray/contexts/LambdaSegmentContext.java index 02c67e72..0f1173dd 100644 --- a/aws-xray-recorder-sdk-core/src/main/java/com/amazonaws/xray/contexts/LambdaSegmentContext.java +++ b/aws-xray-recorder-sdk-core/src/main/java/com/amazonaws/xray/contexts/LambdaSegmentContext.java @@ -30,20 +30,28 @@ import java.util.Objects; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import software.amazon.awssdk.utilslite.SdkInternalThreadLocal; public class LambdaSegmentContext implements SegmentContext { private static final Log logger = LogFactory.getLog(LambdaSegmentContext.class); private static final String LAMBDA_TRACE_HEADER_KEY = "_X_AMZN_TRACE_ID"; - + private static final String CONCURRENT_TRACE_ID_KEY = "AWS_LAMBDA_X_TRACE_ID"; + // See: https://github.com/aws/aws-xray-sdk-java/issues/251 private static final String LAMBDA_TRACE_HEADER_PROP = "com.amazonaws.xray.traceHeader"; public static TraceHeader getTraceHeaderFromEnvironment() { - String lambdaTraceHeaderKey = System.getenv(LAMBDA_TRACE_HEADER_KEY); - return TraceHeader.fromString(lambdaTraceHeaderKey != null && lambdaTraceHeaderKey.length() > 0 - ? lambdaTraceHeaderKey - : System.getProperty(LAMBDA_TRACE_HEADER_PROP)); + String lambdaTraceHeaderKeyFromMdc = SdkInternalThreadLocal.get(CONCURRENT_TRACE_ID_KEY); + String lambdaTraceHeaderKeyFromEnvVar = System.getenv(LAMBDA_TRACE_HEADER_KEY); + + if (lambdaTraceHeaderKeyFromMdc != null && lambdaTraceHeaderKeyFromMdc.length() > 0) { + return TraceHeader.fromString(lambdaTraceHeaderKeyFromMdc); + } else if (lambdaTraceHeaderKeyFromEnvVar != null && lambdaTraceHeaderKeyFromEnvVar.length() > 0) { + return TraceHeader.fromString(lambdaTraceHeaderKeyFromEnvVar); + } else { + return TraceHeader.fromString(System.getProperty(LAMBDA_TRACE_HEADER_PROP)); + } } // SuppressWarnings is needed for passing Root TraceId to noOp segment diff --git a/aws-xray-recorder-sdk-core/src/test/java/com/amazonaws/xray/contexts/LambdaSegmentContextTest.java b/aws-xray-recorder-sdk-core/src/test/java/com/amazonaws/xray/contexts/LambdaSegmentContextTest.java index 2f7a728b..5213a8a6 100644 --- a/aws-xray-recorder-sdk-core/src/test/java/com/amazonaws/xray/contexts/LambdaSegmentContextTest.java +++ b/aws-xray-recorder-sdk-core/src/test/java/com/amazonaws/xray/contexts/LambdaSegmentContextTest.java @@ -40,6 +40,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.quality.Strictness; +import software.amazon.awssdk.utilslite.SdkInternalThreadLocal; @ExtendWith(MockitoExtension.class) @MockitoSettings(strictness = Strictness.LENIENT) @@ -174,4 +175,118 @@ private static void testContextResultsInNoOpSegmentParent() { mockContext.endSubsegment(AWSXRay.getGlobalRecorder()); assertThat(AWSXRay.getTraceEntity()).isNull(); } + + @Test + @SetSystemProperty(key = "com.amazonaws.xray.traceHeader", value = TRACE_HEADER) + void testSystemPropertyFallbackWithTraceValidation() { + LambdaSegmentContext mockContext = new LambdaSegmentContext(); + Subsegment subsegment = mockContext.beginSubsegment(AWSXRay.getGlobalRecorder(), "test"); + FacadeSegment parent = (FacadeSegment) subsegment.getParent(); + + // Verify system property values are used correctly + assertThat(parent.getTraceId().toString()).isEqualTo("1-57ff426a-80c11c39b0c928905eb0828d"); + assertThat(parent.getId()).isEqualTo("1234abcd1234abcd"); + assertThat(parent.isSampled()).isTrue(); + + mockContext.endSubsegment(AWSXRay.getGlobalRecorder()); + } + + @Test + @SetEnvironmentVariable(key = "_X_AMZN_TRACE_ID", value = TRACE_HEADER) + void testEnvironmentVariableFallbackWithTraceValidation() { + LambdaSegmentContext mockContext = new LambdaSegmentContext(); + Subsegment subsegment = mockContext.beginSubsegment(AWSXRay.getGlobalRecorder(), "test"); + FacadeSegment parent = (FacadeSegment) subsegment.getParent(); + + // Verify system property values are used correctly + assertThat(parent.getTraceId().toString()).isEqualTo("1-57ff426a-80c11c39b0c928905eb0828d"); + assertThat(parent.getId()).isEqualTo("1234abcd1234abcd"); + assertThat(parent.isSampled()).isTrue(); + + mockContext.endSubsegment(AWSXRay.getGlobalRecorder()); + } + + @Test + @SetSystemProperty(key = "com.amazonaws.xray.traceHeader", value = TRACE_HEADER_2) + void testSdkInternalThreadLocalTakesPriorityOverSystemProperty() { + SdkInternalThreadLocal.put("AWS_LAMBDA_X_TRACE_ID", TRACE_HEADER); + + try { + LambdaSegmentContext mockContext = new LambdaSegmentContext(); + Subsegment subsegment = mockContext.beginSubsegment(AWSXRay.getGlobalRecorder(), "test"); + FacadeSegment parent = (FacadeSegment) subsegment.getParent(); + + // Verify SdkInternalThreadLocal values are used (TRACE_HEADER), not system property values (TRACE_HEADER_2) + assertThat(parent.getTraceId().toString()).isEqualTo("1-57ff426a-80c11c39b0c928905eb0828d"); + assertThat(parent.getId()).isEqualTo("1234abcd1234abcd"); + assertThat(parent.isSampled()).isTrue(); + + mockContext.endSubsegment(AWSXRay.getGlobalRecorder()); + } finally { + SdkInternalThreadLocal.remove("AWS_LAMBDA_X_TRACE_ID"); + } + } + + @Test + @SetSystemProperty(key = "com.amazonaws.xray.traceHeader", value = TRACE_HEADER) + void testSdkInternalThreadLocalWithEmptyStringFallsBackToSystemProperty() { + SdkInternalThreadLocal.put("AWS_LAMBDA_X_TRACE_ID", ""); + + try { + LambdaSegmentContext mockContext = new LambdaSegmentContext(); + Subsegment subsegment = mockContext.beginSubsegment(AWSXRay.getGlobalRecorder(), "test"); + FacadeSegment parent = (FacadeSegment) subsegment.getParent(); + + // Verify system property values are used as fallback when SdkInternalThreadLocal returns empty string + assertThat(parent.getTraceId().toString()).isEqualTo("1-57ff426a-80c11c39b0c928905eb0828d"); + assertThat(parent.getId()).isEqualTo("1234abcd1234abcd"); + assertThat(parent.isSampled()).isTrue(); + + mockContext.endSubsegment(AWSXRay.getGlobalRecorder()); + } finally { + SdkInternalThreadLocal.remove("AWS_LAMBDA_X_TRACE_ID"); + } + } + + @Test + @SetEnvironmentVariable(key = "_X_AMZN_TRACE_ID", value = TRACE_HEADER_2) + void testSdkInternalThreadLocalTakesPriorityOverEnvironmentVariable() { + SdkInternalThreadLocal.put("AWS_LAMBDA_X_TRACE_ID", TRACE_HEADER); + + try { + LambdaSegmentContext mockContext = new LambdaSegmentContext(); + Subsegment subsegment = mockContext.beginSubsegment(AWSXRay.getGlobalRecorder(), "test"); + FacadeSegment parent = (FacadeSegment) subsegment.getParent(); + + // Verify SdkInternalThreadLocal values are used (TRACE_HEADER), not environment variable values (TRACE_HEADER_2) + assertThat(parent.getTraceId().toString()).isEqualTo("1-57ff426a-80c11c39b0c928905eb0828d"); + assertThat(parent.getId()).isEqualTo("1234abcd1234abcd"); + assertThat(parent.isSampled()).isTrue(); + + mockContext.endSubsegment(AWSXRay.getGlobalRecorder()); + } finally { + SdkInternalThreadLocal.remove("AWS_LAMBDA_X_TRACE_ID"); + } + } + + @Test + @SetEnvironmentVariable(key = "_X_AMZN_TRACE_ID", value = TRACE_HEADER) + void testSdkInternalThreadLocalWithEmptyStringFallsBackToEnvironmentVariable() { + SdkInternalThreadLocal.put("AWS_LAMBDA_X_TRACE_ID", ""); + + try { + LambdaSegmentContext mockContext = new LambdaSegmentContext(); + Subsegment subsegment = mockContext.beginSubsegment(AWSXRay.getGlobalRecorder(), "test"); + FacadeSegment parent = (FacadeSegment) subsegment.getParent(); + + // Verify environment variable values are used as fallback when SdkInternalThreadLocal returns empty string + assertThat(parent.getTraceId().toString()).isEqualTo("1-57ff426a-80c11c39b0c928905eb0828d"); + assertThat(parent.getId()).isEqualTo("1234abcd1234abcd"); + assertThat(parent.isSampled()).isTrue(); + + mockContext.endSubsegment(AWSXRay.getGlobalRecorder()); + } finally { + SdkInternalThreadLocal.remove("AWS_LAMBDA_X_TRACE_ID"); + } + } }