Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ defmodule Bumblebee do
"CLIPModel" => {Bumblebee.Multimodal.Clip, :base},
"CLIPTextModel" => {Bumblebee.Text.ClipText, :base},
"CLIPVisionModel" => {Bumblebee.Vision.ClipVision, :base},
"SiglipModel" => {Bumblebee.Multimodal.SigLip, :base},
"SiglipTextModel" => {Bumblebee.Text.SigLipText, :base},
"SiglipVisionModel" => {Bumblebee.Vision.SigLipVision, :base},
Comment on lines +119 to +121
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are missing this mapping:

Suggested change
"SiglipModel" => {Bumblebee.Multimodal.SigLip, :base},
"SiglipTextModel" => {Bumblebee.Text.SigLipText, :base},
"SiglipVisionModel" => {Bumblebee.Vision.SigLipVision, :base},
"SiglipModel" => {Bumblebee.Multimodal.SigLip, :base},
"SiglipTextModel" => {Bumblebee.Text.SigLipText, :base},
"SiglipVisionModel" => {Bumblebee.Vision.SigLipVision, :base},
"SiglipForImageClassification" => {Bumblebee.Vision.SigLipVision, :for_image_classification},

"ControlNetModel" => {Bumblebee.Diffusion.ControlNet, :base},
"ConvNextForImageClassification" => {Bumblebee.Vision.ConvNext, :for_image_classification},
"ConvNextModel" => {Bumblebee.Vision.ConvNext, :base},
Expand Down Expand Up @@ -229,6 +232,8 @@ defmodule Bumblebee do

