| import torch | |
| import numpy as np | |
| import jax | |
| import jax.numpy as jnp | |
| from transformers import AutoTokenizer | |
| from transformers import FlaxGPT2LMHeadModel | |
| from transformers import GPT2LMHeadModel | |
| model_fx = FlaxGPT2LMHeadModel.from_pretrained("./") | |
| model_pt = GPT2LMHeadModel.from_pretrained("./", from_flax=True) | |
| model_pt.save_pretrained("./") | |
| input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32) | |
| input_ids_pt = torch.tensor(input_ids) | |
| logits_pt = model_pt(input_ids_pt).logits | |
| print(logits_pt) | |
| logits_fx = model_fx(input_ids).logits | |
| print(logits_fx) |