import torch
import os
from collections import OrderedDict
def model_to_cpu(model_state):
    r""" make sure the model is load from cpu memory. In this case, the loaded model will not occupied the gpu memory.
    :meta private:
    """
    new_state = OrderedDict()
    for k, v in model_state.items():
        new_state[k] = v.cpu()
    return new_state
[docs]def save_checkpoints(model, output_dir, checkpoint_prefix, epochs):
    r"""Save the model parameter to a pth file.
    Args:
        model: the ssds model
        output_dir (str): the folder for model saving, usually defined by cfg.EXP_DIR
        checkpoint_prefix (str): the prefix for the checkpoint, usually is the combination of the ssd model and the dataset
        epochs (int): the epoch for the current training
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    filename = checkpoint_prefix + "_epoch_{:d}".format(epochs) + ".pth"
    filename = os.path.join(output_dir, filename)
    torch.save(model_to_cpu(model.state_dict()), filename)
    with open(os.path.join(output_dir, "checkpoint_list.txt"), "a") as f:
        f.write("epoch {epoch:d}: {filename}\n".format(epoch=epochs, filename=filename))
    print("Wrote snapshot to: {:s}".format(filename)) 
[docs]def find_previous_checkpoint(output_dir):
    r"""Return the most recent checkpoint in the checkpoint_list.txt
    
    checkpoint_list.txt is usually saved at cfg.EXP_DIR
    Args:
        output_dir (str): the folder contains the previous checkpoints and checkpoint_list.txt
    """
    if not os.path.exists(os.path.join(output_dir, "checkpoint_list.txt")):
        return False
    with open(os.path.join(output_dir, "checkpoint_list.txt"), "r") as f:
        lineList = f.readlines()
    epoches, resume_checkpoints = [list() for _ in range(2)]
    for line in lineList:
        epoch = int(line[line.find("epoch ") + len("epoch ") : line.find(":")])
        checkpoint = line[line.find(":") + 2 : -1]
        epoches.append(epoch)
        resume_checkpoints.append(checkpoint)
    return epoches, resume_checkpoints 
[docs]def resume_checkpoint(model, resume_checkpoint, resume_scope=""):
    r"""Resume the checkpoints to the given ssds model based on the resume_scope.
    The resume_scope is defined by cfg.TRAIN.RESUME_SCOPE.
    When:
    * cfg.TRAIN.RESUME_SCOPE = ""
        All the parameters in the resume_checkpoint are resumed to the model
    * cfg.TRAIN.RESUME_SCOPE = "a,b,c"
        Only the the parameters in the a, b and c are resumed to the model
    
    Args:
        model: the ssds model
        resume_checkpoint (str): the file address for the checkpoint which contains the resumed parameters
        resume_scope: the scope of the resumed parameters, defined at cfg.TRAIN.RESUME_SCOPE
    """
    if resume_checkpoint == "" or not os.path.isfile(resume_checkpoint):
        print(("=> no checkpoint found at '{}'".format(resume_checkpoint)))
        return False
    print(("=> loading checkpoint '{:s}'".format(resume_checkpoint)))
    checkpoint = torch.load(resume_checkpoint, map_location=torch.device("cpu"))
    if "state_dict" in checkpoint:
        checkpoint = checkpoint["state_dict"]
    # print("=> Weigths in the checkpoints:")
    # print([k for k, v in list(checkpoint.items())])
    # remove the module in the parrallel model
    if "module." in list(checkpoint.items())[0][0]:
        pretrained_dict = {
            ".".join(k.split(".")[1:]): v for k, v in list(checkpoint.items())
        }
        checkpoint = pretrained_dict
    # change the name of the weights which exists in other model
    # change_dict = {
    # }
    # for k, v in list(checkpoint.items()):
    #     for _k, _v in list(change_dict.items()):
    #         if _k in k:
    #             new_key = k.replace(_k, _v)
    #             checkpoint[new_key] = checkpoint.pop(k)
    # remove the output layers from the checkpoint
    # remove_list = {
    # }
    # for k in remove_list:
    #     checkpoint.pop(k+'.weight', None)
    #     checkpoint.pop(k+'.bias', None)
    # extract the weights based on the resume scope
    if resume_scope != "":
        pretrained_dict = {}
        for k, v in list(checkpoint.items()):
            for resume_key in resume_scope.split(","):
                if resume_key in k:
                    pretrained_dict[k] = v
                    break
        checkpoint = pretrained_dict
    pretrained_dict = {k: v for k, v in checkpoint.items() if k in model.state_dict()}
    # print("=> Resume weigths:")
    # print([k for k, v in list(pretrained_dict.items())])
    checkpoint = model.state_dict()
    unresume_dict = set(checkpoint) - set(pretrained_dict)
    if len(unresume_dict) != 0:
        print("=> UNResume weigths:")
        print(unresume_dict)
    checkpoint.update(pretrained_dict)
    model.load_state_dict(checkpoint)
    return model