Skip to content

Simplify API #6

@AmitMY

Description

@AmitMY

I got AToken to install, and I am now trying to, encode and decode separately, similar to how it would work for Cosmos -
I want to encode and only get latent codes, and decode and get the video back.

    @torch.no_grad()
    def encode_video(self, video: torch.Tensor) -> torch.Tensor:
        """Encode a preprocessed video tensor to latent tokens."""
        # Remove batch dimension for AToken wrapper
        if video.dim() == 5:  # [B, T, H, W, C]
            video = video[0]  # [T, H, W, C]

        # Process the video
        video_sparse = self.wrapper.image_video_to_sparse_tensor([video])
        task_types = ["video"]
        kwargs = {"task_types": task_types}

        # Get the latent representation
        _, features, _ = self.wrapper.inference(video_sparse, **kwargs)
        return features

    @torch.no_grad()
    def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
        """Decode latent tokens to video frames."""
        # For AToken, we need to reconstruct from the sparse representation
        # This is a simplified version - actual decoding may need more context
        task_types = ["video"]
        kwargs = {"task_types": task_types}

        rec, _, _ = self.wrapper.inference(latents, **kwargs)
        return rec
  1. At encode_video, the inference function basically runs encode and decode - I only want it to run the encoding, and get the latent codes.
  2. encode_video now returns a torch.Size([1, 1152]) tensor - for Cosmos, this would be reshaped to the video shape (e.g. for a video of 36 frames of 128x128x3, this would be (9, 8, 8)
  3. decode_video expects me to construct a SparseTensor but I am not sure how:
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/amit/dev/sign/video-tokenizer/video_tokenizer/bin.py", line 219, in <module>
    main()
  File "/home/amit/dev/sign/video-tokenizer/video_tokenizer/bin.py", line 207, in main
    video = tokenizer.decode_latents(latents)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amit/miniforge3/envs/tokenizer/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/amit/dev/sign/video-tokenizer/video_tokenizer/atoken.py", line 88, in decode_latents
    rec, _, _ = self.wrapper.inference(latents, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/amit/miniforge3/envs/tokenizer/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/amit/miniforge3/envs/tokenizer/lib/python3.11/site-packages/atoken_inference/atoken_wrapper.py", line 155, in inference
    feats = x.feats.to(device=device, dtype=dtype)
            ^^^^^^^
AttributeError: 'Tensor' object has no attribute 'feats'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions