Skip to content

Torch

Use pip install datachain[torch] and then import from datachain.torch to use the PyTorch functionality. DataChain.to_pytorch converts a chain into a PyTorch Dataset for downstream tasks like model training or inference. The classes and methods below help manipulate data from the chain for PyTorch.

clip_similarity_scores

clip_similarity_scores(
    images: Union[None, Image, list[Image]],
    text: Union[None, str, list[str]],
    model: Any,
    preprocess: Callable,
    tokenizer: Callable,
    prob: bool = False,
    image_to_text: bool = True,
    device: Optional[Union[str, device]] = None,
) -> list[list[float]]

Calculate CLIP similarity scores between one or more images and/or text.

Parameters:

  • images

    Images to use as inputs.

  • text

    Text to use as inputs.

  • model

    Model from clip or open_clip packages.

  • preprocess

    Image preprocessor to apply.

  • tokenizer

    Text tokenizer.

  • prob

    Compute softmax probabilities.

  • image_to_text

    Whether to compute for image-to-text or text-to-image. Ignored if only one of images or text provided.

  • device

    Device to use. Defaults is None - use model's device.

Example

Using https://github.com/openai/CLIP

>>> import clip
>>> model, preprocess = clip.load("ViT-B/32")
>>> similarity_scores(img, "cat", model, preprocess, clip.tokenize)
[[21.813]]

Using https://github.com/mlfoundations/open_clip

>>> import open_clip
>>> model, _, preprocess = open_clip.create_model_and_transforms(
...     "ViT-B-32", pretrained="laion2b_s34b_b79k"
... )
>>> tokenizer = open_clip.get_tokenizer("ViT-B-32")
>>> similarity_scores(img, "cat", model, preprocess, tokenizer)
[[21.813]]

Using https://huggingface.co/docs/transformers/en/model_doc/clip

>>> from transformers import CLIPProcessor, CLIPModel
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
>>> processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
>>> scores = similarity_scores(
...     img, "cat", model, processor.image_processor, processor.tokenizer
... )
[[21.813]]

Image -> list of text

>>> similarity_scores(img, ["cat", "dog"], model, preprocess, tokenizer)
[[21.813, 35.313]]

List of images -> text

>>> similarity_scores([img1, img2], "cat", model, preprocess, tokenizer)
[[21.813], [83.123]]

List of images -> list of text

>>> similarity_scores(
...     [img1, img2], ["cat", "dog"], model, preprocess, tokenizer)
... )
[[21.813, 35.313], [83.123, 34.843]]

List of images -> list of images

>>> similarity_scores([img1, img2], None, model, preprocess, tokenizer)
[[94.189, 37.092]]

List of text -> list of text

>>> similarity_scores(None, ["cat", "dog"], model, preprocess, tokenizer)
[[67.334, 23.588]]

Text -> list of images

>>> similarity_scores([img1, img2], "cat", ..., image_to_text=False)
[[19.708, 19.842]]

Show scores as softmax probabilities

>>> similarity_scores(img, ["cat", "dog"], ..., prob=True)
[[0.423, 0.577]]

