diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 13928ee1..3540d6a7 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -116,6 +116,9 @@ defmodule Bumblebee do "CLIPModel" => {Bumblebee.Multimodal.Clip, :base}, "CLIPTextModel" => {Bumblebee.Text.ClipText, :base}, "CLIPVisionModel" => {Bumblebee.Vision.ClipVision, :base}, + "SiglipModel" => {Bumblebee.Multimodal.SigLip, :base}, + "SiglipTextModel" => {Bumblebee.Text.SigLipText, :base}, + "SiglipVisionModel" => {Bumblebee.Vision.SigLipVision, :base}, "ControlNetModel" => {Bumblebee.Diffusion.ControlNet, :base}, "ConvNextForImageClassification" => {Bumblebee.Vision.ConvNext, :for_image_classification}, "ConvNextModel" => {Bumblebee.Vision.ConvNext, :base}, @@ -229,6 +232,8 @@ defmodule Bumblebee do @transformers_class_to_featurizer %{ "CLIPFeatureExtractor" => Bumblebee.Vision.ClipFeaturizer, + "SiglipImageProcessor" => Bumblebee.Vision.ClipFeaturizer, + "Siglip2ImageProcessor" => Bumblebee.Vision.ClipFeaturizer, "ConvNextFeatureExtractor" => Bumblebee.Vision.ConvNextFeaturizer, "DeiTFeatureExtractor" => Bumblebee.Vision.DeitFeaturizer, "ViTFeatureExtractor" => Bumblebee.Vision.VitFeaturizer, @@ -257,6 +262,7 @@ defmodule Bumblebee do "distilbert" => :distilbert, "camembert" => :camembert, "clip" => :clip, + "siglip" => :siglip, "gemma" => :gemma, "gpt_neox" => :gpt_neo_x, "gpt2" => :gpt2, diff --git a/lib/bumblebee/multimodal/siglip.ex b/lib/bumblebee/multimodal/siglip.ex new file mode 100644 index 00000000..49fde854 --- /dev/null +++ b/lib/bumblebee/multimodal/siglip.ex @@ -0,0 +1,250 @@ +defmodule Bumblebee.Multimodal.SigLip do + alias Bumblebee.Shared + + options = + [ + text_spec: [ + default: nil, + doc: "the specification of the text model. See `Bumblebee.Text.SigLipText` for details" + ], + vision_spec: [ + default: nil, + doc: + "the specification of the vision model. See `Bumblebee.Vision.SigLipVision` for details" + ], + logit_scale_initial_value: [ + default: 2.6592, + doc: "the initial value for the scaling layer used to scale similarity logits" + ], + logit_bias_initial_value: [ + default: -10.0, + doc: "the initial value for the bias added to similarity logits" + ] + ] + + @moduledoc """ + The SigLIP model for text-image similarity. + + SigLIP uses a sigmoid loss function instead of the contrastive loss used + by CLIP, which allows for better scaling and more stable training. + + ## Architectures + + * `:base` - the base SigLIP model + + ## Inputs + + * `"input_ids"` - `{batch_size, sequence_length}` + + Indices of input sequence tokens in the vocabulary. + + * `"attention_mask"` - `{batch_size, sequence_length}` + + Mask indicating which tokens to attend to. This is used to ignore + padding tokens, which are added when processing a batch of sequences + with different length. + + * `"position_ids"` - `{batch_size, sequence_length}` + + Indices of positions of each input sequence tokens in the position + embeddings. + + * `"pixel_values"` - `{batch_size, image_size, image_size, num_channels}` + + Featurized image pixel values. + + ## Global layer options + + #{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])} + + ## Configuration + + #{Shared.options_doc(options)} + + ## References + + * [Sigmoid Loss for Language Image Pre-Training](https://arxiv.org/abs/2303.15343) + + """ + + defstruct [architecture: :base] ++ Shared.option_defaults(options) + + @behaviour Bumblebee.ModelSpec + @behaviour Bumblebee.Configurable + + alias Bumblebee.Layers + + @impl true + def architectures(), do: [:base] + + @impl true + def config(spec, opts) do + Shared.put_config_attrs(spec, opts) + end + + @impl true + def input_template(%{vision_spec: vision_spec}) do + vision_shape = {1, vision_spec.image_size, vision_spec.image_size, vision_spec.num_channels} + + %{ + "input_ids" => Nx.template({1, 1}, :u32), + "pixel_values" => Nx.template(vision_shape, :f32) + } + end + + @impl true + def model(%__MODULE__{architecture: :base} = spec) do + %{text_spec: text_spec, vision_spec: vision_spec} = spec + + text_shape = {nil, nil} + vision_shape = {nil, vision_spec.image_size, vision_spec.image_size, vision_spec.num_channels} + + inputs = + Bumblebee.Utils.Model.inputs_to_map([ + Axon.input("input_ids", shape: text_shape), + Axon.input("attention_mask", optional: true, shape: text_shape), + Axon.input("position_ids", optional: true, shape: text_shape), + Axon.input("pixel_values", shape: vision_shape) + ]) + + text_model = + text_spec + |> Bumblebee.build_model() + |> Bumblebee.Utils.Axon.prefix_names("text_model.") + |> Bumblebee.Utils.Axon.plug_inputs(%{ + "input_ids" => inputs["input_ids"], + "attention_mask" => inputs["attention_mask"], + "position_ids" => inputs["position_ids"] + }) + + vision_model = + vision_spec + |> Bumblebee.build_model() + |> Bumblebee.Utils.Axon.prefix_names("vision_model.") + |> Bumblebee.Utils.Axon.plug_inputs(%{ + "pixel_values" => inputs["pixel_values"] + }) + + text_embedding = + text_model + |> Axon.nx(& &1.pooled_state) + |> Axon.nx(&normalize/1) + + image_embedding = + vision_model + |> Axon.nx(& &1.pooled_state) + |> Axon.nx(&normalize/1) + + similarity = Layers.cosine_similarity(text_embedding, image_embedding) + + logits_per_text = + similarity + |> scale_layer( + name: "logit_scale", + initializer: Axon.Initializers.full(spec.logit_scale_initial_value) + ) + |> bias_layer( + name: "logit_bias", + initializer: Axon.Initializers.full(spec.logit_bias_initial_value) + ) + + logits_per_image = Axon.transpose(logits_per_text) + + Layers.output(%{ + logits_per_text: logits_per_text, + logits_per_image: logits_per_image, + text_embedding: text_embedding, + image_embedding: image_embedding + }) + end + + defp normalize(x) do + Nx.divide(x, Nx.LinAlg.norm(x, ord: 2, axes: [-1], keep_axes: true)) + end + + defp scale_layer(input, opts) do + name = opts[:name] + initializer = opts[:initializer] || Axon.Initializers.full(1.0) + + scale_param = Axon.param("scale", fn _ -> {} end, initializer: initializer) + + Axon.layer( + fn input, scale, _opts -> + Nx.multiply(input, Nx.exp(scale)) + end, + [input, scale_param], + name: name, + op_name: :logit_scale + ) + end + + defp bias_layer(input, opts) do + name = opts[:name] + initializer = opts[:initializer] || Axon.Initializers.full(0.0) + + bias_param = Axon.param("bias", fn _ -> {} end, initializer: initializer) + + Axon.layer( + fn input, bias, _opts -> + Nx.add(input, bias) + end, + [input, bias_param], + name: name, + op_name: :logit_bias + ) + end + + defimpl Bumblebee.HuggingFace.Transformers.Config do + def load(spec, data) do + import Shared.Converters + + {text_data, data} = Map.pop(data, "text_config", %{}) + {vision_data, data} = Map.pop(data, "vision_config", %{}) + + text_spec = + Bumblebee.Text.SigLipText + |> Bumblebee.configure() + |> Bumblebee.HuggingFace.Transformers.Config.load(text_data) + + vision_spec = + Bumblebee.Vision.SigLipVision + |> Bumblebee.configure() + |> Bumblebee.HuggingFace.Transformers.Config.load(vision_data) + + opts = + convert!(data, + logit_scale_initial_value: {"logit_scale_init_value", number()}, + logit_bias_initial_value: {"logit_bias_init_value", number()} + ) + + @for.config(spec, opts ++ [text_spec: text_spec, vision_spec: vision_spec]) + end + end + + defimpl Bumblebee.HuggingFace.Transformers.Model do + alias Bumblebee.HuggingFace.Transformers + + def params_mapping(spec) do + text_mapping = + spec.text_spec + |> Transformers.Model.params_mapping() + |> Transformers.Utils.prefix_params_mapping("text_model", nil) + + vision_mapping = + spec.vision_spec + |> Transformers.Model.params_mapping() + |> Transformers.Utils.prefix_params_mapping("vision_model", nil) + + %{ + "logit_scale" => %{ + "scale" => {[{"logit_scale", "logit_scale"}], fn [scale] -> Nx.squeeze(scale) end} + }, + "logit_bias" => %{ + "bias" => {[{"logit_bias", "logit_bias"}], fn [bias] -> Nx.squeeze(bias) end} + } + } + |> Map.merge(text_mapping) + |> Map.merge(vision_mapping) + end + end +end diff --git a/lib/bumblebee/text/siglip_text.ex b/lib/bumblebee/text/siglip_text.ex new file mode 100644 index 00000000..f29269c2 --- /dev/null +++ b/lib/bumblebee/text/siglip_text.ex @@ -0,0 +1,261 @@ +defmodule Bumblebee.Text.SigLipText do + alias Bumblebee.Shared + + options = + [ + vocab_size: [ + default: 32000, + doc: """ + the vocabulary size of the token embedding. This corresponds to the number of distinct + tokens that can be represented in model input and output + """ + ], + max_positions: [ + default: 64, + doc: """ + the vocabulary size of the position embedding. This corresponds to the maximum sequence + length that this model can process. Typically this is set to a large value just in case, + such as 512, 1024 or 2048 + """ + ], + hidden_size: [ + default: 768, + doc: "the dimensionality of hidden layers" + ], + num_blocks: [ + default: 12, + doc: "the number of Transformer blocks in the encoder" + ], + num_attention_heads: [ + default: 12, + doc: "the number of attention heads for each attention layer in the encoder" + ], + intermediate_size: [ + default: 3072, + doc: + "the dimensionality of the intermediate layer in the transformer feed-forward network (FFN) in the encoder" + ], + activation: [ + default: :gelu_approx_tanh, + doc: "the activation function" + ], + attention_dropout_rate: [ + default: 0.0, + doc: "the dropout rate for attention weights" + ], + layer_norm_epsilon: [ + default: 1.0e-6, + doc: "the epsilon used by the layer normalization layers" + ] + ] ++ Shared.token_options(pad_token_id: 1) + + @moduledoc """ + The SigLIP model for text encoding. + + ## Architectures + + * `:base` - the base text model + + ## Inputs + + * `"input_ids"` - `{batch_size, sequence_length}` + + Indices of input sequence tokens in the vocabulary. + + * `"attention_mask"` - `{batch_size, sequence_length}` + + Mask indicating which tokens to attend to. This is used to ignore + padding tokens, which are added when processing a batch of sequences + with different length. + + + * `"position_ids"` - `{batch_size, sequence_length}` + + Indices of positions of each input sequence tokens in the position + embeddings. + + ## Global layer options + + #{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])} + + ## Configuration + + #{Shared.options_doc(options)} + + ## References + + * [Sigmoid Loss for Language Image Pre-Training](https://arxiv.org/abs/2303.15343) + + """ + + defstruct [architecture: :base] ++ Shared.option_defaults(options) + + @behaviour Bumblebee.ModelSpec + @behaviour Bumblebee.Configurable + + import Bumblebee.Utils.Model, only: [join: 2] + + alias Bumblebee.Layers + + @impl true + def architectures(), do: [:base] + + @impl true + def config(spec, opts) do + Shared.put_config_attrs(spec, opts) + end + + @impl true + def input_template(_spec) do + %{ + "input_ids" => Nx.template({1, 1}, :u32) + } + end + + @impl true + def model(%__MODULE__{architecture: :base} = spec) do + inputs = inputs() + + inputs + |> core(spec) + |> Layers.output() + end + + defp inputs() do + shape = {nil, nil} + + Bumblebee.Utils.Model.inputs_to_map([ + Axon.input("input_ids", shape: shape), + Axon.input("attention_mask", optional: true, shape: shape), + Axon.input("position_ids", optional: true, shape: shape) + ]) + end + + defp core(inputs, spec) do + input_ids = inputs["input_ids"] + + embeddings = embedder(input_ids, inputs["position_ids"], spec, name: "embedder") + encoder_outputs = encoder(embeddings, inputs["attention_mask"], spec, name: "encoder") + + hidden_state = + Axon.layer_norm( + encoder_outputs.hidden_state, + epsilon: spec.layer_norm_epsilon, + name: "norm" + ) + + # Take the last token for pooling (SigLIP pools from last position) + last_token = + Axon.nx(hidden_state, fn x -> + seq_len = Nx.axis_size(x, 1) + x |> Nx.slice_along_axis(seq_len - 1, 1, axis: 1) |> Nx.squeeze(axes: [1]) + end) + + pooled_state = + Axon.dense(last_token, spec.hidden_size, + kernel_initializer: Axon.Initializers.normal(), + name: "head" + ) + + %{ + hidden_state: hidden_state, + pooled_state: pooled_state, + hidden_states: encoder_outputs.hidden_states, + attentions: encoder_outputs.attentions + } + end + + defp embedder(input_ids, position_ids, spec, opts) do + name = opts[:name] + + position_ids = + Layers.default position_ids do + Layers.default_position_ids(input_ids) + end + + input_embeddings = + Axon.embedding(input_ids, spec.vocab_size, spec.hidden_size, + kernel_initializer: Axon.Initializers.normal(), + name: join(name, "token_embedding") + ) + + position_embeddings = + Axon.embedding(position_ids, spec.max_positions, spec.hidden_size, + kernel_initializer: Axon.Initializers.normal(), + name: join(name, "position_embedding") + ) + + Axon.add(input_embeddings, position_embeddings) + end + + defp encoder(embeddings, attention_mask, spec, opts) do + name = opts[:name] + + Layers.Transformer.blocks(embeddings, + attention_mask: attention_mask, + num_blocks: spec.num_blocks, + num_attention_heads: spec.num_attention_heads, + hidden_size: spec.hidden_size, + kernel_initializer: Axon.Initializers.normal(scale: 0.01), + dropout_rate: 0.0, + attention_dropout_rate: spec.attention_dropout_rate, + layer_norm: [ + epsilon: spec.layer_norm_epsilon + ], + ffn: [ + intermediate_size: spec.intermediate_size, + activation: spec.activation + ], + block_type: :norm_first, + name: join(name, "blocks") + ) + end + + defimpl Bumblebee.HuggingFace.Transformers.Config do + def load(spec, %{"model_type" => "siglip", "text_config" => data}) do + load(spec, data) + end + + def load(spec, data) do + import Shared.Converters + + opts = + convert!(data, + vocab_size: {"vocab_size", number()}, + max_positions: {"max_position_embeddings", number()}, + hidden_size: {"hidden_size", number()}, + num_blocks: {"num_hidden_layers", number()}, + num_attention_heads: {"num_attention_heads", number()}, + intermediate_size: {"intermediate_size", number()}, + activation: {"hidden_act", activation()}, + attention_dropout_rate: {"attention_dropout", number()}, + layer_norm_epsilon: {"layer_norm_eps", number()} + ) ++ Shared.common_options_from_transformers(data, spec) + + @for.config(spec, opts) + end + end + + defimpl Bumblebee.HuggingFace.Transformers.Model do + def params_mapping(_spec) do + %{ + "embedder.token_embedding" => "text_model.embeddings.token_embedding", + "embedder.position_embedding" => "text_model.embeddings.position_embedding", + "encoder.blocks.{n}.self_attention.query" => + "text_model.encoder.layers.{n}.self_attn.q_proj", + "encoder.blocks.{n}.self_attention.key" => + "text_model.encoder.layers.{n}.self_attn.k_proj", + "encoder.blocks.{n}.self_attention.value" => + "text_model.encoder.layers.{n}.self_attn.v_proj", + "encoder.blocks.{n}.self_attention.output" => + "text_model.encoder.layers.{n}.self_attn.out_proj", + "encoder.blocks.{n}.self_attention_norm" => "text_model.encoder.layers.{n}.layer_norm1", + "encoder.blocks.{n}.ffn.intermediate" => "text_model.encoder.layers.{n}.mlp.fc1", + "encoder.blocks.{n}.ffn.output" => "text_model.encoder.layers.{n}.mlp.fc2", + "encoder.blocks.{n}.output_norm" => "text_model.encoder.layers.{n}.layer_norm2", + "norm" => "text_model.final_layer_norm", + "head" => "text_model.head" + } + end + end +end diff --git a/lib/bumblebee/vision/siglip_vision.ex b/lib/bumblebee/vision/siglip_vision.ex new file mode 100644 index 00000000..cfefa9b0 --- /dev/null +++ b/lib/bumblebee/vision/siglip_vision.ex @@ -0,0 +1,438 @@ +defmodule Bumblebee.Vision.SigLipVision do + alias Bumblebee.Shared + + options = + [ + image_size: [ + default: 224, + doc: "the size of the input spatial dimensions" + ], + num_channels: [ + default: 3, + doc: "the number of channels in the input" + ], + patch_size: [ + default: 16, + doc: "the size of the patch spatial dimensions" + ], + hidden_size: [ + default: 768, + doc: "the dimensionality of hidden layers" + ], + num_blocks: [ + default: 12, + doc: "the number of Transformer blocks in the encoder" + ], + num_attention_heads: [ + default: 12, + doc: "the number of attention heads for each attention layer in the encoder" + ], + intermediate_size: [ + default: 3072, + doc: + "the dimensionality of the intermediate layer in the transformer feed-forward network (FFN) in the encoder" + ], + activation: [ + default: :gelu_approx_tanh, + doc: "the activation function" + ], + attention_dropout_rate: [ + default: 0.0, + doc: "the dropout rate for attention weights" + ], + layer_norm_epsilon: [ + default: 1.0e-6, + doc: "the epsilon used by the layer normalization layers" + ] + ] ++ Shared.common_options([:num_labels, :id_to_label]) + + @moduledoc """ + The SigLIP model for image encoding. + + ## Architectures + + * `:base` - the base image model + + * `:for_image_classification` - SigLIP vision encoder with a classification + head. The head consists of a single dense layer on top of the mean-pooled + patch embeddings + + ## Inputs + + * `"pixel_values"` - `{batch_size, image_size, image_size, num_channels}` + + Featurized image pixel values. + + ## Global layer options + + #{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])} + + ## Configuration + + #{Shared.options_doc(options)} + + ## References + + * [Sigmoid Loss for Language Image Pre-Training](https://arxiv.org/abs/2303.15343) + + """ + + defstruct [architecture: :base] ++ Shared.option_defaults(options) + + @behaviour Bumblebee.ModelSpec + @behaviour Bumblebee.Configurable + + import Bumblebee.Utils.Model, only: [join: 2] + + alias Bumblebee.Layers + + @impl true + def architectures(), do: [:base, :for_image_classification] + + @impl true + def config(spec, opts) do + spec + |> Shared.put_config_attrs(opts) + |> Shared.validate_label_options() + end + + @impl true + def input_template(spec) do + %{ + "pixel_values" => + Nx.template({1, spec.image_size, spec.image_size, spec.num_channels}, :f32) + } + end + + @impl true + def model(%__MODULE__{architecture: :base} = spec) do + inputs = inputs(spec) + + inputs + |> core(spec) + |> Layers.output() + end + + def model(%__MODULE__{architecture: :for_image_classification} = spec) do + inputs = inputs(spec) + outputs = core(inputs, spec) + + pooled = + Axon.nx(outputs.hidden_state, fn hidden_state -> + Nx.mean(hidden_state, axes: [1]) + end) + + logits = + Axon.dense(pooled, spec.num_labels, + kernel_initializer: Axon.Initializers.normal(), + name: "image_classification_head.output" + ) + + Layers.output(%{ + logits: logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions + }) + end + + defp inputs(spec) do + shape = {nil, spec.image_size, spec.image_size, spec.num_channels} + + Bumblebee.Utils.Model.inputs_to_map([ + Axon.input("pixel_values", shape: shape) + ]) + end + + defp core(inputs, spec) do + embeddings = embedder(inputs["pixel_values"], spec, name: "embedder") + + encoder_outputs = encoder(embeddings, spec, name: "encoder") + + hidden_state = + Axon.layer_norm(encoder_outputs.hidden_state, + epsilon: spec.layer_norm_epsilon, + name: "post_norm" + ) + + pooled_state = attention_pooling_head(hidden_state, spec, name: "head") + + %{ + hidden_state: hidden_state, + pooled_state: pooled_state, + hidden_states: encoder_outputs.hidden_states, + attentions: encoder_outputs.attentions + } + end + + defp embedder(pixel_values, spec, opts) do + name = opts[:name] + + patch_embeddings = patch_embedding(pixel_values, spec, name: join(name, "patch_embedding")) + + num_patches = div(spec.image_size, spec.patch_size) ** 2 + position_ids = position_ids(num_patches) + + position_embeddings = + Axon.embedding(position_ids, num_patches, spec.hidden_size, + name: join(name, "position_embedding") + ) + + Axon.add(patch_embeddings, position_embeddings) + end + + defp patch_embedding(pixel_values, spec, opts) do + name = opts[:name] + + pixel_values + |> Axon.conv(spec.hidden_size, + kernel_size: spec.patch_size, + strides: spec.patch_size, + padding: :valid, + kernel_initializer: Axon.Initializers.normal(), + name: name + ) + |> Axon.reshape({:batch, :auto, spec.hidden_size}, name: join(name, "reshape")) + end + + defp position_ids(num_position_ids) do + Axon.layer( + fn _opts -> Nx.iota({1, num_position_ids}) end, + [], + op_name: :position_ids + ) + end + + defp encoder(embeddings, spec, opts) do + name = opts[:name] + + Layers.Transformer.blocks(embeddings, + num_blocks: spec.num_blocks, + num_attention_heads: spec.num_attention_heads, + hidden_size: spec.hidden_size, + kernel_initializer: Axon.Initializers.normal(scale: 0.01), + dropout_rate: 0.0, + attention_dropout_rate: spec.attention_dropout_rate, + layer_norm: [ + epsilon: spec.layer_norm_epsilon + ], + ffn: [ + intermediate_size: spec.intermediate_size, + activation: spec.activation + ], + block_type: :norm_first, + name: join(name, "blocks") + ) + end + + defp attention_pooling_head(hidden_state, spec, opts) do + name = opts[:name] + + probe = + Layers.learned_embeddings(1, spec.hidden_size, + name: join(name, "probe"), + initializer: Axon.Initializers.normal() + ) + + attended = + multihead_attention(probe, hidden_state, spec, name: join(name, "attention")) + + attended = Axon.nx(attended, fn x -> Nx.squeeze(x, axes: [1]) end) + + normed = + Axon.layer_norm(attended, epsilon: spec.layer_norm_epsilon, name: join(name, "norm")) + + mlp_output = + normed + |> Axon.dense(spec.intermediate_size, + kernel_initializer: Axon.Initializers.normal(), + name: join(name, "mlp.intermediate") + ) + |> Layers.activation(spec.activation) + |> Axon.dense(spec.hidden_size, + kernel_initializer: Axon.Initializers.normal(), + name: join(name, "mlp.output") + ) + + Axon.add(attended, mlp_output) + end + + defp multihead_attention(query, key_value, spec, opts) do + name = opts[:name] + num_heads = spec.num_attention_heads + head_dim = div(spec.hidden_size, num_heads) + + q = + Axon.dense(query, spec.hidden_size, + kernel_initializer: Axon.Initializers.normal(), + name: join(name, "query") + ) + + k = + Axon.dense(key_value, spec.hidden_size, + kernel_initializer: Axon.Initializers.normal(), + name: join(name, "key") + ) + + v = + Axon.dense(key_value, spec.hidden_size, + kernel_initializer: Axon.Initializers.normal(), + name: join(name, "value") + ) + + q = Axon.nx(q, fn x -> reshape_for_attention(x, num_heads, head_dim) end) + k = Axon.nx(k, fn x -> reshape_for_attention(x, num_heads, head_dim) end) + v = Axon.nx(v, fn x -> reshape_for_attention(x, num_heads, head_dim) end) + + scale = :math.sqrt(head_dim) + + attention_output = + Axon.layer( + fn q, k, v, _opts -> + # Broadcast q to match k's batch size (for attention pooling head) + {batch_k, _, _, _} = Nx.shape(k) + {batch_q, heads, seq_q, head_d} = Nx.shape(q) + + q = + if batch_q == 1 and batch_k > 1 do + Nx.broadcast(q, {batch_k, heads, seq_q, head_d}) + else + q + end + + scores = Nx.dot(q, [3], [0, 1], k, [3], [0, 1]) |> Nx.divide(scale) + weights = Axon.Activations.softmax(scores, axis: -1) + Nx.dot(weights, [3], [0, 1], v, [2], [0, 1]) + end, + [q, k, v], + name: join(name, "attention"), + op_name: :attention + ) + + attention_output = + Axon.nx(attention_output, fn x -> + {batch, _heads, seq_len, _head_dim} = Nx.shape(x) + + Nx.transpose(x, axes: [0, 2, 1, 3]) + |> Nx.reshape({batch, seq_len, spec.hidden_size}) + end) + + Axon.dense(attention_output, spec.hidden_size, + kernel_initializer: Axon.Initializers.normal(), + name: join(name, "output") + ) + end + + defp reshape_for_attention(x, num_heads, head_dim) do + {batch, seq_len, _hidden} = Nx.shape(x) + x |> Nx.reshape({batch, seq_len, num_heads, head_dim}) |> Nx.transpose(axes: [0, 2, 1, 3]) + end + + defimpl Bumblebee.HuggingFace.Transformers.Config do + def load(spec, %{"model_type" => "siglip", "vision_config" => data}) do + load(spec, data) + end + + def load(spec, data) do + import Shared.Converters + + opts = + convert!(data, + image_size: {"image_size", number()}, + patch_size: {"patch_size", number()}, + hidden_size: {"hidden_size", number()}, + num_blocks: {"num_hidden_layers", number()}, + num_attention_heads: {"num_attention_heads", number()}, + intermediate_size: {"intermediate_size", number()}, + activation: {"hidden_act", activation()}, + attention_dropout_rate: {"attention_dropout", number()}, + layer_norm_epsilon: {"layer_norm_eps", number()} + ) ++ Shared.common_options_from_transformers(data, spec) + + @for.config(spec, opts) + end + end + + defimpl Bumblebee.HuggingFace.Transformers.Model do + def params_mapping(_spec) do + %{ + "embedder.patch_embedding" => "vision_model.embeddings.patch_embedding", + "embedder.position_embedding" => "vision_model.embeddings.position_embedding", + "encoder.blocks.{n}.self_attention_norm" => "vision_model.encoder.layers.{n}.layer_norm1", + "encoder.blocks.{n}.self_attention.query" => + "vision_model.encoder.layers.{n}.self_attn.q_proj", + "encoder.blocks.{n}.self_attention.key" => + "vision_model.encoder.layers.{n}.self_attn.k_proj", + "encoder.blocks.{n}.self_attention.value" => + "vision_model.encoder.layers.{n}.self_attn.v_proj", + "encoder.blocks.{n}.self_attention.output" => + "vision_model.encoder.layers.{n}.self_attn.out_proj", + "encoder.blocks.{n}.ffn.intermediate" => "vision_model.encoder.layers.{n}.mlp.fc1", + "encoder.blocks.{n}.ffn.output" => "vision_model.encoder.layers.{n}.mlp.fc2", + "encoder.blocks.{n}.output_norm" => "vision_model.encoder.layers.{n}.layer_norm2", + "post_norm" => "vision_model.post_layernorm", + "head.probe" => %{ + "embeddings" => { + [{"vision_model.head", "probe"}], + fn [probe] -> Nx.squeeze(probe, axes: [0]) end + } + }, + "head.attention.query" => %{ + "kernel" => { + [{"vision_model.head.attention", "in_proj_weight"}], + fn [kernel] -> + chunk_size = div(Nx.axis_size(kernel, 0), 3) + kernel = Nx.slice_along_axis(kernel, 0, chunk_size, axis: 0) + Nx.transpose(kernel) + end + }, + "bias" => { + [{"vision_model.head.attention", "in_proj_bias"}], + fn [bias] -> + chunk_size = div(Nx.axis_size(bias, 0), 3) + Nx.slice_along_axis(bias, 0, chunk_size, axis: 0) + end + } + }, + "head.attention.key" => %{ + "kernel" => { + [{"vision_model.head.attention", "in_proj_weight"}], + fn [kernel] -> + chunk_size = div(Nx.axis_size(kernel, 0), 3) + kernel = Nx.slice_along_axis(kernel, chunk_size, chunk_size, axis: 0) + Nx.transpose(kernel) + end + }, + "bias" => { + [{"vision_model.head.attention", "in_proj_bias"}], + fn [bias] -> + chunk_size = div(Nx.axis_size(bias, 0), 3) + Nx.slice_along_axis(bias, chunk_size, chunk_size, axis: 0) + end + } + }, + "head.attention.value" => %{ + "kernel" => { + [{"vision_model.head.attention", "in_proj_weight"}], + fn [kernel] -> + chunk_size = div(Nx.axis_size(kernel, 0), 3) + kernel = Nx.slice_along_axis(kernel, 2 * chunk_size, chunk_size, axis: 0) + Nx.transpose(kernel) + end + }, + "bias" => { + [{"vision_model.head.attention", "in_proj_bias"}], + fn [bias] -> + chunk_size = div(Nx.axis_size(bias, 0), 3) + Nx.slice_along_axis(bias, 2 * chunk_size, chunk_size, axis: 0) + end + } + }, + "head.attention.output" => "vision_model.head.attention.out_proj", + "head.norm" => "vision_model.head.layernorm", + "head.mlp.intermediate" => "vision_model.head.mlp.fc1", + "head.mlp.output" => "vision_model.head.mlp.fc2", + "image_classification_head.output" => "classifier" + } + end + end +end diff --git a/mix.exs b/mix.exs index 47b2f90f..ec6f666f 100644 --- a/mix.exs +++ b/mix.exs @@ -85,6 +85,7 @@ defmodule Bumblebee.MixProject do Bumblebee.Multimodal.Blip, Bumblebee.Multimodal.Clip, Bumblebee.Multimodal.LayoutLm, + Bumblebee.Multimodal.SigLip, Bumblebee.Text.Albert, Bumblebee.Text.Bart, Bumblebee.Text.Bert, @@ -104,6 +105,7 @@ defmodule Bumblebee.MixProject do Bumblebee.Text.Phi, Bumblebee.Text.Phi3, Bumblebee.Text.Roberta, + Bumblebee.Text.SigLipText, Bumblebee.Text.SmolLM3, Bumblebee.Text.T5, Bumblebee.Vision.BlipVision, @@ -112,6 +114,7 @@ defmodule Bumblebee.MixProject do Bumblebee.Vision.Deit, Bumblebee.Vision.DinoV2, Bumblebee.Vision.ResNet, + Bumblebee.Vision.SigLipVision, Bumblebee.Vision.Swin, Bumblebee.Vision.Vit ], diff --git a/test/bumblebee/multimodal/siglip_test.exs b/test/bumblebee/multimodal/siglip_test.exs new file mode 100644 index 00000000..0aea04db --- /dev/null +++ b/test/bumblebee/multimodal/siglip_test.exs @@ -0,0 +1,45 @@ +defmodule Bumblebee.Multimodal.SigLipTest do + use ExUnit.Case, async: true + + import Bumblebee.TestHelpers + + @moduletag model_test_tags() + + test ":base" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "katuni4ka/tiny-random-SiglipModel"}) + + assert %Bumblebee.Multimodal.SigLip{architecture: :base} = spec + + # Image size is 30x30 for this tiny model + inputs = %{ + "input_ids" => + Nx.tensor([ + [10, 20, 30, 40, 50, 60, 70, 80, 1, 1], + [15, 25, 35, 45, 55, 65, 75, 85, 1, 1] + ]), + "pixel_values" => + Nx.concatenate([ + Nx.broadcast(0.25, {1, 30, 30, 3}), + Nx.broadcast(0.75, {1, 30, 30, 3}) + ]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits_per_text) == {2, 2} + assert Nx.shape(outputs.logits_per_image) == {2, 2} + + assert_all_close( + outputs.logits_per_text, + Nx.tensor([[-0.0626, -0.0771], [-0.0961, -0.1548]]), + atol: 1.0e-3 + ) + + assert_all_close( + outputs.logits_per_image, + Nx.tensor([[-0.0626, -0.0961], [-0.0771, -0.1548]]), + atol: 1.0e-3 + ) + end +end