|
21 | 21 | class NormalizedQueryTest(tf.test.TestCase): |
22 | 22 |
|
23 | 23 | def test_normalization(self): |
24 | | - with self.cached_session() as sess: |
25 | | - record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0]. |
26 | | - record2 = tf.constant([4.0, -3.0]) # Not clipped. |
27 | | - |
28 | | - sum_query = gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=0.0) |
29 | | - query = normalized_query.NormalizedQuery( |
30 | | - numerator_query=sum_query, denominator=2.0) |
31 | | - |
32 | | - query_result, _ = test_utils.run_query(query, [record1, record2]) |
33 | | - result = sess.run(query_result) |
34 | | - expected = [0.5, 0.5] |
35 | | - self.assertAllClose(result, expected) |
| 24 | + record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0]. |
| 25 | + record2 = tf.constant([4.0, -3.0]) # Not clipped. |
| 26 | + |
| 27 | + sum_query = gaussian_query.GaussianSumQuery(l2_norm_clip=5.0, stddev=0.0) |
| 28 | + query = normalized_query.NormalizedQuery( |
| 29 | + numerator_query=sum_query, denominator=2.0) |
| 30 | + |
| 31 | + query_result, _ = test_utils.run_query(query, [record1, record2]) |
| 32 | + expected = [0.5, 0.5] |
| 33 | + self.assertAllClose(query_result, expected) |
36 | 34 |
|
37 | 35 |
|
38 | 36 | if __name__ == '__main__': |
|
0 commit comments