import contextlib
import os
import shutil
import subprocess
import tempfile
from typing import Any
import torch
from PIL import Image
from transformers import AutoProcessor, CLIPModel
from feluda import Operator
from feluda.factory import VideoFactory
[docs]
class ClassifyVideoZeroShot(Operator):
"""Operator to classify a video into given labels using CLIP-ViT-B-32 and a zero-shot approach."""
[docs]
def __init__(self) -> None:
"""Initialize the `ClassifyVideoZeroShot` operator, loads the CLIP model and processor, and validates system dependencies."""
super().__init__()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.validate_system()
try:
self.processor = AutoProcessor.from_pretrained(
"openai/clip-vit-base-patch32"
)
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.model.to(self.device)
except Exception as e:
raise RuntimeError(f"Failed to load CLIP model or processor: {e}") from e
self.labels: list[str] = []
self.frame_images: list[Image.Image] = []
self.probs: torch.Tensor | None = None
[docs]
@staticmethod
def validate_system() -> None:
"""Validates that required system dependencies are available.
(ffmpeg).
"""
if shutil.which("ffmpeg") is None:
raise RuntimeError(
"FFmpeg is not installed or not found in system PATH. "
"Please install FFmpeg to use this operator."
)
[docs]
def gen_data(self) -> dict[str, Any]:
"""Generate output dict with prediction and probabilities.
Returns:
dict: A dictionary containing:
- `prediction` (str): Predicted label
- `probs` (list): Label probabilities
"""
return {
"prediction": self.labels[self.probs.argmax().item()],
"probs": self.probs.tolist() if self.probs is not None else [],
}
[docs]
def analyze(self) -> None:
"""Analyze the video file and generates predictions.
Args:
fname (str): Path to the video file
"""
self.frame_images = self.extract_frames()
if not self.frame_images:
raise RuntimeError(
"No frames extracted from video. Check if the video is valid."
)
self.probs = self.predict(self.frame_images, self.labels)
[docs]
def predict(self, images: list[Image.Image], labels: list[str]) -> torch.Tensor:
"""Run inference and gets label probabilities using a pre-trained
CLIP-ViT-B-32.
Args:
images (list): List of PIL Images
labels (list): List of labels
Returns:
torch.Tensor: Probability distribution across labels
"""
if not images:
raise ValueError("Image list for prediction must not be empty.")
if not labels:
raise ValueError("Label list for prediction must not be empty.")
inputs = self.processor(
text=labels,
images=images,
return_tensors="pt",
padding=True,
truncation=True,
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
output = self.model(**inputs)
logits_per_image = output.logits_per_image
probs = logits_per_image.softmax(dim=1)
return probs.mean(dim=0)
[docs]
def run(
self,
file: VideoFactory,
labels: list[str],
remove_after_processing: bool = False,
) -> dict[str, Any]:
"""Run the operator.
Args:
file (dict): VideoFactory file object (must have a 'path' key)
labels (list): List of labels
remove_after_processing (bool): Whether to remove the file after processing
Returns:
dict: A dictionary containing prediction and probabilities
"""
if not isinstance(file, dict) or "path" not in file:
raise TypeError("file must be a dict with a 'path' key from VideoFactory.")
if not isinstance(labels, list) or not all(isinstance(_, str) for _ in labels):
raise TypeError("labels must be a list of strings.")
if not labels:
raise ValueError("Label list must not be empty.")
fname = file["path"]
self.fname = fname
if not isinstance(fname, str) or not os.path.exists(fname):
raise FileNotFoundError(f"File not found: {fname}")
self.labels = labels
self.analyze()
try:
return self.gen_data()
finally:
if remove_after_processing:
with contextlib.suppress(FileNotFoundError):
os.remove(fname)
[docs]
def cleanup(self) -> None:
"""Clean up resources used by the operator."""
self.frame_images.clear()
self.probs = None
self.labels.clear()
del self.processor
del self.model
[docs]
def state(self) -> dict[str, Any]:
"""Return the current state of the operator.
Returns:
dict: State of the operator
"""
return {
"labels": self.labels,
"frame_images": [img.tobytes() for img in self.frame_images],
"probs": self.probs.tolist() if self.probs is not None else [],
"device": str(self.device),
"model": self.model,
"processor": self.processor,
}