import sys, os root = os.sep + os.sep.join(__file__.split(os.sep)[1:__file__.split(os.sep).index("Recurrent-Parameter-Generation")+1]) sys.path.append(root) os.chdir(root) # torch import time import types import torch import copy from torch import nn from model.diffusion import DDIMSampler, DDPMSampler # father import importlib item = importlib.import_module(f"{sys.argv[1]}") Dataset = item.Dataset train_set = item.train_set config = item.config model = item.model assert config.get("tag") is not None, "Remember to set a tag." generate_config = { "device": "cuda", "num_generated": 1, "checkpoint": f"./checkpoint/{config['tag']}.pth", "generated_path": os.path.join(Dataset.generated_path.rsplit("/", 1)[0], "generated_{}_{}.pth"), "test_command": os.path.join(Dataset.test_command.rsplit("/", 1)[0], "generated_{}_{}.pth"), "need_test": True, # inference setting "sampler": DDIMSampler, "steps": 60, # only valid when using DDIMSampler } config.update(generate_config) # Model print('==> Building model..') diction = torch.load(config["checkpoint"]) permutation_shape = diction["to_permutation_state.weight"].shape model.to_permutation_state = nn.Embedding(*permutation_shape) model.load_state_dict(diction) model.criteria.diffusion_sampler = config["sampler"]( model=model.criteria.diffusion_sampler.model, beta=config["model_config"]["beta"], T=config["model_config"]["T"], ) # sampler will be covered below model.condi_embedder = copy.deepcopy(model.criteria.diffusion_sampler.model.condi_embedder) @torch.no_grad() def new_sample(self, x=None, condition=None): z = self.model([1, self.sequence_length, self.config["d_model"]], condition) z = self.condi_embedder(z) if x is None: x = torch.randn((1, self.sequence_length, self.config["model_dim"]), device=z.device) x = self.criteria.sample(x, z, steps=config["steps"]) return x model.sample = types.MethodType(new_sample, model) model.criteria.diffusion_sampler.model.condi_embedder = nn.Identity() model = model.to(config["device"]) # generate print('==> Defining generate..') def generate(save_path=config["generated_path"], test_command=config["test_command"], need_test=True): print("\n==> Generating..") model.eval() with torch.cuda.amp.autocast(True, torch.bfloat16): with torch.no_grad(): start_time = time.time() prediction = model(sample=True) end_time = time.time() generated_norm = torch.nanmean(prediction.abs()) print("used time (seconds):", end_time - start_time) print("memory usage (GB):", torch.cuda.max_memory_allocated() / (1024 ** 3)) # print("Generated_norm:", generated_norm.item()) train_set.save_params(prediction, save_path=save_path) if need_test: os.system(test_command) print("\n") if __name__ == "__main__": for i in range(config["num_generated"]): index = str(i+1).zfill(3) print("Save to", config["generated_path"].format(config["tag"], index)) generate( save_path=config["generated_path"].format(config["tag"], index), test_command=config["test_command"].format(config["tag"], index), need_test=config["need_test"], ) # generate and print info