中值滤波及快速中值滤波的 Python 实现

本文最后更新于:2020年8月16日 下午

写完发现比 opencv 自带的慢了十倍。😂

前言

原理倒是不难,但是对于优化而言就麻烦了。最开始看他们说利用直方图的时候我还没看懂,看了别人的源码才明白原来说的直方图就是桶排序嘛。

但是在测试性能的时候发现比 opencv 自带的函数慢了十倍,还是用了优化过的算法。不知道是哪里出了问题。

环境

  • python 3.8.1
  • opencv-python 4.2.0.34

正文

中值滤波的原理

简单地说就是对于像素点 p[x][y], 它的颜色由它及周围的像素的颜色的组成的序列的中位数来确定。这个周围可以是十字形的,可以是以他为中心的3*3 矩阵或者5*5,n*n(n 最好为奇数,这样直接取排序后中间的那个值就可以了,不然还得求平均数) 的矩阵。而这个的时间复杂度明显就至少是 O(M*N*P^2) 的,其中 M,N 分别为原图像的长宽,而 P 则为你用于取色的那个窗口的大小。这性能开销明显是不可接受的,那么怎么优化呢?

快速中值滤波的原理

通过上面的分析我们可以知道,这里面最大的开销就是排序的这个过程了。而仔细观察的话就能发现,排序的这个过程会有相当大程度的重复。

就拿 3*3 的矩阵来举例。每次向右移动一个像素的话,最左边的 p[x-1]这一列的像素是没有用的,而中间和右边的 p[x],p[x+1]这两列在接下来的排序还是可以继续使用的。那么应该怎么去实现呢?

初始化

这里我们可以通过计数排序来实现。
我们先建立一个有 256 个位数的数组Histogram,这里面每一位表示这个颜色的像素有几个。没有就记为 0。

这里选取 256 的原因是因为我是对单一通道进行处理的,每一个颜色的范围是[0,255]。如果用其他色彩空间的话可以根据需求更改。

现在我们开始对这个点进行维护,对于每一列的最下面的那个点,直接将周围的像素存进 Histogram 就可以了。而要写入的新的颜色值,就是根据这个 Histogram 得到的中位数。当然了,也可以直接通过 numpy 自带的求中位数的函数 numpy.median() 来求得。

计数排序求中位数

我们每次都要确定中位数,而在这种计数排序中,为了方便确定,我们引入一个 n 用来记住小于等于中位数的数字的个数。

每次在更新了数组 Histogram 之后,要及时的修改 n 的值。这样方便接下来的操作。

现在在更新完了数组 Histogram 之后,我们应该及时地更新中位数med。我们可以像下面这样做👇

if n > 5 :
    while n > 5 :
        if med == 0 :
            break
        n = n - int(Histogram[med])
        med = med - 1
elif n < 5 :
    while n < 5 :
        med = med + 1
        n = n + int(Histogram[med])

在这里我们判断,比起更新前的中位数 med1,为了便于区别这里用med1 代替,现在的中位数 med 是大了还是小了?

这就是我们之前记录 n 的目的了。当 n > 5 的时候我们可以发现现在要求的中位数 med 应该小于之前的中位数 med1,即med < med1, 这个时候就应该让med 逐渐减小,直到第一次 n <= 5 为止,这个时候的 med 就是我们要求的 med 了。

当然还有特殊情况就是当 med = 0 的时候。这个时候是不会再小了,那么直接退出就可以了。

同样的逻辑可以解决 n < 5 的情况。而这里没有特殊情况的原因是因为我们之前定义 n 的时候,是把它记为 记住小于等于中位数的数字的个数
如果 med1 = 255 那么可以知道,n 此时是必然等于 9 的。因为不存在比 255 更大的颜色了。

源码实现

要对比性能的话可以修改 medianBlur() 函数中的

# b = medianBlurChannel(b)
# g = medianBlurChannel(g)
# r = medianBlurChannel(r)
b = fastMedianBlurChannel(b)
g = fastMedianBlurChannel(g)
r = fastMedianBlurChannel(r)

这一段,上面的是中值滤波,下面的是快速中值滤波。
下面是源码👇

import cv2 
import numpy
import copy
import datetime
# 对于单一通道进行中值滤波
def medianBlurChannel(channel):
    channel2 = copy.copy(channel)
    for x in range(1,len(channel)-1):
        for y in range(1,len(channel[0])-1):
            channel2[x][y]=numpy.median(channel[x-1:x+2,y-1:y+2])
    return channel2

# 快速中值滤波
def fastMedianBlurChannel(channel):
    channel2 = copy.copy(channel)
    for x in range(1,len(channel)-1):
        # 初始化直方图
        Histogram = numpy.zeros(256,dtype=int)
        # 对每一列的第一个像素初始化
        # 它的颜色取 channel[x][1]它及周围的八个像素的颜色的中位数
        med = int(numpy.median(channel[x-1:x+2,0:3]))
        # 用 n 确定中值的偏移量
        n = 0
        for i in range(-1,2):
            for j in range(0,3):
                Histogram[channel[x+i][j]] = Histogram[channel[x+i][j]] +1
                if channel[x+i][j] <= med :
                    n = n + 1
        for y in range(1,len(channel[0])-1):
            if y == 1:
                pass
            else :
                # 更新直方图,并更新 n 的值
                for i in range(-1,2) :
                    # 在直方图中删除下方不用的像素
                    Histogram[channel[x+i][y-2]] = Histogram[channel[x+i][y-2]] - 1
                    if channel[x+i][y-2] <= med :
                        n = n - 1
                    # 在直方图中添加上方要用的像素
                    Histogram[channel[x+i][y+1]] = Histogram[channel[x+i][y+1]] + 1
                    if channel[x+i][y+1] <= med :
                        n = n + 1
                # 更新 med 的值
                if n > 5 :
                    while n > 5 :
                        if med == 0 :
                            break
                        n = n - int(Histogram[med])
                        med = med - 1
                elif n < 5 :
                    while n < 5 :
                        med = med + 1
                        n = n + int(Histogram[med])
            # 存入结果
            channel2[x][y] = med
    return channel2    
    
# 对图像进行中值滤波
def medianBlur(image):
    shape = list(image.shape)
    if len(shape) != 3:
        print(" 请输入三通道的图像!")
        return
    else:
        b,g,r=cv2.split(image)
        # b = medianBlurChannel(b)
        # g = medianBlurChannel(g)
        # r = medianBlurChannel(r)
        b = fastMedianBlurChannel(b)
        g = fastMedianBlurChannel(g)
        r = fastMedianBlurChannel(r)
        return cv2.merge([b,g,r])

if __name__ == "__main__":
    start = datetime.datetime.now()
    img = cv2.imread('45-salt.jpg')
    img2 = medianBlur(img)
    img3 = cv2.medianBlur(img,3)
    cv2.imshow('src',img)
    cv2.imshow('result',img2)
    cv2.imshow('standard',img3)
    end = datetime.datetime.now()
    print(end-start)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

参考

快速中值滤波——Python 实现