Source code for wandb_callbacks.tensorflow

import cv2
import matplotlib.pyplot as plt
import numpy as np
import wandb
from tensorflow.keras import backend as K

import tensorflow as tf
from tensorflow import keras
from wandb_callbacks.utils import GradCAM


[docs]class ActivationCallback(tf.keras.callbacks.Callback): """ActivationCallback.""" def __init__(self, validation_data, layer_name, log_frequency=5): """__init__. Initializes the ActivationCallback. :param validation_data: Tuple of form (sample_images, sample_labels). :param layer_name: String of the layer name that should be visualised. :param log_frequency: How often the activations should be logged (in epochs). """ super(ActivationCallback, self).__init__() self.validation_data = validation_data self.layer_name = layer_name self.log_frequency = log_frequency
[docs] def on_epoch_end(self, epoch, logs={}): """on_epoch_end. Called at the end of an epoch. :param epoch: Integer, index of epoch. :param logs: Dict, metric results for this training epoch, and for the validation epoch if validation is performed. Validation result keys are prefixed with `val_`. For training epoch, the values of the `Model`'s metrics are returned. Example : `{'loss': 0.2, 'acc': 0.7}`. """ # check if the callback should log or not if (epoch % self.log_frequency) != 0: return # Build intermediate layer with the target layer self.intermediate_model = keras.models.Model( inputs=self.model.input, outputs=self.model.get_layer(self.layer_name).output) # Unpack validation data images, labels = self.validation_data for image, label in zip(images, labels): # Compute output activation of the provided layer name img = np.expand_dims(image, axis=0) features = self.intermediate_model.predict(img) features = features.reshape(features.shape[1:]) features = np.rollaxis(features, 2, 0) # Prepare the plot to be logged to wandb n_features = features.shape[0] # get all divisors of the n. of features divisors = self.__get_divisors(n_features) # middle element of the divisors are n. of columns ncols = divisors[int(len(divisors) / 2)] nrows = int(n_features / ncols) fig, axs = plt.subplots( nrows=nrows, ncols=ncols, figsize=(15, 8)) # plot the figures c = 0 for i in range(nrows): for j in range(ncols): axs[i][j].imshow(features[c], cmap='gray') axs[i][j].set_xticks([]) axs[i][j].set_yticks([]) c += 1 wandb.log({"features_labels_{}".format(label): plt}) plt.close()
@staticmethod def __get_divisors(n): """__get_divisors. Gets the divisors of a given number. :param n: number to get the divisors. """ divisors = [] for i in range(1, int(n / 2) + 1): if n % i == 0: divisors.append(i) divisors.append(n) return divisors
[docs]class DeadReluCallback(tf.keras.callbacks.Callback): """DeadReluCallback. Reports the number of dead ReLUs after each training epoch. ReLU is considered to be dead if it did not fire once for entire training set. """ def __init__(self, x_train, log_frequency=1, dead_threshold=0.8, verbose=False): """__init__. Initializes the DeadReluCallback. :param x_train: Training dataset to check whether or not neurons fire. :param log_frequency: How often the activations should be logged (in epochs). :param dead_threshold: If this threshold of dead neurons is exceeded, the callback will print a warning. :param verbose: verbosity mode. `True` means that even a single dead neuron triggers a warning message. `False` means that only significant number of dead neurons (e.g. 10%) triggers a warning message. """ super(DeadReluCallback, self).__init__() self.x_train = x_train self.log_frequency = log_frequency self.verbose = verbose self.dead_neurons_share_threshold = dead_threshold
[docs] @staticmethod def is_relu_layer(layer): """is_relu_layer. Checks if a certain layer contains a RELU activation. :param layer: layer object to check. """ # Should work for all layers with relu # activation. Tested for Dense and Conv2D return layer.get_config().get('activation', None) == 'relu'
[docs] def get_relu_activations(self): """get_relu_activations. Retreives all RELU activations of the current model. """ model_input = self.model.input funcs = {} for index, layer in enumerate(self.model.layers): if not layer.get_weights(): continue funcs[index] = keras.models.Model( inputs=model_input, outputs=layer.output) layer_outputs = {} for index, func in funcs.items(): layer_outputs[index] = tf.dtypes.cast( func(self.x_train)[0], tf.float64) for layer_index, layer_activations in layer_outputs.items(): if self.is_relu_layer(self.model.layers[layer_index]): layer_name = self.model.layers[layer_index].name # layer_weight is a list [W] (+ [b]) layer_weight = self.model.layers[layer_index].get_weights() # with kernel and bias, the weights are saved as a list [W, b]. # If only weights, it is [W] if not isinstance(layer_weight, list): raise ValueError("'Layer_weight' should be a list, " "but was {}".format(type(layer_weight))) # there are no weights for current layer; skip it # this is only legitimate if layer is "Activation" if len(layer_weight) == 0: continue layer_weight_shape = np.shape(layer_weight[0]) yield (layer_index, layer_activations, layer_name, layer_weight_shape)
[docs] def on_epoch_end(self, epoch, logs={}): """on_epoch_end. Called at the end of an epoch. :param epoch: Integer, index of epoch. :param logs: Dict, metric results for this training epoch, and for the validation epoch if validation is performed. Validation result keys are prefixed with `val_`. For training epoch, the values of the `Model`'s metrics are returned. Example : `{'loss': 0.2, 'acc': 0.7}`. """ # check if the callback should log or not if (epoch % self.log_frequency) != 0: return # lists to compute final average over all layers l_dead_neurons = [] l_percentage_dead = [] # loop over all relu activations for relu_activation in self.get_relu_activations(): layer_index, activation_values, layer_name, layer_weight_shape = relu_activation shape_act = activation_values.shape weight_len = len(layer_weight_shape) act_len = len(shape_act) # should work for both Conv and Flat if K.image_data_format() == 'channels_last': # features in last axis axis_filter = -1 else: # features before the convolution axis, for weight_ # len the input and output have to be subtracted axis_filter = -1 - (weight_len - 2) total_featuremaps = shape_act[axis_filter] axis = [] for i in range(act_len): if (i != axis_filter) and ( i != (len(shape_act) + axis_filter)): axis.append(i) axis = tuple(axis) dead_neurons = np.sum(np.sum(activation_values, axis=axis) == 0.0, dtype='double') dead_neurons_share = float(dead_neurons) / float(total_featuremaps) if self.verbose and dead_neurons > 0 \ or dead_neurons_share >= self.dead_neurons_share_threshold: # print the warning str_warning = ( 'Layer {} (#{}) has {} ' 'dead neurons ({:.2%})!').format( layer_name, layer_index, dead_neurons, dead_neurons_share) print(str_warning) # log to wandb percentage_dead_neurons = round(dead_neurons_share * 100, 2) wandb.log({'n. of dead relus/Layer {} (#{})'.format(layer_name, layer_index): dead_neurons, 'percentage dead relus/Layer {} (#{})'.format(layer_name, layer_index): percentage_dead_neurons}) # append to overall list l_dead_neurons.append(dead_neurons) l_percentage_dead.append(percentage_dead_neurons) # log summary of all layers l_dead_neurons = np.asarray(l_dead_neurons) l_percentage_dead = np.asarray(l_percentage_dead) wandb.log({'n. of dead relus/overall mean': l_dead_neurons.mean(), 'percentage dead relus/overall mean': l_percentage_dead.mean()})
[docs]class GRADCamCallback(tf.keras.callbacks.Callback): """GRADCamCallback.""" def __init__(self, validation_data, layer_name, log_frequency=10): """__init__. Initializes the GRADCamCallback. :param validation_data: Tuple of form (sample_images, sample_labels). :param layer_name: String of the layer name that should be visualised. :param log_frequency: How often the GRADCam should be logged (in epochs). Should be chosen in consideration, that using this callback adds some additional runtime. """ super(GRADCamCallback, self).__init__() self.validation_data = validation_data self.layer_name = layer_name self.log_frequency = log_frequency
[docs] def on_epoch_end(self, epoch, logs={}): """on_epoch_end. Called at the end of an epoch. :param epoch: Integer, index of epoch. :param logs: Dict, metric results for this training epoch, and for the validation epoch if validation is performed. Validation result keys are prefixed with `val_`. For training epoch, the values of the `Model`'s metrics are returned. Example : `{'loss': 0.2, 'acc': 0.7}`. """ # check if the callback should log or not if (epoch % self.log_frequency) != 0: return # list for the images and the GRAD images that will be passed to wandb images = [] grad_cam = [] # Initialize GRADCam Class cam = GradCAM(self.model, self.layer_name) for image in self.validation_data: image = np.expand_dims(image, 0) pred = self.model.predict(image) classIDx = np.argmax(pred[0]) # Compute Heatmap heatmap = cam.compute_heatmap(image, classIDx) image = image.reshape(image.shape[1:]) image = image * 255 image = image.astype(np.uint8) # Overlay heatmap on original image heatmap = cv2.resize(heatmap, (image.shape[0], image.shape[1])) (heatmap, output) = cam.overlay_heatmap(heatmap, image, alpha=0.5) images.append(image) grad_cam.append(output) grad_cam.append(heatmap) wandb.log({"images": [wandb.Image(image) for image in images]}) wandb.log({"gradcam": [wandb.Image(cam) for cam in grad_cam]})