Spaces:
Paused
Paused
| # coding=utf-8 | |
| # Copyright 2024 Google LLC | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Model classes for MetricX, modified from the T5 versions in HF.""" | |
| import copy | |
| import dataclasses | |
| from typing import Optional, Tuple, Union | |
| import warnings | |
| import torch | |
| from torch import nn | |
| import transformers | |
| import transformers.modeling_outputs | |
| BaseModelOutput = transformers.modeling_outputs.BaseModelOutput | |
| ModelOutput = transformers.modeling_outputs.ModelOutput | |
| MT5Config = transformers.models.mt5.modeling_mt5.MT5Config | |
| MT5PreTrainedModel = transformers.models.mt5.modeling_mt5.MT5PreTrainedModel | |
| MT5Stack = transformers.models.mt5.modeling_mt5.MT5Stack | |
| __HEAD_MASK_WARNING_MSG = ( | |
| transformers.models.mt5.modeling_mt5.__HEAD_MASK_WARNING_MSG # pylint: disable=protected-access | |
| ) | |
| class MT5ForRegressionOutput(ModelOutput): | |
| loss: Optional[torch.FloatTensor] = None | |
| predictions: torch.FloatTensor = None | |
| class MT5ForRegression(MT5PreTrainedModel): | |
| """MT5 model for regression.""" | |
| def __init__(self, config: MT5Config): | |
| super().__init__(config) | |
| self.model_dim = config.d_model | |
| self.shared = nn.Embedding(config.vocab_size, config.d_model) | |
| encoder_config = copy.deepcopy(config) | |
| encoder_config.is_decoder = False | |
| encoder_config.use_cache = False | |
| encoder_config.is_encoder_decoder = False | |
| self.encoder = MT5Stack(encoder_config, self.shared) | |
| decoder_config = copy.deepcopy(config) | |
| decoder_config.is_decoder = True | |
| decoder_config.is_encoder_decoder = False | |
| decoder_config.num_layers = config.num_decoder_layers | |
| self.decoder = MT5Stack(decoder_config, self.shared) | |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| # Model parallel | |
| self.model_parallel = False | |
| self.device_map = None | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| decoder_attention_mask: Optional[torch.BoolTensor] = None, | |
| head_mask: Optional[torch.FloatTensor] = None, | |
| decoder_head_mask: Optional[torch.FloatTensor] = None, | |
| cross_attn_head_mask: Optional[torch.Tensor] = None, | |
| encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, | |
| past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| decoder_inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.FloatTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple[torch.FloatTensor], MT5ForRegressionOutput]: | |
| use_cache = use_cache if use_cache is not None else self.config.use_cache | |
| return_dict = ( | |
| return_dict if return_dict is not None else self.config.use_return_dict | |
| ) | |
| # FutureWarning: head_mask was separated into two input args - head_mask, | |
| # decoder_head_mask | |
| if head_mask is not None and decoder_head_mask is None: | |
| if self.config.num_layers == self.config.num_decoder_layers: | |
| warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) | |
| decoder_head_mask = head_mask | |
| # Encode if needed (training, first prediction pass) | |
| if encoder_outputs is None: | |
| # Convert encoder inputs in embeddings if needed | |
| encoder_outputs = self.encoder( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| inputs_embeds=inputs_embeds, | |
| head_mask=head_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): | |
| encoder_outputs = BaseModelOutput( | |
| last_hidden_state=encoder_outputs[0], | |
| hidden_states=encoder_outputs[1] | |
| if len(encoder_outputs) > 1 | |
| else None, | |
| attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, | |
| ) | |
| hidden_states = encoder_outputs[0] | |
| if self.model_parallel: | |
| torch.cuda.set_device(self.decoder.first_device) | |
| # Create 1 step of dummy input for the decoder. | |
| batch_size = input_ids.size(0) | |
| decoder_input_ids = torch.LongTensor([0]).repeat(batch_size).reshape(-1, 1) | |
| if torch.cuda.is_available(): | |
| decoder_input_ids = decoder_input_ids.to(torch.device("cuda")) | |
| # Set device for model parallelism | |
| if self.model_parallel: | |
| torch.cuda.set_device(self.decoder.first_device) | |
| hidden_states = hidden_states.to(self.decoder.first_device) | |
| if decoder_input_ids is not None: | |
| decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) | |
| if attention_mask is not None: | |
| attention_mask = attention_mask.to(self.decoder.first_device) | |
| if decoder_attention_mask is not None: | |
| decoder_attention_mask = decoder_attention_mask.to( | |
| self.decoder.first_device | |
| ) | |
| # Decode | |
| decoder_outputs = self.decoder( | |
| input_ids=decoder_input_ids, | |
| attention_mask=decoder_attention_mask, | |
| inputs_embeds=decoder_inputs_embeds, | |
| past_key_values=past_key_values, | |
| encoder_hidden_states=hidden_states, | |
| encoder_attention_mask=attention_mask, | |
| head_mask=decoder_head_mask, | |
| cross_attn_head_mask=cross_attn_head_mask, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| sequence_output = decoder_outputs[0] | |
| # Set device for model parallelism | |
| if self.model_parallel: | |
| torch.cuda.set_device(self.encoder.first_device) | |
| self.lm_head = self.lm_head.to(self.encoder.first_device) | |
| sequence_output = sequence_output.to(self.lm_head.weight.device) | |
| if self.config.tie_word_embeddings: | |
| # Rescale output before projecting on vocab | |
| # See | |
| # https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 | |
| sequence_output = sequence_output * (self.model_dim**-0.5) | |
| lm_logits = self.lm_head(sequence_output) | |
| # 250089 = <extra_id_10> | |
| predictions = lm_logits[:, 0, 250089] | |
| # Clip to 0 to 25 | |
| predictions = torch.clamp(predictions, 0, 25) | |
| loss = None | |
| if labels is not None: | |
| loss_fct = nn.MSELoss() | |
| # move labels to correct device to enable PP | |
| labels = labels.to(predictions.device) | |
| loss = loss_fct(predictions.view(-1), labels.view(-1)) | |
| return MT5ForRegressionOutput( | |
| loss=loss, | |
| predictions=predictions, | |
| ) | |