Diffusers documentation
Krea2Transformer2DModel
Krea2Transformer2DModel
The single-stream MMDiT flow-matching transformer used by Krea 2.
Krea2Transformer2DModel
class diffusers.Krea2Transformer2DModel
< source >( in_channels: int = 64 num_layers: int = 28 attention_head_dim: int = 128 num_attention_heads: int = 48 num_key_value_heads: int = 12 intermediate_size: int = 16384 timestep_embed_dim: int = 256 text_hidden_dim: int = 2560 num_text_layers: int = 12 text_num_attention_heads: int = 20 text_num_key_value_heads: int = 20 text_intermediate_size: int = 6912 num_layerwise_text_blocks: int = 2 num_refiner_text_blocks: int = 2 axes_dims_rope: tuple = (32, 48, 48) rope_theta: float = 1000.0 norm_eps: float = 1e-05 )
Parameters
- in_channels (
int, defaults to 64) — Latent channel count after patchification (vae_channels * patch_size ** 2). - num_layers (
int, defaults to 28) — Number of transformer blocks. - attention_head_dim (
int, defaults to 128) — Dimension of each attention head; the total hidden size isattention_head_dim * num_attention_heads. - num_attention_heads (
int, defaults to 48) — Number of query heads. - num_key_value_heads (
int, defaults to 12) — Number of key/value heads for grouped-query attention. - intermediate_size (
int, defaults to 16384) — Feed-forward hidden size of the SwiGLU MLP inside each block. - timestep_embed_dim (
int, defaults to 256) — Width of the sinusoidal timestep embedding before its MLP. - text_hidden_dim (
int, defaults to 2560) — Hidden size of the text encoder whose hidden states are consumed. - num_text_layers (
int, defaults to 12) — Number of tapped text-encoder hidden states stacked per token. - text_num_attention_heads (
int, defaults to 20) — Number of query heads in the text fusion blocks. - text_num_key_value_heads (
int, defaults to 20) — Number of key/value heads in the text fusion blocks. - text_intermediate_size (
int, defaults to 6912) — Feed-forward hidden size of the SwiGLU MLP inside the text fusion blocks. - num_layerwise_text_blocks (
int, defaults to 2) — Number of text fusion blocks applied across the tapped-layer axis (per token). - num_refiner_text_blocks (
int, defaults to 2) — Number of text fusion blocks applied across the token sequence. - axes_dims_rope (
tuple[int, int, int], defaults to(32, 48, 48)) — Head-dim split across the (t, h, w) rotary position axes. - rope_theta (
float, defaults to 1000.0) — Base used by the rotary position embedding. - norm_eps (
float, defaults to 1e-5) — Epsilon used by all RMSNorm modules.
The single-stream MMDiT flow-matching backbone used by the Krea 2 pipeline.
Text conditioning enters as a stack of hidden states tapped from several layers of a multimodal text encoder. A
small text-fusion transformer collapses the layer axis and refines the token sequence; the result is concatenated
with the patchified image latents into a single [text, image] sequence processed by the transformer blocks. The
timestep conditions every block through one shared modulation vector plus per-block learned tables.
forward
< source >( hidden_states: Tensor encoder_hidden_states: Tensor timestep: Tensor position_ids: Tensor encoder_attention_mask: torch.Tensor | None = None return_dict: bool = True )
Parameters
- hidden_states (
torch.Tensorof shape(batch_size, image_seq_len, in_channels)) — Packed (patchified) noisy image latents. - encoder_hidden_states (
torch.Tensorof shape(batch_size, text_seq_len, num_text_layers, text_hidden_dim)) — Stack of tapped text-encoder hidden states per token. - timestep (
torch.Tensorof shape(batch_size,)) — Flow-matching time in[0, 1](1 is pure noise, 0 is clean data). - position_ids (
torch.Tensorof shape(text_seq_len + image_seq_len, 3)) —(t, h, w)rotary coordinates for the combined sequence. Text rows are all-zero; image rows hold the latent-grid coordinates. - encoder_attention_mask (
torch.Tensorof shape(batch_size, text_seq_len), optional) — Boolean mask marking valid text tokens. PassNonewhen every text token is valid. - return_dict (
bool, optional, defaults toTrue) — Whether to return a Transformer2DModelOutput instead of a plain tuple.
Predict the flow-matching velocity for the image tokens.