Diffusers documentation

Krea2Transformer2DModel

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v0.38.0).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Krea2Transformer2DModel

The single-stream MMDiT flow-matching transformer used by Krea 2.

Krea2Transformer2DModel

class diffusers.Krea2Transformer2DModel

< >

( in_channels: int = 64 num_layers: int = 28 attention_head_dim: int = 128 num_attention_heads: int = 48 num_key_value_heads: int = 12 intermediate_size: int = 16384 timestep_embed_dim: int = 256 text_hidden_dim: int = 2560 num_text_layers: int = 12 text_num_attention_heads: int = 20 text_num_key_value_heads: int = 20 text_intermediate_size: int = 6912 num_layerwise_text_blocks: int = 2 num_refiner_text_blocks: int = 2 axes_dims_rope: tuple = (32, 48, 48) rope_theta: float = 1000.0 norm_eps: float = 1e-05 )

Parameters

  • in_channels (int, defaults to 64) — Latent channel count after patchification (vae_channels * patch_size ** 2).
  • num_layers (int, defaults to 28) — Number of transformer blocks.
  • attention_head_dim (int, defaults to 128) — Dimension of each attention head; the total hidden size is attention_head_dim * num_attention_heads.
  • num_attention_heads (int, defaults to 48) — Number of query heads.
  • num_key_value_heads (int, defaults to 12) — Number of key/value heads for grouped-query attention.
  • intermediate_size (int, defaults to 16384) — Feed-forward hidden size of the SwiGLU MLP inside each block.
  • timestep_embed_dim (int, defaults to 256) — Width of the sinusoidal timestep embedding before its MLP.
  • text_hidden_dim (int, defaults to 2560) — Hidden size of the text encoder whose hidden states are consumed.
  • num_text_layers (int, defaults to 12) — Number of tapped text-encoder hidden states stacked per token.
  • text_num_attention_heads (int, defaults to 20) — Number of query heads in the text fusion blocks.
  • text_num_key_value_heads (int, defaults to 20) — Number of key/value heads in the text fusion blocks.
  • text_intermediate_size (int, defaults to 6912) — Feed-forward hidden size of the SwiGLU MLP inside the text fusion blocks.
  • num_layerwise_text_blocks (int, defaults to 2) — Number of text fusion blocks applied across the tapped-layer axis (per token).
  • num_refiner_text_blocks (int, defaults to 2) — Number of text fusion blocks applied across the token sequence.
  • axes_dims_rope (tuple[int, int, int], defaults to (32, 48, 48)) — Head-dim split across the (t, h, w) rotary position axes.
  • rope_theta (float, defaults to 1000.0) — Base used by the rotary position embedding.
  • norm_eps (float, defaults to 1e-5) — Epsilon used by all RMSNorm modules.

The single-stream MMDiT flow-matching backbone used by the Krea 2 pipeline.

Text conditioning enters as a stack of hidden states tapped from several layers of a multimodal text encoder. A small text-fusion transformer collapses the layer axis and refines the token sequence; the result is concatenated with the patchified image latents into a single [text, image] sequence processed by the transformer blocks. The timestep conditions every block through one shared modulation vector plus per-block learned tables.

forward

< >

( hidden_states: Tensor encoder_hidden_states: Tensor timestep: Tensor position_ids: Tensor encoder_attention_mask: torch.Tensor | None = None return_dict: bool = True )

Parameters

  • hidden_states (torch.Tensor of shape (batch_size, image_seq_len, in_channels)) — Packed (patchified) noisy image latents.
  • encoder_hidden_states (torch.Tensor of shape (batch_size, text_seq_len, num_text_layers, text_hidden_dim)) — Stack of tapped text-encoder hidden states per token.
  • timestep (torch.Tensor of shape (batch_size,)) — Flow-matching time in [0, 1] (1 is pure noise, 0 is clean data).
  • position_ids (torch.Tensor of shape (text_seq_len + image_seq_len, 3)) — (t, h, w) rotary coordinates for the combined sequence. Text rows are all-zero; image rows hold the latent-grid coordinates.
  • encoder_attention_mask (torch.Tensor of shape (batch_size, text_seq_len), optional) — Boolean mask marking valid text tokens. Pass None when every text token is valid.
  • return_dict (bool, optional, defaults to True) — Whether to return a Transformer2DModelOutput instead of a plain tuple.

Predict the flow-matching velocity for the image tokens.

Update on GitHub