Skip to content

Commit 1caf770

Browse files
committed
add pytests
Signed-off-by: Sarah Yurick <[email protected]>
1 parent d48e6b5 commit 1caf770

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed

tests/stages/common/test_base.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,106 @@ def process(self, task: MockTask) -> MockTask:
265265
assert stage_with_custom2.resources == Resources(cpus=7.0)
266266

267267

268+
class TestProcessingStageOverriddenProperties:
269+
"""Test that ProcessingStage raises an error if a derived class overrides the name, resources, or batch_size property."""
270+
271+
def test_name_property(self):
272+
"""Test that ProcessingStage raises an error if a derived class overrides the name property."""
273+
with pytest.raises(TypeError, match="MockStageOverriddenName must not override 'name'"):
274+
275+
class MockStageOverriddenName(ProcessingStage[MockTask, MockTask]):
276+
"""Mock stage with overridden name property."""
277+
278+
_name = "MockStageOverriddenName"
279+
_resources = Resources(cpus=1.0)
280+
_batch_size = 1
281+
282+
# A derived class must not override the name property
283+
def name(self) -> str:
284+
return self._name
285+
286+
def process(self, task: MockTask) -> MockTask:
287+
return task
288+
289+
def inputs(self) -> tuple[list[str], list[str]]:
290+
return [], []
291+
292+
def outputs(self) -> tuple[list[str], list[str]]:
293+
return [], []
294+
295+
def test_resources_property(self):
296+
"""Test that ProcessingStage raises an error if a derived class overrides the resources property."""
297+
with pytest.raises(TypeError, match="MockStageOverriddenResources must not override 'resources'"):
298+
299+
class MockStageOverriddenResources(ProcessingStage[MockTask, MockTask]):
300+
"""Mock stage with overridden resources property."""
301+
302+
_name = "MockStageOverriddenResources"
303+
_resources = Resources(cpus=1.0)
304+
_batch_size = 1
305+
306+
# A derived class must not override the resources property
307+
def resources(self) -> Resources:
308+
return self._resources
309+
310+
def process(self, task: MockTask) -> MockTask:
311+
return task
312+
313+
def inputs(self) -> tuple[list[str], list[str]]:
314+
return [], []
315+
316+
def outputs(self) -> tuple[list[str], list[str]]:
317+
return [], []
318+
319+
def test_batch_size_property(self):
320+
"""Test that ProcessingStage raises an error if a derived class overrides the batch_size property."""
321+
with pytest.raises(TypeError, match="MockStageOverriddenBatchSize must not override 'batch_size'"):
322+
323+
class MockStageOverriddenBatchSize(ProcessingStage[MockTask, MockTask]):
324+
"""Mock stage with overridden batch_size property."""
325+
326+
_name = "MockStageOverriddenBatchSize"
327+
_resources = Resources(cpus=1.0)
328+
_batch_size = 1
329+
330+
# A derived class must not override the batch_size property
331+
def batch_size(self) -> int:
332+
return self._batch_size
333+
334+
def process(self, task: MockTask) -> MockTask:
335+
return task
336+
337+
def inputs(self) -> tuple[list[str], list[str]]:
338+
return [], []
339+
340+
def outputs(self) -> tuple[list[str], list[str]]:
341+
return [], []
342+
343+
def test_nested_class_inheritance(self):
344+
"""Test that nested class inheritance raises an error if a derived class overrides the name, resources, or batch_size property."""
345+
with pytest.raises(TypeError, match="MockStageNestedOverriddenName must not override 'name'"):
346+
347+
class MockStageNestedOverriddenName(ConcreteProcessingStage):
348+
"""Mock stage with nested class inheritance."""
349+
350+
_name = "MockStageNestedOverriddenName"
351+
_resources = Resources(cpus=1.0)
352+
_batch_size = 1
353+
354+
# A derived class must not override the name property
355+
def name(self) -> str:
356+
return self._name
357+
358+
def process(self, task: MockTask) -> MockTask:
359+
return task
360+
361+
def inputs(self) -> tuple[list[str], list[str]]:
362+
return [], []
363+
364+
def outputs(self) -> tuple[list[str], list[str]]:
365+
return [], []
366+
367+
268368
# Mock stages for testing composite stage functionality
269369
class MockStageA(ProcessingStage[MockTask, MockTask]):
270370
"""Mock stage A for testing composite stages."""

0 commit comments

Comments
 (0)