Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 49 additions & 26 deletions tests/test_repository.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# -*- coding: utf-8 -*-
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: from py4j.protocol import Py4JError is imported but pytest was removed. However, Py4JError is used in the test, so this is fine. Just noting that pytest removal is correct since pytest.mark.xfail is no longer used.

import unittest

import pytest
from py4j.protocol import Py4JError
from pyspark.sql import Row

from pydeequ.analyzers import *
from pydeequ.checks import *
from pydeequ.repository import *
from pydeequ.verification import *
from pydeequ.analyzers import AnalyzerContext, AnalysisRunner, ApproxCountDistinct
from pydeequ.checks import Check, CheckLevel
from pydeequ.repository import FileSystemMetricsRepository, InMemoryMetricsRepository, ResultKey
from pydeequ.verification import VerificationResult, VerificationSuite
from tests.conftest import setup_pyspark


Expand All @@ -18,7 +18,9 @@ def setUpClass(cls):
cls.AnalysisRunner = AnalysisRunner(cls.spark)
cls.VerificationSuite = VerificationSuite(cls.spark)
cls.sc = cls.spark.sparkContext
cls.df = cls.sc.parallelize([Row(a="foo", b=1, c=5), Row(a="bar", b=2, c=6), Row(a="baz", b=3, c=None)]).toDF()
cls.df = cls.sc.parallelize(
[Row(a="foo", b=1, c=5), Row(a="bar", b=2, c=6), Row(a="baz", b=3, c=None)]
).toDF()

@classmethod
def tearDownClass(cls):
Expand Down Expand Up @@ -121,12 +123,16 @@ def test_verifications_FSmetrep(self):
)

# TEST: Check JSON for tags
result_metrep_json = repository.load().before(ResultKey.current_milli_time()).getSuccessMetricsAsJson()
result_metrep_json = (
repository.load().before(ResultKey.current_milli_time()).getSuccessMetricsAsJson()
)

print(result_metrep_json[0]["tag"], key_tags["tag"])
self.assertEqual(result_metrep_json[0]["tag"], key_tags["tag"])

result_metrep = repository.load().before(ResultKey.current_milli_time()).getSuccessMetricsAsDataFrame()
result_metrep = (
repository.load().before(ResultKey.current_milli_time()).getSuccessMetricsAsDataFrame()
)

df = VerificationResult.checkResultsAsDataFrame(self.spark, result)
print(df.collect())
Expand All @@ -146,7 +152,9 @@ def test_verifications_FSmetrep_noTags_noFile(self):
)

# TEST: Check DF parity
result_metrep = repository.load().before(ResultKey.current_milli_time()).getSuccessMetricsAsDataFrame()
result_metrep = (
repository.load().before(ResultKey.current_milli_time()).getSuccessMetricsAsDataFrame()
)

df = VerificationResult.checkResultsAsDataFrame(self.spark, result)
print(df.collect())
Expand Down Expand Up @@ -243,12 +251,16 @@ def test_verifications_IMmetrep(self):
)

# TEST: Check JSON for tags
result_metrep_json = repository.load().before(ResultKey.current_milli_time()).getSuccessMetricsAsJson()
result_metrep_json = (
repository.load().before(ResultKey.current_milli_time()).getSuccessMetricsAsJson()
)

print(result_metrep_json[0]["tag"], key_tags["tag"])
self.assertEqual(result_metrep_json[0]["tag"], key_tags["tag"])

result_metrep = repository.load().before(ResultKey.current_milli_time()).getSuccessMetricsAsDataFrame()
result_metrep = (
repository.load().before(ResultKey.current_milli_time()).getSuccessMetricsAsDataFrame()
)

df = VerificationResult.checkResultsAsDataFrame(self.spark, result)
print(df.collect())
Expand All @@ -267,37 +279,43 @@ def test_verifications_IMmetrep_noTags_noFile(self):
)

