Skip to content

Commit 9e25eee

Browse files
schien1729tensorflower-gardener
authored andcommitted
Update remaining DPQuery tests to TF2.
PiperOrigin-RevId: 468793518
1 parent fd64be5 commit 9e25eee

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

tensorflow_privacy/privacy/dp_query/normalized_query_test.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,16 @@
2121
class NormalizedQueryTest(tf.test.TestCase):
2222

2323
def test_normalization(self):
24-
with self.cached_session() as sess:
25-
record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0].
26-
record2 = tf.constant([4.0, -3.0]) # Not clipped.
27-
28-
sum_query = gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=0.0)
29-
query = normalized_query.NormalizedQuery(
30-
numerator_query=sum_query, denominator=2.0)
31-
32-
query_result, _ = test_utils.run_query(query, [record1, record2])
33-
result = sess.run(query_result)
34-
expected = [0.5, 0.5]
35-
self.assertAllClose(result, expected)
24+
record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0].
25+
record2 = tf.constant([4.0, -3.0]) # Not clipped.
26+
27+
sum_query = gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=0.0)
28+
query = normalized_query.NormalizedQuery(
29+
numerator_query=sum_query, denominator=2.0)
30+
31+
query_result, _ = test_utils.run_query(query, [record1, record2])
32+
expected = [0.5, 0.5]
33+
self.assertAllClose(query_result, expected)
3634

3735

3836
if __name__ == '__main__':

0 commit comments

Comments
 (0)