Guide to Google’s STAC: An SSL Framework For Object Detection
The Google Brain team has introduced STAC, a semi-supervised learning (SSL) framework to perform object detection in a simplified way. STAC outperforms most supervised approaches with greater data efficiency and simplicity. Further, STAC opens a new gateway to SSL-based visual object detection.
The top challenge in an object detection task is preparing annotated image data. Thanks to great data benchmarks, this issue has been well mitigated. The next challenge arises with the compute memory to handle such a huge data. Research labs and countable companies are well established with necessary hardware, but what about numerous production models all over the world preparing for task-specific deployment? Suppose a big pre-trained model trained on 1000 classes of objects is deployed as a bye-pass to the above issues that need no data preparation or compute power for training. But if the task is to detect some 10 classes only, there is actually no need to deploy a billion-parameter pre-trained model trained with data from 1000 classes.
Semi-supervised learning (SSL) finds solutions to the above two challenges to a greater extent. SSL requires a fraction of the data annotated, leaving all others unannotated. Thus, the first issue of huge annotated data is resolvable. SSL models enable task-specific training with freshly prepared data. Moreover, Semi-supervised approaches yield better performance compared to fully supervised approaches. In recent years, image classification using SSL has become popular among researchers. Despite some SSL modeling attempts, there has been a gap generalizing the SSL approach to object detection.
To this end, Kihyuk Sohn, Zizhao Zhang, Chun-Liang Li, Han Zhang, Chen-Yu Lee, and Tomas Pfister from Google Cloud AI Research, Google Brain has developed an SSL framework with two newly introduced tunable hyperparameters, the STAC, intended exclusively for object detection. It follows a consistency-based self-training strategy to efficiently train the model with minimal annotated data and strongly augmented unannotated data.
How does STAC really work?
STAC is provided with a small amount of annotated data and a large amount of unannotated data. Semi-supervised learning is the mixed mode of supervised learning with annotated data and unsupervised learning with unannotated data. STAC employs a supervised-learning-based Model as its teacher model. This can be a Faster R-CNN. The annotated images are fed into the teacher model to train it. With this, the supervised learning part is finished.
The trained teacher model is used to infer bounding boxes over all the objects in the unannotated images. While inferring the bounding boxes, the model generates confidence values for each bounding box. Non-Maximum Suppression (NMS) is applied to the bounding boxes as post-processing. Further, STAC introduces a hyper-parameter that acts as a cut-off (threshold) to confidence value. Bounding boxes with confidence values above this cut-off are retained, whereas the rest are discarded. The retained bounding boxes are called pseudo labels and the images with pseudo labels are called pseudo-labeled images.
Strong augmentation strategies such as colour transformations, box-level geometric transformations, global transformations, and Cut-outs are applied to the unannotated images. The detector is trained with both the annotated images and the pseudo-labeled unannotated images with strong augmentations. STAC introduces a tunable loss weight to control unsupervised learning.
Other than the teacher model, STAC is completely simple with just two tunable parameters (one for bounding box confidence thresholding and another for unsupervised loss control) and needs no supervision or intervention at all.
Python Implementation of STAC
STAC requires a Python environment with TensorFlow version 1.14 and a CUDA runtime with 8 GPUs. Most of this implementation references the official source code repository of STAC. Clone the source code to the local (or virtual) environment using the following command.
!git clone https://github.com/google-research/ssl_detection.git
Output:
Create the environment by installing the dependencies using the following commands.
%%bash cd /content/ssl_detection/ sudo apt install python3-dev python3-virtualenv python3-tk imagemagick virtualenv -p python3 --system-site-packages env3 . env3/bin/activate pip install -r requirements.txt python -c 'import tensorflow as tf; print(tf.__version__)' # install coco apis pip3 install 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
Prepare the COCO dataset using the following commands. Unzip the compressed files once the download is finished.
%%bash mkdir -p /content/coco/ cd /content/coco/ wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip wget http://images.cocodataset.org/zips/train2017.zip wget http://images.cocodataset.org/zips/val2017.zip wget http://images.cocodataset.org/zips/unlabeled2017.zip unzip annotations_trainval2017.zip -d . unzip -q train2017.zip -d . unzip -q val2017.zip -d . unzip -q unlabeled2017.zip -d .
Similarly, download the VOC dataset and untar them using the following commands.
%%bash mkdir -p /content/voc/ cd /content/voc/ wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar tar -xf VOCtrainval_06-Nov-2007.tar tar -xf VOCtest_06-Nov-2007.tar tar -xf VOCtrainval_11-May-2012.tar
Generate annotated, and unannotated data split from the downloaded annotated dataset.
%%bash cd /content/ssl_detection/prepare_datasets for seed in 1 2 3 4 5; do for percent in 1 2 5 10 20; do python3 prepare_coco_data.py --percent $percent --seed $seed & done done
Download JSON files for the downloaded VOC dataset and COCO dataset.
%%bash cd /content/ wget https://storage.cloud.google.com/gresearch/ssl_detection/STAC_JSON.tar tar -xf STAC_JSON.tar.gz
Download an untrained FasterRCNN backbone model
%%bash cd /content/coco/ wget http://models.tensorpack.com/FasterRCNN/ImageNet-R50-AlignPadding.npz
Prepare COCO dataset for training. Generate a path to save the model checkpoints during training. Prepare CUDA for training the FasterRCNN backbone model with 8 GPUs using the following commands. (Users can opt for 4, 16, or 32 GPUs)
%%bash cd /content/ssl_detection/detection # Labeled and Unlabeled datasets DATASET=coco_train2017.1@10 UNLABELED_DATASET=${DATASET}-unlabeled # PATH to save trained models CKPT_PATH=result/${DATASET} # PATH to save pseudo labels for unlabeled data PSEUDO_PATH=${CKPT_PATH}/PSEUDO_DATA # Train with 8 GPUs export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
Train the teacher model- FasterRCNN with the above settings on the prepared dataset.
%%bash python3 train_stg1.py \ --logdir ${CKPT_PATH} --simple_path --config \ BACKBONE.WEIGHTS=/content/coco/ImageNet-R50-AlignPadding.npz \ DATA.BASEDIR=/content/coco/ \ DATA.TRAIN="('${DATASET}',)" \ MODE_MASK=False \ FRCNN.BATCH_PER_IM=64 \ PREPROC.TRAIN_SHORT_EDGE_SIZE="[500,800]" \ TRAIN.EVAL_PERIOD=20 \ TRAIN.AUGTYPE_LAB='default'
Evaluate the teacher model using the following commands.
%%bash if [ ! -d ${PSEUDO_PATH} ]; then mkdir -p ${PSEUDO_PATH} fi # model-180000 is the last checkpoint # save eval.json at $PSEUDO_PATH python3 predict.py \ --evaluate ${PSEUDO_PATH}/eval.json \ --load "${CKPT_PATH}"/model-180000 \ --config \ DATA.BASEDIR=/content/coco/ \ DATA.TRAIN="('${UNLABELED_DATASET}',)"
Once the teacher model is trained, it can be used to prepare pseudo labels for the unannotated images.
%%bash python3 predict.py \ --predict_unlabeled ${PSEUDO_PATH} \ --load "${CKPT_PATH}"/model-180000 \ --config \ DATA.BASEDIR=/content/coco/ \ DATA.TRAIN="('${UNLABELED_DATASET}',)" \ EVAL.PSEUDO_INFERENCE=True
Train the STAC detector with necessary configurations using the following command. It should be noted that training may take hours based on memory availability.
%%bash python3 train_stg2.py \ --logdir=${CKPT_PATH}/STAC --simple_path \ --pseudo_path=${PSEUDO_PATH} \ --config \ BACKBONE.WEIGHTS=/content/coco/ImageNet-R50-AlignPadding.npz \ DATA.BASEDIR=/content/coco/ \ DATA.TRAIN="('${DATASET}',)" \ DATA.UNLABEL="('${UNLABELED_DATASET}',)" \ MODE_MASK=False \ FRCNN.BATCH_PER_IM=64 \ PREPROC.TRAIN_SHORT_EDGE_SIZE="[500,800]" \ TRAIN.EVAL_PERIOD=20 \ TRAIN.AUGTYPE_LAB='default' \ TRAIN.AUGTYPE='strong' \ TRAIN.CONFIDENCE=0.9 \ TRAIN.WU=2
A trained model can be directly deployed for inference.
Performance of STAC
STAC is evaluated on MS-COCO dataset and VOC07 along with recent supervised baseline models. STAC greatly outperforms supervised models trained either with strong augmentation or without augmentation. STAC, trained with only 5% of annotated data, yields far better performance than a Supervised model that is trained with 10% of annotated data.
References
The post Guide to Google’s STAC: An SSL Framework For Object Detection appeared first on Analytics India Magazine.