# TEST: Check DF parity
result_metrep = repository.load().before(ResultKey.current_milli_time()).getSuccessMetricsAsDataFrame()
result_metrep = (
repository.load().before(ResultKey.current_milli_time()).getSuccessMetricsAsDataFrame()
)

df = VerificationResult.checkResultsAsDataFrame(self.spark, result)
print(df.collect())
print(result_metrep.collect())

@pytest.mark.xfail(reason="@unittest.expectedFailure")
def test_fail_no_useRepository(self):
"""This test should fail because it doesn't call useRepository() before saveOrAppendResult()"""
"""This run fails because it doesn't call useRepository() before saveOrAppendResult()."""
metrics_file = FileSystemMetricsRepository.helper_metrics_file(self.spark, "metrics.json")
print(f"metrics filepath: {metrics_file}")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_fail_no_useRepository asserts a specific Py4J error message substring "Method saveOrAppendResult([class com.amazon.deequ.repository.ResultKey]) does not exist". This is fragile because the exact error message depends on the Py4J and Deequ JAR versions. If the Deequ version changes, this message could differ. Consider just asserting self.assertRaises(Py4JError) without checking the message content, or use a less specific substring like "saveOrAppendResult".

key_tags = {"tag": "FS metrep analyzers -- FAIL"}
resultKey = ResultKey(self.spark, ResultKey.current_milli_time(), key_tags)

# MISSING useRepository()
result = (
self.AnalysisRunner.onData(self.df)
.addAnalyzer(ApproxCountDistinct("b"))
.saveOrAppendResult(resultKey)
.run()
with self.assertRaises(Py4JError) as err:
_ = (
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test_fail_no_load test asserts the error message string exactly, but MetricsRepository._check_RepositoryLoader checks if not self.RepositoryLoader which would raise AttributeError only if RepositoryLoader was never set as an attribute. However, repository.before(...) calls self._check_RepositoryLoader() which accesses self.RepositoryLoader. Since RepositoryLoader is never initialized in __init__ of FileSystemMetricsRepository (or its parent MetricsRepository), this will raise AttributeError: 'FileSystemMetricsRepository' object has no attribute 'RepositoryLoader' — but this is an implementation detail that could change. The assertion is correct for the current code, but it's brittle. Consider using self.assertIn or just self.assertRaises(AttributeError) without checking the exact message.

self.AnalysisRunner.onData(self.df)
.addAnalyzer(ApproxCountDistinct("b"))
.saveOrAppendResult(resultKey)
.run()
)

self.assertIn(
"Method saveOrAppendResult([class com.amazon.deequ.repository.ResultKey]) does not exist",
str(err.exception),
)

@pytest.mark.xfail(reason="@unittest.expectedFailure")
def test_fail_no_load(self):
"""This test should fail because we do not load() for the repository reading"""
"""This run fails because we do not load() for the repository reading."""
metrics_file = FileSystemMetricsRepository.helper_metrics_file(self.spark, "metrics.json")
print(f"metrics filepath: {metrics_file}")
repository = FileSystemMetricsRepository(self.spark, metrics_file)
key_tags = {"tag": "FS metrep analyzers"}
resultKey = ResultKey(self.spark, ResultKey.current_milli_time(), key_tags)
result = (
_ = (
self.AnalysisRunner.onData(self.df)
.addAnalyzer(ApproxCountDistinct("b"))
.useRepository(repository)
Expand All @@ -306,8 +324,13 @@ def test_fail_no_load(self):
)

# MISSING: repository.load()
result_metrep_json = (
repository.before(ResultKey.current_milli_time())
.forAnalyzers([ApproxCountDistinct("b")])
.getSuccessMetricsAsJson()
with self.assertRaises(AttributeError) as err:
_ = (
repository.before(ResultKey.current_milli_time())
.forAnalyzers([ApproxCountDistinct("b")])
.getSuccessMetricsAsJson()
)

self.assertEqual(
"'FileSystemMetricsRepository' object has no attribute 'RepositoryLoader'", str(err.exception)
)