Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
### Changed

- Relaxed video manifest creation to make use of keyframes even if seek lands earlier
(<https://github.com/cvat-ai/cvat/pull/9994>)
61 changes: 43 additions & 18 deletions utils/dataset_manifest/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from .types import NamedBytesIO
from .utils import PcdReader, SortingMethod, md5_hash, rotate_image, sort

# how many frames to check after seeking to validate key frame
SEEK_MISMATCH_UPPER_BOUND = 200


class VideoStreamReader:
def __init__(self, source_path, chunk_size, force):
Expand Down Expand Up @@ -70,12 +73,35 @@ def __len__(self):
def resolution(self):
return (self.width, self.height)

def validate_key_frame(self, container, video_stream, key_frame):
for packet in container.demux(video_stream):
for frame in packet.decode():
if md5_hash(frame) != key_frame["md5"] or frame.pts != key_frame["pts"]:
return False
return True
def validate_key_frame(
self,
container: av.container.InputContainer,
video_stream: av.video.stream.VideoStream,
key_frame: dict,
prev_seek_pts: Optional[int],
) -> Optional[int]:
"""
Returns a pts of the first decoded frame after seeking to the key_frame pts
Returns None if the key frame is not suitable for seeking
"""
container.seek(offset=key_frame["pts"], stream=video_stream)

frames = (frame for packet in container.demux(video_stream) for frame in packet.decode())
frames = islice(frames, SEEK_MISMATCH_UPPER_BOUND)

seek_pts = None
for frame in frames:
if seek_pts is None:
seek_pts = frame.pts
# if seek landed on the same frame as previous seek, it is redundant
if prev_seek_pts == seek_pts:
return None
if frame.pts < key_frame["pts"]:
continue
if md5_hash(frame) != key_frame["md5"] or frame.pts != key_frame["pts"]:
return None
return seek_pts
return None

def __iter__(self) -> Iterator[Union[int, tuple[int, int, str]]]:
"""
Expand All @@ -94,6 +120,7 @@ def __iter__(self) -> Iterator[Union[int, tuple[int, int, str]]]:
prev_pts: Optional[int] = None
prev_dts: Optional[int] = None
index, key_frame_count = 0, 0
prev_seek_pts: Optional[int] = None

for packet in reading_container.demux(reading_v_stream):
for frame in packet.decode():
Expand All @@ -111,17 +138,15 @@ def __iter__(self) -> Iterator[Union[int, tuple[int, int, str]]]:
}

# Check that it is possible to seek to this key frame using frame.pts
checking_container.seek(
offset=key_frame_data["pts"],
stream=checking_v_stream,
)
is_valid_key_frame = self.validate_key_frame(
seek_pts = self.validate_key_frame(
checking_container,
checking_v_stream,
key_frame_data,
prev_seek_pts,
)

if is_valid_key_frame:
if seek_pts is not None:
prev_seek_pts = seek_pts
key_frame_count += 1
yield (index, key_frame_data["pts"], key_frame_data["md5"])
else:
Expand Down Expand Up @@ -614,12 +639,12 @@ def _get_video_stream(container):
return video_stream

def validate_key_frame(self, container, video_stream, key_frame):
for packet in container.demux(video_stream):
for frame in packet.decode():
assert (
frame.pts == key_frame["pts"]
), "The uploaded manifest does not match the video"
return
frames = (frame for packet in container.demux(video_stream) for frame in packet.decode())
frames = islice(frames, SEEK_MISMATCH_UPPER_BOUND)

assert any(
frame.pts == key_frame["pts"] for frame in frames
), "The uploaded manifest does not match the video"

def validate_seek_key_frames(self):
with closing(av.open(self._source_path, mode="r")) as container:
Expand Down
Loading