|
|
import sys, os |
|
|
root = os.sep + os.sep.join(__file__.split(os.sep)[1:__file__.split(os.sep).index("Recurrent-Parameter-Generation")+1]) |
|
|
sys.path.append(root) |
|
|
os.chdir(root) |
|
|
|
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
import importlib |
|
|
item = importlib.import_module(f"{sys.argv[1]}") |
|
|
Dataset = item.Dataset |
|
|
train_loader = item.train_loader |
|
|
optimizer = item.optimizer |
|
|
train_set = item.train_set |
|
|
config = item.config |
|
|
model = item.model |
|
|
assert config.get("tag") is not None, "Remember to set a tag." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_config = { |
|
|
"device": "cuda", |
|
|
"checkpoint": f"./checkpoint/{config['tag']}.pth", |
|
|
} |
|
|
config.update(test_config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print('==> Building model..') |
|
|
diction = torch.load(config["checkpoint"]) |
|
|
permutation_shape = diction["to_permutation_state.weight"].shape |
|
|
model.to_permutation_state = nn.Embedding(*permutation_shape) |
|
|
model.load_state_dict(diction) |
|
|
model = model.to(config["device"]) |
|
|
|
|
|
|
|
|
|
|
|
print('==> Defining training..') |
|
|
def memory_test(): |
|
|
print("==> start training..") |
|
|
model.train() |
|
|
for batch_idx, (param, permutation_state) in enumerate(train_loader): |
|
|
optimizer.zero_grad() |
|
|
|
|
|
with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): |
|
|
loss = model(output_shape=param.shape, |
|
|
x_0=param.to(model.device), |
|
|
permutation_state=permutation_state.to(model.device)) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
if batch_idx >= 10: |
|
|
break |
|
|
os.system("nvidia-smi") |
|
|
input(f"This program running on GPU:{os.environ['CUDA_VISIBLE_DEVICES']}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
memory_test() |