Doven
update code.
f7009b3
import sys, os
root = os.sep + os.sep.join(__file__.split(os.sep)[1:__file__.split(os.sep).index("Recurrent-Parameter-Generation")+1])
sys.path.append(root)
os.chdir(root)
# torch
import torch
from torch import nn
# father
import importlib
item = importlib.import_module(f"{sys.argv[1]}")
Dataset = item.Dataset
train_loader = item.train_loader
optimizer = item.optimizer
train_set = item.train_set
config = item.config
model = item.model
assert config.get("tag") is not None, "Remember to set a tag."
test_config = {
"device": "cuda",
"checkpoint": f"./checkpoint/{config['tag']}.pth",
}
config.update(test_config)
# Model
print('==> Building model..')
diction = torch.load(config["checkpoint"])
permutation_shape = diction["to_permutation_state.weight"].shape
model.to_permutation_state = nn.Embedding(*permutation_shape)
model.load_state_dict(diction)
model = model.to(config["device"])
# test
print('==> Defining training..')
def memory_test():
print("==> start training..")
model.train()
for batch_idx, (param, permutation_state) in enumerate(train_loader):
optimizer.zero_grad()
# noinspection PyArgumentList
with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
loss = model(output_shape=param.shape,
x_0=param.to(model.device),
permutation_state=permutation_state.to(model.device))
loss.backward()
optimizer.step()
if batch_idx >= 10:
break
os.system("nvidia-smi")
input(f"This program running on GPU:{os.environ['CUDA_VISIBLE_DEVICES']}")
if __name__ == "__main__":
memory_test()