\ This tutorial fine-tunes a Mask R-CNN with Mobilenet V2 as backbone model from the TensorFlow Model Garden package (tensorflow-models).
Model Garden contains a collection of state-of-the-art models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.
This tutorial demonstrates how to:
pip install -U -q "tf-models-official" pip install -U -q remotezip tqdm opencv-python einops
import os import io import json import tqdm import shutil import pprint import pathlib import tempfile import requests import collections import matplotlib import numpy as np import tensorflow as tf import matplotlib.pyplot as plt from PIL import Image from six import BytesIO from etils import epath from IPython import display from urllib.request import urlopen
\
2023-11-30 12:05:19.630836: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-30 12:05:19.630880: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-30 12:05:19.632442: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
\
import orbit import tensorflow as tf import tensorflow_models as tfm import tensorflow_datasets as tfds from official.core import exp_factory from official.core import config_definitions as cfg from official.vision.data import tfrecord_lib from official.vision.serving import export_saved_model_lib from official.vision.dataloaders.tf_example_decoder import TfExampleDecoder from official.vision.utils.object_detection import visualization_utils from official.vision.ops.preprocess_ops import normalize_image, resize_and_crop_image from official.vision.data.create_coco_tf_record import coco_annotations_to_lists pp = pprint.PrettyPrinter(indent=4) # Set Pretty Print Indentation print(tf.__version__) # Check the version of tensorflow used %matplotlib inline
\
2.15.0
LVIS: A dataset for large vocabulary instance segmentation.
\
:::tip Note: LVIS uses the COCO 2017 train, validation, and test image sets. If you have already downloaded the COCO images, you only need to download the LVIS annotations. LVIS val set contains images from COCO 2017 train in addition to the COCO 2017 val split.
:::
\
# @title Download annotation files wget https://dl.fbaipublicfiles.com/LVIS/lvis_v1_train.json.zip unzip -q lvis_v1_train.json.zip rm lvis_v1_train.json.zip wget https://dl.fbaipublicfiles.com/LVIS/lvis_v1_val.json.zip unzip -q lvis_v1_val.json.zip rm lvis_v1_val.json.zip wget https://dl.fbaipublicfiles.com/LVIS/lvis_v1_image_info_test_dev.json.zip unzip -q lvis_v1_image_info_test_dev.json.zip rm lvis_v1_image_info_test_dev.json.zip
\
--2023-11-30 12:05:23-- https://dl.fbaipublicfiles.com/LVIS/lvis_v1_train.json.zip Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 3.163.189.51, 3.163.189.108, 3.163.189.14, ... Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|3.163.189.51|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 350264821 (334M) [application/zip] Saving to: ‘lvis_v1_train.json.zip’ lvis_v1_train.json. 100%[===================>] 334.04M 295MB/s in 1.1s 2023-11-30 12:05:25 (295 MB/s) - ‘lvis_v1_train.json.zip’ saved [350264821/350264821] --2023-11-30 12:05:34-- https://dl.fbaipublicfiles.com/LVIS/lvis_v1_val.json.zip Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 3.163.189.51, 3.163.189.108, 3.163.189.14, ... Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|3.163.189.51|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 64026968 (61M) [application/zip] Saving to: ‘lvis_v1_val.json.zip’ lvis_v1_val.json.zi 100%[===================>] 61.06M 184MB/s in 0.3s 2023-11-30 12:05:34 (184 MB/s) - ‘lvis_v1_val.json.zip’ saved [64026968/64026968] --2023-11-30 12:05:36-- https://dl.fbaipublicfiles.com/LVIS/lvis_v1_image_info_test_dev.json.zip Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 3.163.189.51, 3.163.189.108, 3.163.189.14, ... Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|3.163.189.51|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 384629 (376K) [application/zip] Saving to: ‘lvis_v1_image_info_test_dev.json.zip’ lvis_v1_image_info_ 100%[===================>] 375.61K --.-KB/s in 0.03s 2023-11-30 12:05:37 (12.3 MB/s) - ‘lvis_v1_image_info_test_dev.json.zip’ saved [384629/384629]
\
# @title Lvis annotation parsing # Annotations with invalid bounding boxes. Will not be used. _INVALID_ANNOTATIONS = [ # Train split. 662101, 81217, 462924, 227817, 29381, 601484, 412185, 504667, 572573, 91937, 239022, 181534, 101685, # Validation split. 36668, 57541, 33126, 10932, ] def get_category_map(annotation_path, num_classes): with epath.Path(annotation_path).open() as f: data = json.load(f) category_map = {id+1: {'id': cat_dict['id'], 'name': cat_dict['name']} for id, cat_dict in enumerate(data['categories'][:num_classes])} return category_map class LvisAnnotation: """LVIS annotation helper class. The format of the annations is explained on https://www.lvisdataset.org/dataset. """ def __init__(self, annotation_path): with epath.Path(annotation_path).open() as f: data = json.load(f) self._data = data img_id2annotations = collections.defaultdict(list) for a in self._data.get('annotations', []): if a['category_id'] in category_ids: img_id2annotations[a['image_id']].append(a) self._img_id2annotations = { k: list(sorted(v, key=lambda a: a['id'])) for k, v in img_id2annotations.items() } @property def categories(self): """Return the category dicts, as sorted in the file.""" return self._data['categories'] @property def images(self): """Return the image dicts, as sorted in the file.""" sub_images = [] for image_info in self._data['images']: if image_info['id'] in self._img_id2annotations: sub_images.append(image_info) return sub_images def get_annotations(self, img_id): """Return all annotations associated with the image id string.""" # Some images don't have any annotations. Return empty list instead. return self._img_id2annotations.get(img_id, []) def _generate_tf_records(prefix, images_zip, annotation_file, num_shards=5): """Generate TFRecords.""" lvis_annotation = LvisAnnotation(annotation_file) def _process_example(prefix, image_info, id_to_name_map): # Search image dirs. filename = pathlib.Path(image_info['coco_url']).name image = tf.io.read_file(os.path.join(IMGS_DIR, filename)) instances = lvis_annotation.get_annotations(img_id=image_info['id']) instances = [x for x in instances if x['id'] not in _INVALID_ANNOTATIONS] # print([x['category_id'] for x in instances]) is_crowd = {'iscrowd': 0} instances = [dict(x, **is_crowd) for x in instances] neg_category_ids = image_info.get('neg_category_ids', []) not_exhaustive_category_ids = image_info.get( 'not_exhaustive_category_ids', [] ) data, _ = coco_annotations_to_lists(instances, id_to_name_map, image_info['height'], image_info['width'], include_masks=True) # data['category_id'] = [id-1 for id in data['category_id']] keys_to_features = { 'image/encoded': tfrecord_lib.convert_to_feature(image.numpy()), 'image/filename': tfrecord_lib.convert_to_feature(filename.encode('utf8')), 'image/format': tfrecord_lib.convert_to_feature('jpg'.encode('utf8')), 'image/height': tfrecord_lib.convert_to_feature(image_info['height']), 'image/width': tfrecord_lib.convert_to_feature(image_info['width']), 'image/source_id': tfrecord_lib.convert_to_feature(str(image_info['id']).encode('utf8')), 'image/object/bbox/xmin': tfrecord_lib.convert_to_feature(data['xmin']), 'image/object/bbox/xmax': tfrecord_lib.convert_to_feature(data['xmax']), 'image/object/bbox/ymin': tfrecord_lib.convert_to_feature(data['ymin']), 'image/object/bbox/ymax': tfrecord_lib.convert_to_feature(data['ymax']), 'image/object/class/text': tfrecord_lib.convert_to_feature(data['category_names']), 'image/object/class/label': tfrecord_lib.convert_to_feature(data['category_id']), 'image/object/is_crowd': tfrecord_lib.convert_to_feature(data['is_crowd']), 'image/object/area': tfrecord_lib.convert_to_feature(data['area'], 'float_list'), 'image/object/mask': tfrecord_lib.convert_to_feature(data['encoded_mask_png']) } # print(keys_to_features['image/object/class/label']) example = tf.train.Example( features=tf.train.Features(feature=keys_to_features)) return example # file_names = [f"{prefix}/{pathlib.Path(image_info['coco_url']).name}" # for image_info in lvis_annotation.images] # _extract_images(images_zip, file_names) writers = [ tf.io.TFRecordWriter( tf_records_dir + prefix +'-%05d-of-%05d.tfrecord' % (i, num_shards)) for i in range(num_shards) ] id_to_name_map = {cat_dict['id']: cat_dict['name'] for cat_dict in lvis_annotation.categories[:NUM_CLASSES]} # print(id_to_name_map) for idx, image_info in enumerate(tqdm.tqdm(lvis_annotation.images)): img_data = requests.get(image_info['coco_url'], stream=True).content img_name = image_info['coco_url'].split('/')[-1] with open(os.path.join(IMGS_DIR, img_name), 'wb') as handler: handler.write(img_data) tf_example = _process_example(prefix, image_info, id_to_name_map) writers[idx % num_shards].write(tf_example.SerializeToString()) del lvis_annotation
\
_URLS = { 'train_images': 'http://images.cocodataset.org/zips/train2017.zip', 'validation_images': 'http://images.cocodataset.org/zips/val2017.zip', 'test_images': 'http://images.cocodataset.org/zips/test2017.zip', } train_prefix = 'train' valid_prefix = 'val' train_annotation_path = './lvis_v1_train.json' valid_annotation_path = './lvis_v1_val.json' IMGS_DIR = './lvis_sub_dataset/' tf_records_dir = './lvis_tfrecords/' if not os.path.exists(IMGS_DIR): os.mkdir(IMGS_DIR) if not os.path.exists(tf_records_dir): os.mkdir(tf_records_dir) NUM_CLASSES = 3 category_index = get_category_map(valid_annotation_path, NUM_CLASSES) category_ids = list(category_index.keys())
\
# Below helper function are taken from github tensorflow dataset lvis # https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/datasets/lvis/lvis_dataset_builder.py _generate_tf_records(train_prefix, _URLS['train_images'], train_annotation_path)
\
100%|██████████| 2338/2338 [16:14<00:00, 2.40it/s]
\
_generate_tf_records(valid_prefix, _URLS['validation_images'], valid_annotation_path)
\
100%|██████████| 422/422 [02:56<00:00, 2.40it/s]
train_data_input_path = './lvis_tfrecords/train*' valid_data_input_path = './lvis_tfrecords/val*' test_data_input_path = './lvis_tfrecords/test*' model_dir = './trained_model/' export_dir ='./exported_model/'
\
if not os.path.exists(model_dir): os.mkdir(model_dir)
In Model Garden, the collections of parameters that define a model are called configs. Model Garden can create a config based on a known set of parameters via a factory.
Use the retinanet_mobilenet_coco
experiment configuration, as defined by tfm.vision.configs.maskrcnn.maskrcnn_mobilenet_coco
.
Please find all the registered experiements here
The configuration defines an experiment to train a Mask R-CNN model with mobilenet as backbone and FPN as decoder. Default Congiguration is trained on COCO train2017 and evaluated on COCO val2017.
There are also other alternative experiments available such as maskrcnn_resnetfpn_coco
, maskrcnn_spinenet_coco
and more. One can switch to them by changing the experiment name argument to the get_exp_config
function.
\
exp_config = exp_factory.get_exp_config('maskrcnn_mobilenet_coco')
\
model_ckpt_path = './model_ckpt/' if not os.path.exists(model_ckpt_path): os.mkdir(model_ckpt_path) !gsutil cp gs://tf_model_garden/vision/mobilenet/v2_1.0_float/ckpt-180648.data-00000-of-00001 './model_ckpt/' !gsutil cp gs://tf_model_garden/vision/mobilenet/v2_1.0_float/ckpt-180648.index './model_ckpt/'
\
Copying gs://tf_model_garden/vision/mobilenet/v2_1.0_float/ckpt-180648.data-00000-of-00001... Operation completed over 1 objects/26.9 MiB. Copying gs://tf_model_garden/vision/mobilenet/v2_1.0_float/ckpt-180648.index... Operation completed over 1 objects/7.5 KiB.
BATCH_SIZE = 8 HEIGHT, WIDTH = 256, 256 IMG_SHAPE = [HEIGHT, WIDTH, 3] # Backbone Config exp_config.task.annotation_file = None exp_config.task.freeze_backbone = True exp_config.task.init_checkpoint = "./model_ckpt/ckpt-180648" exp_config.task.init_checkpoint_modules = "backbone" # Model Config exp_config.task.model.num_classes = NUM_CLASSES + 1 exp_config.task.model.input_size = IMG_SHAPE # Training Data Config exp_config.task.train_data.input_path = train_data_input_path exp_config.task.train_data.dtype = 'float32' exp_config.task.train_data.global_batch_size = BATCH_SIZE exp_config.task.train_data.shuffle_buffer_size = 64 exp_config.task.train_data.parser.aug_scale_max = 1.0 exp_config.task.train_data.parser.aug_scale_min = 1.0 # Validation Data Config exp_config.task.validation_data.input_path = valid_data_input_path exp_config.task.validation_data.dtype = 'float32' exp_config.task.validation_data.global_batch_size = BATCH_SIZE
logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()] if 'GPU' in ''.join(logical_device_names): print('This may be broken in Colab.') device = 'GPU' elif 'TPU' in ''.join(logical_device_names): print('This may be broken in Colab.') device = 'TPU' else: print('Running on CPU is slow, so only train for a few steps.') device = 'CPU' train_steps = 2000 exp_config.trainer.steps_per_loop = 200 # steps_per_loop = num_of_training_examples // train_batch_size exp_config.trainer.summary_interval = 200 exp_config.trainer.checkpoint_interval = 200 exp_config.trainer.validation_interval = 200 exp_config.trainer.validation_steps = 200 # validation_steps = num_of_validation_examples // eval_batch_size exp_config.trainer.train_steps = train_steps exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 200 exp_config.trainer.optimizer_config.learning_rate.type = 'cosine' exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.07 exp_config.trainer.optimizer_config.warmup.linear.warmup_learning_rate = 0.05
\
This may be broken in Colab.
pp.pprint(exp_config.as_dict()) display.Javascript("google.colab.output.setIframeHeight('500px');")
\
{ 'runtime': { 'all_reduce_alg': None, 'batchnorm_spatial_persistent': False, 'dataset_num_private_threads': None, 'default_shard_dim': -1, 'distribution_strategy': 'mirrored', 'enable_xla': False, 'gpu_thread_mode': None, 'loss_scale': None, 'mixed_precision_dtype': 'bfloat16', 'num_cores_per_replica': 1, 'num_gpus': 0, 'num_packs': 1, 'per_gpu_thread_count': 0, 'run_eagerly': False, 'task_index': -1, 'tpu': None, 'tpu_enable_xla_dynamic_padder': None, 'use_tpu_mp_strategy': False, 'worker_hosts': None}, 'task': { 'allow_image_summary': False, 'allowed_mask_class_ids': None, 'annotation_file': None, 'differential_privacy_config': None, 'freeze_backbone': True, 'init_checkpoint': './model_ckpt/ckpt-180648', 'init_checkpoint_modules': 'backbone', 'losses': { 'class_weights': None, 'frcnn_box_weight': 1.0, 'frcnn_class_loss_top_k_percent': 1.0, 'frcnn_class_use_binary_cross_entropy': False, 'frcnn_class_weight': 1.0, 'frcnn_huber_loss_delta': 1.0, 'l2_weight_decay': 4e-05, 'loss_weight': 1.0, 'mask_weight': 1.0, 'rpn_box_weight': 1.0, 'rpn_huber_loss_delta': 0.1111111111111111, 'rpn_score_weight': 1.0}, 'model': { 'anchor': { 'anchor_size': 3, 'aspect_ratios': [0.5, 1.0, 2.0], 'num_scales': 1}, 'backbone': { 'mobilenet': { 'filter_size_scale': 1.0, 'model_id': 'MobileNetV2', 'output_intermediate_endpoints': False, 'output_stride': None, 'stochastic_depth_drop_rate': 0.0}, 'type': 'mobilenet'}, 'decoder': { 'fpn': { 'fusion_type': 'sum', 'num_filters': 128, 'use_keras_layer': False, 'use_separable_conv': True}, 'type': 'fpn'}, 'detection_generator': { 'apply_nms': True, 'max_num_detections': 100, 'nms_iou_threshold': 0.5, 'nms_version': 'v2', 'pre_nms_score_threshold': 0.05, 'pre_nms_top_k': 5000, 'soft_nms_sigma': None, 'use_cpu_nms': False, 'use_sigmoid_probability': False}, 'detection_head': { 'cascade_class_ensemble': False, 'class_agnostic_bbox_pred': False, 'fc_dims': 512, 'num_convs': 4, 'num_fcs': 1, 'num_filters': 128, 'use_separable_conv': True}, 'include_mask': True, 'input_size': [256, 256, 3], 'mask_head': { 'class_agnostic': False, 'num_convs': 4, 'num_filters': 128, 'upsample_factor': 2, 'use_separable_conv': True}, 'mask_roi_aligner': { 'crop_size': 14, 'sample_offset': 0.5}, 'mask_sampler': {'num_sampled_masks': 128}, 'max_level': 6, 'min_level': 3, 'norm_activation': { 'activation': 'relu6', 'norm_epsilon': 0.001, 'norm_momentum': 0.99, 'use_sync_bn': True}, 'num_classes': 4, 'outer_boxes_scale': 1.0, 'roi_aligner': { 'crop_size': 7, 'sample_offset': 0.5}, 'roi_generator': { 'nms_iou_threshold': 0.7, 'num_proposals': 1000, 'pre_nms_min_size_threshold': 0.0, 'pre_nms_score_threshold': 0.0, 'pre_nms_top_k': 2000, 'test_nms_iou_threshold': 0.7, 'test_num_proposals': 1000, 'test_pre_nms_min_size_threshold': 0.0, 'test_pre_nms_score_threshold': 0.0, 'test_pre_nms_top_k': 1000, 'use_batched_nms': False}, 'roi_sampler': { 'background_iou_high_threshold': 0.5, 'background_iou_low_threshold': 0.0, 'cascade_iou_thresholds': None, 'foreground_fraction': 0.25, 'foreground_iou_threshold': 0.5, 'mix_gt_boxes': True, 'num_sampled_rois': 512}, 'rpn_head': { 'num_convs': 1, 'num_filters': 128, 'use_separable_conv': True} }, 'name': None, 'per_category_metrics': False, 'train_data': { 'apply_tf_data_service_before_batching': False, 'autotune_algorithm': None, 'block_length': 1, 'cache': False, 'cycle_length': None, 'decoder': { 'simple_decoder': { 'attribute_names': [ ], 'mask_binarize_threshold': None, 'regenerate_source_id': False}, 'type': 'simple_decoder'}, 'deterministic': None, 'drop_remainder': True, 'dtype': 'float32', 'enable_shared_tf_data_service_between_parallel_trainers': False, 'enable_tf_data_service': False, 'file_type': 'tfrecord', 'global_batch_size': 8, 'input_path': './lvis_tfrecords/train*', 'is_training': True, 'num_examples': -1, 'parser': { 'aug_rand_hflip': True, 'aug_rand_vflip': False, 'aug_scale_max': 1.0, 'aug_scale_min': 1.0, 'aug_type': None, 'mask_crop_size': 112, 'match_threshold': 0.5, 'max_num_instances': 100, 'num_channels': 3, 'pad': True, 'rpn_batch_size_per_im': 256, 'rpn_fg_fraction': 0.5, 'rpn_match_threshold': 0.7, 'rpn_unmatched_threshold': 0.3, 'skip_crowd_during_training': True, 'unmatched_threshold': 0.5}, 'prefetch_buffer_size': None, 'seed': None, 'sharding': True, 'shuffle_buffer_size': 64, 'tf_data_service_address': None, 'tf_data_service_job_name': None, 'tfds_as_supervised': False, 'tfds_data_dir': '', 'tfds_name': '', 'tfds_skip_decoding_feature': '', 'tfds_split': '', 'trainer_id': None, 'weights': None}, 'use_approx_instance_metrics': False, 'use_coco_metrics': True, 'use_wod_metrics': False, 'validation_data': { 'apply_tf_data_service_before_batching': False, 'autotune_algorithm': None, 'block_length': 1, 'cache': False, 'cycle_length': None, 'decoder': { 'simple_decoder': { 'attribute_names': [ ], 'mask_binarize_threshold': None, 'regenerate_source_id': False}, 'type': 'simple_decoder'}, 'deterministic': None, 'drop_remainder': False, 'dtype': 'float32', 'enable_shared_tf_data_service_between_parallel_trainers': False, 'enable_tf_data_service': False, 'file_type': 'tfrecord', 'global_batch_size': 8, 'input_path': './lvis_tfrecords/val*', 'is_training': False, 'num_examples': -1, 'parser': { 'aug_rand_hflip': False, 'aug_rand_vflip': False, 'aug_scale_max': 1.0, 'aug_scale_min': 1.0, 'aug_type': None, 'mask_crop_size': 112, 'match_threshold': 0.5, 'max_num_instances': 100, 'num_channels': 3, 'pad': True, 'rpn_batch_size_per_im': 256, 'rpn_fg_fraction': 0.5, 'rpn_match_threshold': 0.7, 'rpn_unmatched_threshold': 0.3, 'skip_crowd_during_training': True, 'unmatched_threshold': 0.5}, 'prefetch_buffer_size': None, 'seed': None, 'sharding': True, 'shuffle_buffer_size': 10000, 'tf_data_service_address': None, 'tf_data_service_job_name': None, 'tfds_as_supervised': False, 'tfds_data_dir': '', 'tfds_name': '', 'tfds_skip_decoding_feature': '', 'tfds_split': '', 'trainer_id': None, 'weights': None} }, 'trainer': { 'allow_tpu_summary': False, 'best_checkpoint_eval_metric': '', 'best_checkpoint_export_subdir': '', 'best_checkpoint_metric_comp': 'higher', 'checkpoint_interval': 200, 'continuous_eval_timeout': 3600, 'eval_tf_function': True, 'eval_tf_while_loop': False, 'loss_upper_bound': 1000000.0, 'max_to_keep': 5, 'optimizer_config': { 'ema': None, 'learning_rate': { 'cosine': { 'alpha': 0.0, 'decay_steps': 2000, 'initial_learning_rate': 0.07, 'name': 'CosineDecay', 'offset': 0}, 'type': 'cosine'}, 'optimizer': { 'sgd': { 'clipnorm': None, 'clipvalue': None, 'decay': 0.0, 'global_clipnorm': None, 'momentum': 0.9, 'name': 'SGD', 'nesterov': False}, 'type': 'sgd'}, 'warmup': { 'linear': { 'name': 'linear', 'warmup_learning_rate': 0.05, 'warmup_steps': 200}, 'type': 'linear'} }, 'preemption_on_demand_checkpoint': True, 'recovery_begin_steps': 0, 'recovery_max_trials': 0, 'steps_per_loop': 200, 'summary_interval': 200, 'train_steps': 2000, 'train_tf_function': True, 'train_tf_while_loop': True, 'validation_interval': 200, 'validation_steps': 200, 'validation_summary_subdir': 'validation'} } <IPython.core.display.Javascript object>
# Setting up the Strategy if exp_config.runtime.mixed_precision_dtype == tf.float16: tf.keras.mixed_precision.set_global_policy('mixed_float16') if 'GPU' in ''.join(logical_device_names): distribution_strategy = tf.distribute.MirroredStrategy() elif 'TPU' in ''.join(logical_device_names): tf.tpu.experimental.initialize_tpu_system() tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0') distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu) else: print('Warning: this will be really slow.') distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0]) print("Done")
\
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') Done
Task
object (tfm.core.base_task.Task
) from the config_definitions.TaskConfig
.The Task
object has all the methods necessary for building the dataset, building the model, and running training & evaluation. These methods are driven by tfm.core.train_lib.run_experiment
.
\
with distribution_strategy.scope(): task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)
\ \
:::info Originally published on the TensorFlow website, this article appears here under a new headline and is licensed under CC BY 4.0. Code samples shared under the Apache 2.0 License.
:::
\