Tensorflow Object Detection API 源码分析之 utils/label_map_util.py
protos/string_int_label_map.proto
syntax = "proto2";package object_detection.protos;message StringIntLabelMapItem {// String name. The most common practice is to set this to a MID or synsets// id.optional string name = 1;// Integer id that maps to the string name above. Label ids should start from// 1.optional int32 id = 2;// Human readable string label.optional string display_name = 3;
};message StringIntLabelMap {repeated StringIntLabelMapItem item = 1;
};
utils/label_map_util.py
# 有关 label map 的辅助函数
"""Label map utility functions."""import loggingimport tensorflow as tf
from google.protobuf import text_format
from object_detection.protos import string_int_label_map_pb2# background默认为0,所以其他的不能为0,否则报错
def _validate_label_map(label_map):"""Checks if a label map is valid.Args:label_map: StringIntLabelMap to validate.Raises:ValueError: if label map is invalid."""for item in label_map.item:if item.id < 0:raise ValueError('Label map ids should be >= 0.')if (item.id == 0 and item.name != 'background' anditem.display_name != 'background'):raise ValueError('Label map id 0 is reserved for the background label')# 返回 类别 和 index 的字典 {key为index,value为类别id}
# 例如 category_index[1] = id ?
def create_category_index(categories):"""Creates dictionary of COCO compatible categories keyed by category id.Args:categories: a list of dicts, each of which has the following keys:'id': (required) an integer id uniquely identifying this category.'name': (required) string representing category namee.g., 'cat', 'dog', 'pizza'.Returns:category_index: a dict containing the same entries as categories, but keyedby the 'id' field of each category."""category_index = {}for cat in categories:category_index[cat['id']] = catreturn category_index# 返回最大的index
def get_max_label_map_index(label_map):"""Get maximum index in label map.Args:label_map: a StringIntLabelMapProtoReturns:an integer"""return max([item.id for item in label_map.item])# 将label_map 转化为 categories 列表 [{'id': item.id, 'name': name} , ...]
# 可选 name 或者 display_name
def convert_label_map_to_categories(label_map,max_num_classes,use_display_name=True):"""Loads label map proto and returns categories list compatible with eval.This function loads a label map and returns a list of dicts, each of whichhas the following keys:'id': (required) an integer id uniquely identifying this category.'name': (required) string representing category namee.g., 'cat', 'dog', 'pizza'.We only allow class into the list if its id-label_id_offset isbetween 0 (inclusive) and max_num_classes (exclusive).If there are several items mapping to the same id in the label map,we will only keep the first one in the categories list.Args:label_map: a StringIntLabelMapProto or None. If None, a default categorieslist is created with max_num_classes categories.max_num_classes: maximum number of (consecutive) label indices to include.use_display_name: (boolean) choose whether to load 'display_name' fieldas category name. If False or if the display_name field does not exist,uses 'name' field as category names instead.Returns:categories: a list of dictionaries representing all possible categories."""categories = []list_of_ids_already_added = []if not label_map:label_id_offset = 1for class_id in range(max_num_classes):categories.append({'id': class_id + label_id_offset,'name': 'category_{}'.format(class_id + label_id_offset)})return categoriesfor item in label_map.item:if not 0 < item.id <= max_num_classes:logging.info('Ignore item %d since it falls outside of requested ''label range.', item.id)continueif use_display_name and item.HasField('display_name'):name = item.display_nameelse:name = item.nameif item.id not in list_of_ids_already_added:list_of_ids_already_added.append(item.id)categories.append({'id': item.id, 'name': name})return categories# 从文件载入 label map,下一个函数使用
def load_labelmap(path):"""Loads label map proto.Args:path: path to StringIntLabelMap proto text file.Returns:a StringIntLabelMapProto"""with tf.gfile.GFile(path, 'r') as fid:label_map_string = fid.read()label_map = string_int_label_map_pb2.StringIntLabelMap()try:text_format.Merge(label_map_string, label_map)except text_format.ParseError:label_map.ParseFromString(label_map_string)_validate_label_map(label_map)return label_map# 获取 label_map 字典,返回形式 label_map_dict['background'] = 0
def get_label_map_dict(label_map_path,use_display_name=False,fill_in_gaps_and_background=False):"""Reads a label map and returns a dictionary of label names to id.Args:label_map_path: path to StringIntLabelMap proto text file.use_display_name: whether to use the label map items' display names as keys.fill_in_gaps_and_background: whether to fill in gaps and background withrespect to the id field in the proto. The id: 0 is reserved for the'background' class and will be added if it is missing. All other missingids in range(1, max(id)) will be added with a dummy class name("class_<id>") if they are missing.Returns:A dictionary mapping label names to id.Raises:ValueError: if fill_in_gaps_and_background and label_map has non-integer ornegative values."""label_map = load_labelmap(label_map_path)label_map_dict = {}for item in label_map.item:if use_display_name:label_map_dict[item.display_name] = item.idelse:label_map_dict[item.name] = item.idif fill_in_gaps_and_background:values = set(label_map_dict.values())if 0 not in values:label_map_dict['background'] = 0if not all(isinstance(value, int) for value in values):raise ValueError('The values in label map must be integers in order to''fill_in_gaps_and_background.')if not all(value >= 0 for value in values):raise ValueError('The values in the label map must be positive.')if len(values) != max(values) + 1:# there are gaps in the labels, fill in gaps.for value in range(1, max(values)):if value not in values:label_map_dict['class_' + str(value)] = valuereturn label_map_dict# 建立从类别索引 到 label_map的字典
# {1: {'id': 1, 'name': 'dog'}, 2: {'id': 2, 'name': 'cat'}, ...}
# model_lib.py 中调用的函数,重要
def create_category_index_from_labelmap(label_map_path):"""Reads a label map and returns a category index.Args:label_map_path: Path to `StringIntLabelMap` proto text file.Returns:A category index, which is a dictionary that maps integer ids to dictscontaining categories, e.g.{1: {'id': 1, 'name': 'dog'}, 2: {'id': 2, 'name': 'cat'}, ...}"""label_map = load_labelmap(label_map_path)max_num_classes = max(item.id for item in label_map.item)categories = convert_label_map_to_categories(label_map, max_num_classes)return create_category_index(categories)# 建立类别无关的索引,只有一个object物体的类别
def create_class_agnostic_category_index():"""Creates a category index with a single `object` class."""return {1: {'id': 1, 'name': 'object'}}
转自:
Tensorflow Object Detection API 源码分析之 utils/label_map_util.py:这个人的博客里还有关于其它目标检测API的代码