| | --- |
| | license: cc-by-sa-4.0 |
| | datasets: |
| | - Smith42/galaxies |
| | - Smith42/galaxies_metadata |
| | - Smith42/galaxies_embeddings |
| | tags: |
| | - astronomy |
| | - images |
| | - huggingscience |
| | - science |
| | --- |
| | <center> |
| | <img src="assets/shoggoth_telescope_sticker_2.png" alt="astroPT_shoggoth" width="300px"/> |
| | </center> |
| |
|
| | # astroPTv2.0: a Large Observation Model for Astronomy |
| |
|
| | Here we have the model files for the astroPT project, the code to run inference |
| | with these models is found here: |
| | [https://github.com/smith42/astropt](https://github.com/smith42/astropt) |
| |
|
| | You will find the fully trained models (pretrained on 8.6 million galaxies) in |
| | folders labelled with the model parameter count in the `astropt` directory. |
| |
|
| | Unlike the older models which were trained on the "image" column in |
| | [smith42/galaxies](https://huggingface.co/datasets/Smith42/galaxies), |
| | these models are trained on the "cropped" galaxies from the "image_crop" |
| | column. Those galaxies have been cropped and zoomed so that they take up the |
| | majority of each image before uploading. |
| | |
| | We get some promising scaling on this new dataset, see below: |
| | |
| | <center> |
| | <img src="assets/scaling_law.png" alt="scaling_law" width="300px"/> |
| | </center> |
| | |
| | ## Usage |
| | |
| | To use these models in anger you can `pip install astropt` and run the following code: |
| | |
| | ```python |
| | from astropt.model_utils import load_astropt |
| | from astropt.local_datasets import GalaxyImageDataset |
| |
|
| | from datasets import load_dataset # for Smith42/galaxies |
| | |
| | import torch |
| | import numpy as np |
| | from functools import partial |
| | from torch.utils.data import DataLoader |
| | from torchvision import transforms |
| | |
| | # boilerplate to preprocess galaxy images |
| | def normalise(x): |
| | std, mean = torch.std_mean(x, dim=1, keepdim=True) |
| | return (x - mean) / (std + 1e-8) |
| | |
| | def data_transforms(): |
| | return transforms.Compose([transforms.Lambda(normalise)]) |
| | |
| | def _process_galaxy_wrapper(idx, func): |
| | """This function ensures that the image is tokenised in the same way as the pre-trained model is expecting""" |
| | galaxy = func( |
| | torch.from_numpy(np.array(idx["image"]).swapaxes(0, 2)).to(float) |
| | ).to(torch.float) |
| | galaxy_positions = torch.arange(0, len(galaxy), dtype=torch.long) |
| | return { |
| | "images": galaxy, |
| | "images_positions": galaxy_positions, |
| | } |
| | |
| | # for 095M parameter model, 015M and 850M models are also available: |
| | model = load_astropt("Smith42/astroPT_v2.0", path="astropt/095M") |
| |
|
| | galproc = GalaxyImageDataset( |
| | None, |
| | spiral=True, |
| | transform={"images": data_transforms()}, |
| | modality_registry=model.modality_registry |
| | ) |
| | |
| | ds = ( |
| | load_dataset("Smith42/galaxies", split="test", revision="v2.0", streaming=True) |
| | .select_columns("image") |
| | .map(partial(_process_galaxy_wrapper, func=galproc.process_galaxy)) |
| | .with_format("torch") |
| | ) |
| |
|
| | dl = iter(DataLoader(ds, batch_size=128, num_workers=32)) |
| |
|
| | zs = [] |
| | for B in dl: |
| | zs.append(model.generate_embeddings(B)["images"].detach().numpy()) |
| | zs = np.concatenate(zs) |
| | |
| | # do cool stuff with zs... |
| | ``` |
| | |
| | |
| | ## Updates and community |
| | |
| | AstroPT is an open-to-all UniverseTBD project. Please join the [UniverseTBD](https://universetbd.org) Discord for updates: [https://discord.gg/MNEVegvfJq](https://discord.gg/MNEVegvfJq) |