@transformers_class_to_featurizer %{
"CLIPFeatureExtractor" => Bumblebee.Vision.ClipFeaturizer,
"SiglipImageProcessor" => Bumblebee.Vision.ClipFeaturizer,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SiglipImageProcessor is different from CLIP (e.g. it doesn't center crop). It should be it's own featurizer

"Siglip2ImageProcessor" => Bumblebee.Vision.ClipFeaturizer,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Siglip2 is a separate altogether.

Suggested change
"Siglip2ImageProcessor" => Bumblebee.Vision.ClipFeaturizer,

"ConvNextFeatureExtractor" => Bumblebee.Vision.ConvNextFeaturizer,
"DeiTFeatureExtractor" => Bumblebee.Vision.DeitFeaturizer,
"ViTFeatureExtractor" => Bumblebee.Vision.VitFeaturizer,
Expand Down Expand Up @@ -257,6 +262,7 @@ defmodule Bumblebee do
"distilbert" => :distilbert,
"camembert" => :camembert,
"clip" => :clip,
"siglip" => :siglip,
"gemma" => :gemma,
"gpt_neox" => :gpt_neo_x,
"gpt2" => :gpt2,
Expand Down
250 changes: 250 additions & 0 deletions lib/bumblebee/multimodal/siglip.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
defmodule Bumblebee.Multimodal.SigLip do
alias Bumblebee.Shared

options =
[
text_spec: [
default: nil,
doc: "the specification of the text model. See `Bumblebee.Text.SigLipText` for details"
],
vision_spec: [
default: nil,
doc:
"the specification of the vision model. See `Bumblebee.Vision.SigLipVision` for details"
],
logit_scale_initial_value: [
default: 2.6592,
doc: "the initial value for the scaling layer used to scale similarity logits"
],
logit_bias_initial_value: [
default: -10.0,
doc: "the initial value for the bias added to similarity logits"
]
]

@moduledoc """
The SigLIP model for text-image similarity.

SigLIP uses a sigmoid loss function instead of the contrastive loss used
by CLIP, which allows for better scaling and more stable training.

## Architectures

* `:base` - the base SigLIP model

## Inputs

* `"input_ids"` - `{batch_size, sequence_length}`

Indices of input sequence tokens in the vocabulary.

* `"attention_mask"` - `{batch_size, sequence_length}`

Mask indicating which tokens to attend to. This is used to ignore
padding tokens, which are added when processing a batch of sequences
with different length.

* `"position_ids"` - `{batch_size, sequence_length}`

Indices of positions of each input sequence tokens in the position
embeddings.

* `"pixel_values"` - `{batch_size, image_size, image_size, num_channels}`

Featurized image pixel values.

## Global layer options

#{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])}

## Configuration

#{Shared.options_doc(options)}

## References

* [Sigmoid Loss for Language Image Pre-Training](https://arxiv.org/abs/2303.15343)

"""

defstruct [architecture: :base] ++ Shared.option_defaults(options)

@behaviour Bumblebee.ModelSpec
@behaviour Bumblebee.Configurable

alias Bumblebee.Layers

@impl true
def architectures(), do: [:base]

@impl true
def config(spec, opts) do
Shared.put_config_attrs(spec, opts)
end

@impl true
def input_template(%{vision_spec: vision_spec}) do
vision_shape = {1, vision_spec.image_size, vision_spec.image_size, vision_spec.num_channels}

%{
"input_ids" => Nx.template({1, 1}, :u32),
"pixel_values" => Nx.template(vision_shape, :f32)
}
end

@impl true
def model(%__MODULE__{architecture: :base} = spec) do
%{text_spec: text_spec, vision_spec: vision_spec} = spec

text_shape = {nil, nil}
vision_shape = {nil, vision_spec.image_size, vision_spec.image_size, vision_spec.num_channels}

inputs =
Bumblebee.Utils.Model.inputs_to_map([
Axon.input("input_ids", shape: text_shape),
Axon.input("attention_mask", optional: true, shape: text_shape),
Axon.input("position_ids", optional: true, shape: text_shape),
Axon.input("pixel_values", shape: vision_shape)
])

text_model =
text_spec
|> Bumblebee.build_model()
|> Bumblebee.Utils.Axon.prefix_names("text_model.")
|> Bumblebee.Utils.Axon.plug_inputs(%{
"input_ids" => inputs["input_ids"],
"attention_mask" => inputs["attention_mask"],
"position_ids" => inputs["position_ids"]
})

vision_model =
vision_spec
|> Bumblebee.build_model()
|> Bumblebee.Utils.Axon.prefix_names("vision_model.")
|> Bumblebee.Utils.Axon.plug_inputs(%{
"pixel_values" => inputs["pixel_values"]
})

text_embedding =
text_model
|> Axon.nx(& &1.pooled_state)
|> Axon.nx(&normalize/1)

image_embedding =
vision_model
|> Axon.nx(& &1.pooled_state)
|> Axon.nx(&normalize/1)

similarity = Layers.cosine_similarity(text_embedding, image_embedding)

logits_per_text =
similarity
|> scale_layer(
name: "logit_scale",
initializer: Axon.Initializers.full(spec.logit_scale_initial_value)
)
|> bias_layer(
name: "logit_bias",
initializer: Axon.Initializers.full(spec.logit_bias_initial_value)
)
Comment on lines +146 to +149
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is Axon.bias, which should be the same?


logits_per_image = Axon.transpose(logits_per_text)

Layers.output(%{
logits_per_text: logits_per_text,
logits_per_image: logits_per_image,
text_embedding: text_embedding,
image_embedding: image_embedding
})
end

defp normalize(x) do
Nx.divide(x, Nx.LinAlg.norm(x, ord: 2, axes: [-1], keep_axes: true))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can use this:

defn normalize(tensor) do

end

defp scale_layer(input, opts) do
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have this in CLIP too:

defp exp_scale(input, opts) do
opts = Keyword.validate!(opts, [:name, scale_initializer: Axon.Initializers.full(1.0e-6)])
name = opts[:name]
scale_initializer = opts[:scale_initializer]
scale_param = Axon.param("scale", fn _ -> {} end, initializer: scale_initializer)
Axon.layer(
fn input, scale, _opts ->
Nx.multiply(input, Nx.exp(scale))
end,
[input, scale_param],
name: name,
op_name: :exp_scale
)
end

So it would make sense to share this. But we already have a scale layer that multiplies:

def scale(%Axon{} = x, opts \\ []) do

What I would do is add an option to Layers.scale called :transform, and it would accept an Nx function to apply to the scale itself. Then here and in CLIP we would pass transform: &Nx.exp/1.

name = opts[:name]
initializer = opts[:initializer] || Axon.Initializers.full(1.0)

scale_param = Axon.param("scale", fn _ -> {} end, initializer: initializer)

Axon.layer(
fn input, scale, _opts ->
Nx.multiply(input, Nx.exp(scale))
end,
[input, scale_param],
name: name,
op_name: :logit_scale
)
end

defp bias_layer(input, opts) do
name = opts[:name]
initializer = opts[:initializer] || Axon.Initializers.full(0.0)

bias_param = Axon.param("bias", fn _ -> {} end, initializer: initializer)

Axon.layer(
fn input, bias, _opts ->
Nx.add(input, bias)
end,
[input, bias_param],
name: name,
op_name: :logit_bias
)
end

defimpl Bumblebee.HuggingFace.Transformers.Config do
def load(spec, data) do
import Shared.Converters

{text_data, data} = Map.pop(data, "text_config", %{})
{vision_data, data} = Map.pop(data, "vision_config", %{})

text_spec =
Bumblebee.Text.SigLipText
|> Bumblebee.configure()
|> Bumblebee.HuggingFace.Transformers.Config.load(text_data)

vision_spec =
Bumblebee.Vision.SigLipVision
|> Bumblebee.configure()
|> Bumblebee.HuggingFace.Transformers.Config.load(vision_data)

opts =
convert!(data,
logit_scale_initial_value: {"logit_scale_init_value", number()},
logit_bias_initial_value: {"logit_bias_init_value", number()}
)

@for.config(spec, opts ++ [text_spec: text_spec, vision_spec: vision_spec])
end
end

defimpl Bumblebee.HuggingFace.Transformers.Model do
alias Bumblebee.HuggingFace.Transformers

def params_mapping(spec) do
text_mapping =
spec.text_spec
|> Transformers.Model.params_mapping()
|> Transformers.Utils.prefix_params_mapping("text_model", nil)

vision_mapping =
spec.vision_spec
|> Transformers.Model.params_mapping()
|> Transformers.Utils.prefix_params_mapping("vision_model", nil)

%{
"logit_scale" => %{
"scale" => {[{"logit_scale", "logit_scale"}], fn [scale] -> Nx.squeeze(scale) end}
},
"logit_bias" => %{
"bias" => {[{"logit_bias", "logit_bias"}], fn [bias] -> Nx.squeeze(bias) end}
}
}
|> Map.merge(text_mapping)
|> Map.merge(vision_mapping)
end
end
end
Loading
Loading