Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retraining faster rcnn from object detection 2 api on a custom dataste #11251

Open
omzoughi opened this issue Aug 15, 2024 · 0 comments
Open
Labels
models:research models that come under research directory type:feature

Comments

@omzoughi
Copy link

Dear all,

I am trying to retrain the classification head of faster rcnn while keeping other branches, I defined this code:

import tensorflow as tf
from object_detection.builders import model_builder
from object_detection.utils import config_util

Clear the current TensorFlow session

tf.keras.backend.clear_session()

print('Building model and restoring weights for fine-tuning...', flush=True)

Settings

num_classes = 2
pipeline_config = '/MyPath/models/research/object_detection/configs/tf2/faster_rcnn_resnet50_v1_640x640_coco17_tpu-8.config'
checkpoint_path = 'MyPath/models/research/object_detection/test_data2/checkpoint/faster_rcnn_resnet50_v1_640x640_coco17_tpu-8/checkpoint/ckpt-0'

Load pipeline config and build a detection model

configs = config_util.get_configs_from_pipeline_file(pipeline_config)
model_config = configs['model']

Update the model config to reflect the number of classes

model_config.faster_rcnn.num_classes = num_classes
model_config.faster_rcnn.feature_extractor.batch_norm_trainable = False

Build the detection model

detection_model = model_builder.build(model_config=model_config, is_training=False)

Restore the model weights from the checkpoint, excluding the classification and mask heads

ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
ckpt.restore(checkpoint_path).expect_partial()

Reinitialize the class prediction head to match the new number of classes

class_prediction_head = detection_model._mask_rcnn_box_predictor._class_prediction_head

Access the Dense layer responsible for class prediction

dense_layer = class_prediction_head._class_predictor_layers[1]

num_units = num_classes + 1 # Add 1 for the background class

Create a new Dense layer with the correct number of classes

new_dense_layer = tf.keras.layers.Dense(
units=num_units,
activation=dense_layer.activation,
use_bias=dense_layer.use_bias,
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01), # Initialize new layer weights randomly
bias_initializer=dense_layer.bias_initializer,
kernel_regularizer=dense_layer.kernel_regularizer,
bias_regularizer=dense_layer.bias_regularizer,
activity_regularizer=dense_layer.activity_regularizer,
kernel_constraint=dense_layer.kernel_constraint,
bias_constraint=dense_layer.bias_constraint,
name=dense_layer.name # Keep the same name
)

Replace the old Dense layer with the new one in the class prediction head

class_prediction_head._class_predictor_layers[1] = new_dense_layer

Run the model through a dummy image to ensure all variables are properly initialized

image, shapes = detection_model.preprocess(tf.zeros([1, 640, 640, 3]))
prediction_dict = detection_model.predict(image, shapes)
_ = detection_model.postprocess(prediction_dict, shapes)

print('Weights restored for box predictor! Classification head initialized with the new number of classes.')

When establishing the training, my loss is around the 1. I choose the following optimizer:

learning_rate = 0.000001
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)

Is the restoration of checkpoint correct? Can Any help to detect any inconsistencies in the definition of model.

Thank you

@omzoughi omzoughi added models:research models that come under research directory type:feature labels Aug 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
models:research models that come under research directory type:feature
Projects
None yet
Development

No branches or pull requests

1 participant