Skip to content

Commit 58060e2

Browse files
authored
fix: Normalize s3 paths for PME key retriever (#2874)
1 parent a7cf6cf commit 58060e2

File tree

2 files changed

+68
-2
lines changed

2 files changed

+68
-2
lines changed

common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,35 @@ public class CometFileKeyUnwrapper {
101101
// Cache the hadoopConf just to assert the assumption above.
102102
private Configuration conf = null;
103103

104+
/**
105+
* Normalizes S3 URI schemes to a canonical form. S3 can be accessed via multiple schemes (s3://,
106+
* s3a://, s3n://) that refer to the same logical filesystem. This method ensures consistent cache
107+
* lookups regardless of which scheme is used.
108+
*
109+
* @param filePath The file path that may contain an S3 URI
110+
* @return The file path with normalized S3 scheme (s3a://)
111+
*/
112+
private String normalizeS3Scheme(final String filePath) {
113+
// Normalize s3:// and s3n:// to s3a:// for consistent cache lookups
114+
// This handles the case where ObjectStoreUrl uses s3:// but Spark uses s3a://
115+
String s3Prefix = "s3://";
116+
String s3nPrefix = "s3n://";
117+
if (filePath.startsWith(s3Prefix)) {
118+
return "s3a://" + filePath.substring(s3Prefix.length());
119+
} else if (filePath.startsWith(s3nPrefix)) {
120+
return "s3a://" + filePath.substring(s3nPrefix.length());
121+
}
122+
return filePath;
123+
}
124+
104125
/**
105126
* Creates and stores a DecryptionKeyRetriever instance for the given file path.
106127
*
107128
* @param filePath The path to the Parquet file
108129
* @param hadoopConf The Hadoop Configuration to use for this file path
109130
*/
110131
public void storeDecryptionKeyRetriever(final String filePath, final Configuration hadoopConf) {
132+
final String normalizedPath = normalizeS3Scheme(filePath);
111133
// Use DecryptionPropertiesFactory.loadFactory to get the factory and then call
112134
// getFileDecryptionProperties
113135
if (factory == null) {
@@ -122,7 +144,7 @@ public void storeDecryptionKeyRetriever(final String filePath, final Configurati
122144
factory.getFileDecryptionProperties(hadoopConf, path);
123145

124146
DecryptionKeyRetriever keyRetriever = decryptionProperties.getKeyRetriever();
125-
retrieverCache.put(filePath, keyRetriever);
147+
retrieverCache.put(normalizedPath, keyRetriever);
126148
}
127149

128150
/**
@@ -136,7 +158,8 @@ public void storeDecryptionKeyRetriever(final String filePath, final Configurati
136158
*/
137159
public byte[] getKey(final String filePath, final byte[] keyMetadata)
138160
throws ParquetCryptoRuntimeException {
139-
DecryptionKeyRetriever keyRetriever = retrieverCache.get(filePath);
161+
final String normalizedPath = normalizeS3Scheme(filePath);
162+
DecryptionKeyRetriever keyRetriever = retrieverCache.get(normalizedPath);
140163
if (keyRetriever == null) {
141164
throw new ParquetCryptoRuntimeException(
142165
"Failed to find DecryptionKeyRetriever for path: " + filePath);

spark/src/test/scala/org/apache/comet/parquet/ParquetReadFromS3Suite.scala

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@
1919

2020
package org.apache.comet.parquet
2121

22+
import java.nio.charset.StandardCharsets
23+
import java.util.Base64
24+
25+
import org.apache.parquet.crypto.DecryptionPropertiesFactory
26+
import org.apache.parquet.crypto.keytools.{KeyToolkit, PropertiesDrivenCryptoFactory}
27+
import org.apache.parquet.crypto.keytools.mocks.InMemoryKMS
2228
import org.apache.spark.sql.{DataFrame, SaveMode}
2329
import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec}
2430
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -30,6 +36,15 @@ class ParquetReadFromS3Suite extends CometS3TestBase with AdaptiveSparkPlanHelpe
3036

3137
override protected val testBucketName = "test-bucket"
3238

39+
// Encryption keys for testing parquet encryption
40+
private val encoder = Base64.getEncoder
41+
private val footerKey =
42+
encoder.encodeToString("0123456789012345".getBytes(StandardCharsets.UTF_8))
43+
private val key1 = encoder.encodeToString("1234567890123450".getBytes(StandardCharsets.UTF_8))
44+
private val key2 = encoder.encodeToString("1234567890123451".getBytes(StandardCharsets.UTF_8))
45+
private val cryptoFactoryClass =
46+
"org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory"
47+
3348
private def writeTestParquetFile(filePath: String): Unit = {
3449
val df = spark.range(0, 1000)
3550
df.write.format("parquet").mode(SaveMode.Overwrite).save(filePath)
@@ -76,4 +91,32 @@ class ParquetReadFromS3Suite extends CometS3TestBase with AdaptiveSparkPlanHelpe
7691
assertCometScan(df)
7792
assert(df.first().getLong(0) == 499500)
7893
}
94+
95+
test("write and read encrypted parquet from S3") {
96+
import testImplicits._
97+
98+
withSQLConf(
99+
DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass,
100+
KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME ->
101+
"org.apache.parquet.crypto.keytools.mocks.InMemoryKMS",
102+
InMemoryKMS.KEY_LIST_PROPERTY_NAME ->
103+
s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") {
104+
105+
val inputDF = spark
106+
.range(0, 1000)
107+
.map(i => (i, i.toString, i.toFloat))
108+
.repartition(5)
109+
.toDF("a", "b", "c")
110+
111+
val testFilePath = s"s3a://$testBucketName/data/encrypted-test.parquet"
112+
inputDF.write
113+
.option(PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, "key1: a, b; key2: c")
114+
.option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey")
115+
.parquet(testFilePath)
116+
117+
val df = spark.read.parquet(testFilePath).agg(sum(col("a")))
118+
assertCometScan(df)
119+
assert(df.first().getLong(0) == 499500)
120+
}
121+
}
79122
}

0 commit comments

Comments
 (0)