| from __gin__ import dynamic_registration |
| import __main__ as train_script |
| import seqio |
| from t5.data import mixtures |
| from t5x import adafactor |
| from t5x.examples.t5 import network |
| from t5x import gin_utils |
| from t5x import models |
| from t5x import partitioning |
| from t5x import trainer |
| from t5x import utils |
| import tasks |
|
|
| |
| |
| BATCH_SIZE = 32 |
| DROPOUT_RATE = 0.1 |
| EVAL_PERIOD = 1000 |
| EVAL_STEPS = 20 |
| EVALUATOR_NUM_EXAMPLES = None |
| EVALUATOR_USE_MEMORY_CACHE = True |
| INITIAL_CHECKPOINT_PATH = \ |
| 'gs://nb-t5x-us-central2/norwegian_NCC_plus_English_pluss200k_balanced_bokmaal_nynorsk_t5x_large/checkpoint_1700000' |
| JSON_WRITE_N_RESULTS = None |
| LABEL_SMOOTHING = 0.0 |
| LOSS_NORMALIZING_FACTOR = None |
| MIXTURE_OR_TASK_MODULE = None |
| MIXTURE_OR_TASK_NAME = 'translate_long' |
| MODEL = @models.EncoderDecoderModel() |
| MODEL_DIR = 'gs://nb-t5x-us-central2/finetuned/nynorsk_balanced_large_long_v1' |
| OPTIMIZER = @adafactor.Adafactor() |
| RANDOM_SEED = 0 |
| TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets': 512} |
| TRAIN_STEPS = 1705000 |
| USE_CACHED_TASKS = False |
| USE_HARDWARE_RNG = False |
| VOCABULARY = @seqio.SentencePieceVocabulary() |
| Z_LOSS = 0.0001 |
|
|
| |
| |
| adafactor.Adafactor.decay_rate = 0.8 |
| adafactor.Adafactor.logical_factor_rules = \ |
| @adafactor.standard_logical_factor_rules() |
| adafactor.Adafactor.step_offset = 0 |
|
|
| |
| |
| utils.CheckpointConfig.restore = @utils.RestoreCheckpointConfig() |
| utils.CheckpointConfig.save = @utils.SaveCheckpointConfig() |
|
|
| |
| |
| utils.create_learning_rate_scheduler.base_learning_rate = 0.001 |
| utils.create_learning_rate_scheduler.factors = 'constant' |
| utils.create_learning_rate_scheduler.warmup_steps = 1000 |
|
|
| |
| |
| infer_eval/utils.DatasetConfig.batch_size = %BATCH_SIZE |
| infer_eval/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME |
| infer_eval/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE |
| infer_eval/utils.DatasetConfig.pack = False |
| infer_eval/utils.DatasetConfig.seed = 42 |
| infer_eval/utils.DatasetConfig.shuffle = False |
| infer_eval/utils.DatasetConfig.split = 'validation' |
| infer_eval/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS |
| infer_eval/utils.DatasetConfig.use_cached = %USE_CACHED_TASKS |
|
|
| |
| |
| train/utils.DatasetConfig.batch_size = %BATCH_SIZE |
| train/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME |
| train/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE |
| train/utils.DatasetConfig.pack = True |
| train/utils.DatasetConfig.seed = None |
| train/utils.DatasetConfig.shuffle = True |
| train/utils.DatasetConfig.split = 'train' |
| train/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS |
| train/utils.DatasetConfig.use_cached = %USE_CACHED_TASKS |
|
|
| |
| |
| train_eval/utils.DatasetConfig.batch_size = %BATCH_SIZE |
| train_eval/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME |
| train_eval/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE |
| train_eval/utils.DatasetConfig.pack = True |
| train_eval/utils.DatasetConfig.seed = 42 |
| train_eval/utils.DatasetConfig.shuffle = False |
| train_eval/utils.DatasetConfig.split = 'validation' |
| train_eval/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS |
| train_eval/utils.DatasetConfig.use_cached = %USE_CACHED_TASKS |
|
|
| |
| |
| models.EncoderDecoderModel.input_vocabulary = %VOCABULARY |
| models.EncoderDecoderModel.label_smoothing = %LABEL_SMOOTHING |
| models.EncoderDecoderModel.loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR |
| models.EncoderDecoderModel.module = @network.Transformer() |
| models.EncoderDecoderModel.optimizer_def = %OPTIMIZER |
| models.EncoderDecoderModel.output_vocabulary = %VOCABULARY |
| models.EncoderDecoderModel.z_loss = %Z_LOSS |
|
|
| |
| |
| seqio.Evaluator.logger_cls = \ |
| [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] |
| seqio.Evaluator.num_examples = %EVALUATOR_NUM_EXAMPLES |
| seqio.Evaluator.use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE |
|
|
| |
| |
| seqio.JSONLogger.write_n_results = %JSON_WRITE_N_RESULTS |
|
|
| |
| |
| partitioning.PjitPartitioner.logical_axis_rules = \ |
| @partitioning.standard_logical_axis_rules() |
| partitioning.PjitPartitioner.model_parallel_submesh = None |
| partitioning.PjitPartitioner.num_partitions = 1 |
|
|
| |
| |
| utils.RestoreCheckpointConfig.dtype = 'float32' |
| utils.RestoreCheckpointConfig.mode = 'specific' |
| utils.RestoreCheckpointConfig.path = %INITIAL_CHECKPOINT_PATH |
|
|
| |
| |
| utils.SaveCheckpointConfig.dtype = 'float32' |
| utils.SaveCheckpointConfig.keep = None |
| utils.SaveCheckpointConfig.period = 1000 |
| utils.SaveCheckpointConfig.save_dataset = False |
|
|
| |
| |
| seqio.SentencePieceVocabulary.sentencepiece_model_file = \ |
| 'gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model' |
|
|
| |
| |
| network.T5Config.dropout_rate = %DROPOUT_RATE |
| network.T5Config.dtype = 'bfloat16' |
| network.T5Config.emb_dim = 1024 |
| network.T5Config.head_dim = 64 |
| network.T5Config.logits_via_embedding = False |
| network.T5Config.mlp_activations = ('gelu', 'linear') |
| network.T5Config.mlp_dim = 2816 |
| network.T5Config.num_decoder_layers = 24 |
| network.T5Config.num_encoder_layers = 24 |
| network.T5Config.num_heads = 16 |
| network.T5Config.vocab_size = 250112 |
|
|
| |
| |
| train_script.train.checkpoint_cfg = @utils.CheckpointConfig() |
| train_script.train.eval_period = %EVAL_PERIOD |
| train_script.train.eval_steps = %EVAL_STEPS |
| train_script.train.infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig() |
| train_script.train.inference_evaluator_cls = @seqio.Evaluator |
| train_script.train.model = %MODEL |
| train_script.train.model_dir = %MODEL_DIR |
| train_script.train.partitioner = @partitioning.PjitPartitioner() |
| train_script.train.random_seed = %RANDOM_SEED |
| train_script.train.summarize_config_fn = @gin_utils.summarize_gin_config |
| train_script.train.total_steps = %TRAIN_STEPS |
| train_script.train.train_dataset_cfg = @train/utils.DatasetConfig() |
| train_script.train.train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() |
| train_script.train.trainer_cls = @trainer.Trainer |
| train_script.train.use_hardware_rng = %USE_HARDWARE_RNG |
|
|
| |
| |
| trainer.Trainer.learning_rate_fn = @utils.create_learning_rate_scheduler() |
| trainer.Trainer.num_microbatches = None |
|
|
| |
| |
| network.Transformer.config = @network.T5Config() |
|
|