@@ -24,11 +24,20 @@ def abcd(self):
2424 else :
2525 port = 9200
2626 host = "localhost"
27+ security_enabled = os .getenv ("security_enabled" ) == "true"
28+ if os .environ ["opensearch-version" ] == "latest" :
29+ credential = "admin:myStrongPassword123!"
30+ else :
31+ credential = "admin:admin"
2732
2833 logging .basicConfig (level = logging .INFO )
2934
30- url = f"opensearch://admin:admin@{ host } :{ port } "
31- opensearch_abcd = ABCD .from_url (url , index_name = "test_index" , use_ssl = False )
35+ url = f"opensearch://{ credential } @{ host } :{ port } "
36+ opensearch_abcd = ABCD .from_url (
37+ url ,
38+ index_name = "test_index" ,
39+ use_ssl = security_enabled ,
40+ )
3241 assert isinstance (opensearch_abcd , OpenSearchDatabase )
3342 return opensearch_abcd
3443
@@ -78,10 +87,9 @@ def test_push(self, abcd):
7887 assert isinstance (atoms_2 , Atoms )
7988 atoms_2 .set_cell ([1 , 1 , 1 ])
8089
90+ abcd .refresh ()
8191 result = AtomsModel (
82- None ,
83- None ,
84- abcd .client .search (index = "test_index" )["hits" ]["hits" ][0 ]["_source" ],
92+ dict = abcd .client .search (index = "test_index" )["hits" ]["hits" ][0 ]["_source" ],
8593 ).to_ase ()
8694 assert atoms_1 == result
8795 assert atoms_2 != result
@@ -117,17 +125,14 @@ def test_bulk(self, abcd):
117125 atoms_list .append (atoms_1 )
118126 atoms_list .append (atoms_2 )
119127 abcd .push (atoms_list )
128+ abcd .refresh ()
120129 assert abcd .count () == 2
121130
122131 result_1 = AtomsModel (
123- None ,
124- None ,
125- abcd .client .search (index = "test_index" )["hits" ]["hits" ][0 ]["_source" ],
132+ dict = abcd .client .search (index = "test_index" )["hits" ]["hits" ][0 ]["_source" ],
126133 ).to_ase ()
127134 result_2 = AtomsModel (
128- None ,
129- None ,
130- abcd .client .search (index = "test_index" )["hits" ]["hits" ][1 ]["_source" ],
135+ dict = abcd .client .search (index = "test_index" )["hits" ]["hits" ][1 ]["_source" ],
131136 ).to_ase ()
132137 assert atoms_1 == result_1
133138 assert atoms_2 == result_2
@@ -151,4 +156,5 @@ def test_count(self, abcd):
151156 atoms .set_cell ([1 , 1 , 1 ])
152157 abcd .push (atoms )
153158 abcd .push (atoms )
159+ abcd .refresh ()
154160 assert abcd .count () == 2
0 commit comments