LutherXD commited on
Commit
1f332cc
·
verified ·
1 Parent(s): a51c64c

add _support_sdpa property to support more transformers version

Browse files
Files changed (1) hide show
  1. modeling_opencua.py +5 -4
modeling_opencua.py CHANGED
@@ -67,6 +67,10 @@ class OpenCUAPreTrainedModel(PreTrainedModel):
67
  _skip_keys_device_placement = "past_key_values"
68
  _supports_flash_attn_2 = True
69
 
 
 
 
 
70
  def _init_weights(self, module):
71
  # important: this ported version of Llava isn't meant for training from scratch - only
72
  # inference and fine-tuning - so the proper init weights code has been removed - the original codebase
@@ -95,11 +99,8 @@ class OpenCUAForConditionalGeneration(OpenCUAPreTrainedModel):
95
  super().__init__(config)
96
  self.vision_tower = Qwen2_5_VisionTransformerPretrainedModel(config.vision_config)
97
  self.language_model = Qwen2ForCausalLM(config.text_config)
 
98
  self.post_init()
99
-
100
- @property
101
- def _supports_sdpa(self):
102
- return self.language_model._supports_sdpa
103
 
104
  # 使用 property 来创建动态属性
105
  @property
 
67
  _skip_keys_device_placement = "past_key_values"
68
  _supports_flash_attn_2 = True
69
 
70
+ supports_gradient_checkpointing = True
71
+
72
+ _supports_sdpa = True
73
+
74
  def _init_weights(self, module):
75
  # important: this ported version of Llava isn't meant for training from scratch - only
76
  # inference and fine-tuning - so the proper init weights code has been removed - the original codebase
 
99
  super().__init__(config)
100
  self.vision_tower = Qwen2_5_VisionTransformerPretrainedModel(config.vision_config)
101
  self.language_model = Qwen2ForCausalLM(config.text_config)
102
+ self._supports_sdpa = True
103
  self.post_init()
 
 
 
 
104
 
105
  # 使用 property 来创建动态属性
106
  @property