Source code in datachain/lib/clip.py
def clip_similarity_scores(
    images: Union[None, "Image.Image", list["Image.Image"]],
    text: Union[None, str, list[str]],
    model: Any,
    preprocess: Callable,
    tokenizer: Callable,
    prob: bool = False,
    image_to_text: bool = True,
    device: Optional[Union[str, torch.device]] = None,
) -> list[list[float]]:
    """
    Calculate CLIP similarity scores between one or more images and/or text.

    Parameters:
        images : Images to use as inputs.
        text : Text to use as inputs.
        model : Model from clip or open_clip packages.
        preprocess : Image preprocessor to apply.
        tokenizer : Text tokenizer.
        prob : Compute softmax probabilities.
        image_to_text : Whether to compute for image-to-text or text-to-image. Ignored
            if only one of images or text provided.
        device : Device to use. Defaults is None - use model's device.


    Example:
        Using https://github.com/openai/CLIP
        ```py
        >>> import clip
        >>> model, preprocess = clip.load("ViT-B/32")
        >>> similarity_scores(img, "cat", model, preprocess, clip.tokenize)
        [[21.813]]
        ```

        Using https://github.com/mlfoundations/open_clip
        ```py
        >>> import open_clip
        >>> model, _, preprocess = open_clip.create_model_and_transforms(
        ...     "ViT-B-32", pretrained="laion2b_s34b_b79k"
        ... )
        >>> tokenizer = open_clip.get_tokenizer("ViT-B-32")
        >>> similarity_scores(img, "cat", model, preprocess, tokenizer)
        [[21.813]]
        ```

        Using https://huggingface.co/docs/transformers/en/model_doc/clip
        ```py
        >>> from transformers import CLIPProcessor, CLIPModel
        >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        >>> processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        >>> scores = similarity_scores(
        ...     img, "cat", model, processor.image_processor, processor.tokenizer
        ... )
        [[21.813]]
        ```

        Image -> list of text
        ```py
        >>> similarity_scores(img, ["cat", "dog"], model, preprocess, tokenizer)
        [[21.813, 35.313]]
        ```

        List of images -> text
        ```py
        >>> similarity_scores([img1, img2], "cat", model, preprocess, tokenizer)
        [[21.813], [83.123]]
        ```

        List of images -> list of text
        ```py
        >>> similarity_scores(
        ...     [img1, img2], ["cat", "dog"], model, preprocess, tokenizer)
        ... )
        [[21.813, 35.313], [83.123, 34.843]]
        ```

        List of images -> list of images
        ```py
        >>> similarity_scores([img1, img2], None, model, preprocess, tokenizer)
        [[94.189, 37.092]]
        ```

        List of text -> list of text
        ```py
        >>> similarity_scores(None, ["cat", "dog"], model, preprocess, tokenizer)
        [[67.334, 23.588]]
        ```

        Text -> list of images
        ```py
        >>> similarity_scores([img1, img2], "cat", ..., image_to_text=False)
        [[19.708, 19.842]]
        ```

        Show scores as softmax probabilities
        ```py
        >>> similarity_scores(img, ["cat", "dog"], ..., prob=True)
        [[0.423, 0.577]]
        ```
    """

    if device is None:
        if hasattr(model, "device"):
            device = model.device
        else:
            device = next(model.parameters()).device
    else:
        model = model.to(device)
    with torch.no_grad():
        if images is not None:
            encoder = _get_encoder(model, "image")
            image_features = convert_images(
                images, transform=preprocess, encoder=encoder, device=device
            )
            image_features /= image_features.norm(dim=-1, keepdim=True)  # type: ignore[union-attr]

        if text is not None:
            encoder = _get_encoder(model, "text")
            text_features = convert_text(
                text, tokenizer, encoder=encoder, device=device
            )
            text_features /= text_features.norm(dim=-1, keepdim=True)  # type: ignore[union-attr]

        if images is not None and text is not None:
            if image_to_text:
                logits = 100.0 * image_features @ text_features.T  # type: ignore[operator,union-attr]
            else:
                logits = 100.0 * text_features @ image_features.T  # type: ignore[operator,union-attr]
        elif images is not None:
            logits = 100.0 * image_features @ image_features.T  # type: ignore[operator,union-attr]
        elif text is not None:
            logits = 100.0 * text_features @ text_features.T  # type: ignore[operator,union-attr]
        else:
            raise ValueError(
                "Error calculating CLIP similarity - "
                "provide at least one of images or text"
            )

        if prob:
            scores = logits.softmax(dim=1)
        else:
            scores = logits

        return scores.tolist()

convert_image

convert_image(
    img: Image,
    mode: str = "RGB",
    size: Optional[tuple[int, int]] = None,
    transform: Optional[Callable] = None,
    encoder: Optional[Callable] = None,
    device: Optional[Union[str, device]] = None,
) -> Union[Image, Tensor]

Resize, transform, and otherwise convert an image.

Parameters:

  • img (Image) –

    PIL.Image object.

  • mode (str, default: 'RGB' ) –

    PIL.Image mode.

  • size (tuple[int, int], default: None ) –

    Size in (width, height) pixels for resizing.

  • transform (Callable, default: None ) –

    Torchvision transform or huggingface processor to apply.

  • encoder (Callable, default: None ) –

    Encode image using model.

  • device (str or device, default: None ) –

    Device to use.

Source code in datachain/lib/image.py
def convert_image(
    img: Image.Image,
    mode: str = "RGB",
    size: Optional[tuple[int, int]] = None,
    transform: Optional[Callable] = None,
    encoder: Optional[Callable] = None,
    device: Optional[Union[str, torch.device]] = None,
) -> Union[Image.Image, torch.Tensor]:
    """
    Resize, transform, and otherwise convert an image.

    Args:
        img (Image): PIL.Image object.
        mode (str): PIL.Image mode.
        size (tuple[int, int]): Size in (width, height) pixels for resizing.
        transform (Callable): Torchvision transform or huggingface processor to apply.
        encoder (Callable): Encode image using model.
        device (str or torch.device): Device to use.
    """
    if mode:
        img = img.convert(mode)
    if size:
        img = img.resize(size)
    if transform:
        img = transform(img)

        try:
            from transformers.image_processing_utils import BaseImageProcessor

            if isinstance(transform, BaseImageProcessor):
                img = torch.as_tensor(img.pixel_values[0]).clone().detach()  # type: ignore[assignment,attr-defined]
        except ImportError:
            pass
        if device:
            img = img.to(device)  # type: ignore[attr-defined]
        if encoder:
            img = img.unsqueeze(0)  # type: ignore[attr-defined]
    if encoder:
        img = encoder(img)
    return img

