def adjust_input(in_data):#调整输出
""" adjust the input from (h, w, c) to ( 1, c, h, w) for network inputParameters:---------- in_data: numpy array of shape (h, w, c) input dataReturns:------- out_data: numpy array of shape (1, c, h, w) reshaped array"""if in_data.dtype is not np.dtype('float32'): out_data = in_data.astype(np.float32)else: out_data = in_dataout_data = out_data.transpose((2,0,1))out_data = np.expand_dims(out_data, 0)out_data = (out_data - 127.5)*0.0078125return out_data
def generate_bbox(map, reg, scale, threshold):#金属期货生成bbox
""" generate bbox from feature map Parameters: ---------- map: numpy array , n x m x 1 detect score for each position reg: numpy array , n x m x 4 bbox scale: float number scale of this detection threshold: float number detect threshold Returns: ------- bbox array """ stride = 2 cellsize = 12 t_index = np.where(map>threshold) # find nothing if t_index[0].size == 0: return np.array([]) dx1, dy1, dx2, dy2 = [reg[0, i, t_index[0], t_index[1]] for i in range(4)] reg = np.array([dx1, dy1, dx2, dy2]) score = map[t_index[0], t_index[1]] boundingbox = np.vstack([np.round((stride*t_index[1]+1)/scale), np.round((stride*t_index[0]+1)/scale), np.round((stride*t_index[1]+1+cellsize)/scale), np.round((stride*t_index[0]+1+cellsize)/scale), score, reg]) return boundingbox.T
def detect_first_stage(img, net, scale, threshold):#检测第一阶段
""" run PNet for first stageParameters:---------- img: numpy array, bgr order input image scale: float number how much should the input image scale net: PNet workerReturns:------- total_boxes : bboxes"""height, width, _ = img.shapehs = int(math.ceil(height * scale))ws = int(math.ceil(width * scale))im_data = cv2.resize(img, (ws,hs))# adjust for the network inputinput_buf = adjust_input(im_data)output = net.predict(input_buf)boxes = generate_bbox(output[1][0,1,:,:], output[0], scale, threshold)if boxes.size == 0: return None# nmspick = nms(boxes[:,0:5], 0.5, mode='Union')boxes = boxes[pick]return boxes
def detect_first_stage_warpper( args ):
return detect_first_stage(*args)