| from transformers import pipeline | |
| from vista3d_config import VISTA3DConfig | |
| from vista3d_model import VISTA3DModel, register_my_model | |
| from vista3d_pipeline import VISTA3DPipeline, register_simple_pipeline | |
| class HuggingFacePipelineHelper: | |
| def __init__(self, pipeline_name: str = "vista3d"): | |
| self.pipeline_name = pipeline_name | |
| def __model_register(self): | |
| register_my_model() | |
| def __pipeline_register(self): | |
| register_simple_pipeline() | |
| def get_pipeline(self): | |
| self.__model_register() | |
| self.__pipeline_register() | |
| return pipeline(self.pipeline_name) | |
| def _update_config(self, config, config_dict): | |
| if config_dict: | |
| for key in config_dict: | |
| if hasattr(config, key) and getattr(config, key) != config_dict[key]: | |
| setattr(config, key, config_dict[key]) | |
| return config | |
| def init_pipeline(self, pretrained_model_name_or_path: str, **kwargs): | |
| config = VISTA3DConfig() | |
| config_dict = kwargs.pop("config_dict", None) | |
| self._update_config(config, config_dict) | |
| model = VISTA3DModel(config) | |
| model = model.from_pretrained( | |
| pretrained_model_name_or_path=pretrained_model_name_or_path | |
| ) | |
| return VISTA3DPipeline(model, **kwargs) | |