import torch import torchvision from torch import nn from transformers import PreTrainedModel # uncomment when writing file from configuration import EffNetPlantDiseaseConfig # create model class class EffNetPlantDiseaseClassification(PreTrainedModel): config_class = EffNetPlantDiseaseConfig def __init__(self, config): super().__init__(config) # get the model architecture from torchvision self.model = torchvision.models.efficientnet_v2_s() # modify the classifier head according to the config self.model.classifier = nn.Sequential( nn.Dropout(p=config.dropout_rate, inplace=True), nn.Linear(in_features=self.model.classifier[-1].in_features, out_features=config.num_classes) ) self.num_classes = config.num_classes self.loss_fn = nn.CrossEntropyLoss() # define forward method to be similar to hugging face model standards def forward(self, image, label=None): logits = self.model(image) loss = None if label is not None: loss = self.loss_fn(logits, label) return {"logits":logits, "loss": loss}