convert_images

convert_images(
    images: Union[Image, list[Image]],
    mode: str = "RGB",
    size: Optional[tuple[int, int]] = None,
    transform: Optional[Callable] = None,
    encoder: Optional[Callable] = None,
    device: Optional[Union[str, device]] = None,
) -> Union[list[Image], Tensor]

Resize, transform, and otherwise convert one or more images.

Parameters:

  • images ((Image, list[Image])) –

    PIL.Image object or list of objects.

  • mode (str, default: 'RGB' ) –

    PIL.Image mode.

  • size (tuple[int, int], default: None ) –

    Size in (width, height) pixels for resizing.

  • transform (Callable, default: None ) –

    Torchvision transform or huggingface processor to apply.

  • encoder (Callable, default: None ) –

    Encode image using model.

  • device (str or device, default: None ) –

    Device to use.

Source code in datachain/lib/image.py
def convert_images(
    images: Union[Image.Image, list[Image.Image]],
    mode: str = "RGB",
    size: Optional[tuple[int, int]] = None,
    transform: Optional[Callable] = None,
    encoder: Optional[Callable] = None,
    device: Optional[Union[str, torch.device]] = None,
) -> Union[list[Image.Image], torch.Tensor]:
    """
    Resize, transform, and otherwise convert one or more images.

    Args:
        images (Image, list[Image]): PIL.Image object or list of objects.
        mode (str): PIL.Image mode.
        size (tuple[int, int]): Size in (width, height) pixels for resizing.
        transform (Callable): Torchvision transform or huggingface processor to apply.
        encoder (Callable): Encode image using model.
        device (str or torch.device): Device to use.
    """
    if isinstance(images, Image.Image):
        images = [images]

    converted = [
        convert_image(img, mode, size, transform, device=device) for img in images
    ]

    if isinstance(converted[0], torch.Tensor):
        converted = torch.stack(converted)  # type: ignore[assignment,arg-type]

    if encoder:
        converted = encoder(converted)

    return converted  # type: ignore[return-value]

convert_text

convert_text(
    text: Union[str, list[str]],
    tokenizer: Optional[Callable] = None,
    tokenizer_kwargs: Optional[dict[str, Any]] = None,
    encoder: Optional[Callable] = None,
    device: Optional[Union[str, device]] = None,
) -> Union[str, list[str], Tensor]

Tokenize and otherwise transform text.

Parameters:

  • text (str) –

    Text to convert.

  • tokenizer (Callable, default: None ) –

    Tokenizer to use to tokenize objects.

  • tokenizer_kwargs (dict, default: None ) –

    Additional kwargs to pass when calling tokenizer.

  • encoder (Callable, default: None ) –

    Encode text using model.

  • device (str or device, default: None ) –

    Device to use.

Source code in datachain/lib/text.py
def convert_text(
    text: Union[str, list[str]],
    tokenizer: Optional[Callable] = None,
    tokenizer_kwargs: Optional[dict[str, Any]] = None,
    encoder: Optional[Callable] = None,
    device: Optional[Union[str, torch.device]] = None,
) -> Union[str, list[str], torch.Tensor]:
    """
    Tokenize and otherwise transform text.

    Args:
        text (str): Text to convert.
        tokenizer (Callable): Tokenizer to use to tokenize objects.
        tokenizer_kwargs (dict): Additional kwargs to pass when calling tokenizer.
        encoder (Callable): Encode text using model.
        device (str or torch.device): Device to use.
    """
    if not tokenizer:
        return text

    if isinstance(text, str):
        text = [text]

    if tokenizer_kwargs:
        res = tokenizer(text, **tokenizer_kwargs)
    else:
        res = tokenizer(text)

    tokens = res.input_ids if isinstance(tokenizer, PreTrainedTokenizerBase) else res
    tokens = torch.as_tensor(tokens).clone().detach()
    if device:
        tokens = tokens.to(device)

    if not encoder:
        return tokens

    return encoder(tokens)

label_to_int

label_to_int(value: str, classes: list) -> int

Given a value and list of classes, return the index of the value's class.

Source code in datachain/lib/pytorch.py
def label_to_int(value: str, classes: list) -> int:
    """Given a value and list of classes, return the index of the value's class."""
    return classes.index(value)