Source code for ssds.core.checkpoint

import torch

import os
from collections import OrderedDict


def model_to_cpu(model_state):
    r""" make sure the model is load from cpu memory. In this case, the loaded model will not occupied the gpu memory.

    :meta private:
    """
    new_state = OrderedDict()
    for k, v in model_state.items():
        new_state[k] = v.cpu()
    return new_state


[docs]def save_checkpoints(model, output_dir, checkpoint_prefix, epochs): r"""Save the model parameter to a pth file. Args: model: the ssds model output_dir (str): the folder for model saving, usually defined by cfg.EXP_DIR checkpoint_prefix (str): the prefix for the checkpoint, usually is the combination of the ssd model and the dataset epochs (int): the epoch for the current training """ if not os.path.exists(output_dir): os.makedirs(output_dir) filename = checkpoint_prefix + "_epoch_{:d}".format(epochs) + ".pth" filename = os.path.join(output_dir, filename) torch.save(model_to_cpu(model.state_dict()), filename) with open(os.path.join(output_dir, "checkpoint_list.txt"), "a") as f: f.write("epoch {epoch:d}: {filename}\n".format(epoch=epochs, filename=filename)) print("Wrote snapshot to: {:s}".format(filename))
[docs]def find_previous_checkpoint(output_dir): r"""Return the most recent checkpoint in the checkpoint_list.txt checkpoint_list.txt is usually saved at cfg.EXP_DIR Args: output_dir (str): the folder contains the previous checkpoints and checkpoint_list.txt """ if not os.path.exists(os.path.join(output_dir, "checkpoint_list.txt")): return False with open(os.path.join(output_dir, "checkpoint_list.txt"), "r") as f: lineList = f.readlines() epoches, resume_checkpoints = [list() for _ in range(2)] for line in lineList: epoch = int(line[line.find("epoch ") + len("epoch ") : line.find(":")]) checkpoint = line[line.find(":") + 2 : -1] epoches.append(epoch) resume_checkpoints.append(checkpoint) return epoches, resume_checkpoints
[docs]def resume_checkpoint(model, resume_checkpoint, resume_scope=""): r"""Resume the checkpoints to the given ssds model based on the resume_scope. The resume_scope is defined by cfg.TRAIN.RESUME_SCOPE. When: * cfg.TRAIN.RESUME_SCOPE = "" All the parameters in the resume_checkpoint are resumed to the model * cfg.TRAIN.RESUME_SCOPE = "a,b,c" Only the the parameters in the a, b and c are resumed to the model Args: model: the ssds model resume_checkpoint (str): the file address for the checkpoint which contains the resumed parameters resume_scope: the scope of the resumed parameters, defined at cfg.TRAIN.RESUME_SCOPE """ if resume_checkpoint == "" or not os.path.isfile(resume_checkpoint): print(("=> no checkpoint found at '{}'".format(resume_checkpoint))) return False print(("=> loading checkpoint '{:s}'".format(resume_checkpoint))) checkpoint = torch.load(resume_checkpoint, map_location=torch.device("cpu")) if "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] # print("=> Weigths in the checkpoints:") # print([k for k, v in list(checkpoint.items())]) # remove the module in the parrallel model if "module." in list(checkpoint.items())[0][0]: pretrained_dict = { ".".join(k.split(".")[1:]): v for k, v in list(checkpoint.items()) } checkpoint = pretrained_dict # change the name of the weights which exists in other model # change_dict = { # } # for k, v in list(checkpoint.items()): # for _k, _v in list(change_dict.items()): # if _k in k: # new_key = k.replace(_k, _v) # checkpoint[new_key] = checkpoint.pop(k) # remove the output layers from the checkpoint # remove_list = { # } # for k in remove_list: # checkpoint.pop(k+'.weight', None) # checkpoint.pop(k+'.bias', None) # extract the weights based on the resume scope if resume_scope != "": pretrained_dict = {} for k, v in list(checkpoint.items()): for resume_key in resume_scope.split(","): if resume_key in k: pretrained_dict[k] = v break checkpoint = pretrained_dict pretrained_dict = {k: v for k, v in checkpoint.items() if k in model.state_dict()} # print("=> Resume weigths:") # print([k for k, v in list(pretrained_dict.items())]) checkpoint = model.state_dict() unresume_dict = set(checkpoint) - set(pretrained_dict) if len(unresume_dict) != 0: print("=> UNResume weigths:") print(unresume_dict) checkpoint.update(pretrained_dict) model.load_state_dict(checkpoint) return model