Python CGDAL类——支持栅格数据的栅格计算/线性增强/滤波增强

# -*- coding: UTF-8 -*-

'''
python version: 2.7.11
numpy  ver=1.11.1
gdal   ver=2.0.3
Author: Liuph
Date: 2016/9/9
Description: This is a GDAL Class adapted from the Python GDAL_OGR Cookbook documentation
(http://pcjericks.github.io/py-gdalogr-cookbook/raster_layers.html). The CGDAL Class can used to
load Image and obtain description information of the image file. Usually, basic image processing
is included in the class and a linear enhancement or a spatial filtering operation can well performed
via CGDAL Class. Moreover, this class offers some functions to generate a raster image via a numpy array.
However,it' does not process perfection in exception handling, for instance, while a image with "None"
as nodata value might lead to some puzzles, or, it wont't happen.It need to be confirmed.
 Addition, as all know, codes in python formats is very concise and distinct. However, the ratio of running speed of python
 codes and C++ codes is about first in thirty ,which is a serious and longstanding problem, or shortcut.
 May python be better!
'''



from osgeo import gdal, gdalnumeric, ogr, osr
from PIL import Image, ImageDraw
import os, sys
from gdalconst import *
import struct
import numpy as np
import re
gdal.UseExceptions()


class CGDAL:
    #数据部分
    mpoDataset = None
    __mpData = None
    mpArray = np.array([])
    mgDataType = GDT_Byte
    mnRows = mnCols = mnBands = -1
    mnDatalength = -1
    mpGeoTransfor = []
    msProjectionRef = ""
    msFilename = ""
    mdInvalidValue = 0.0
    mnPerPixSize = 1

    srcSR = None
    latLongSR = None
    poTransform = None
    poTransformT = None

    #函数部分
    def __init__(self):
        pass

    def __del__(self):
        self.mpoDataset = None
        self.__mpData = None
        self.mpArray = np.array([])
        self.mgDataType = GDT_Byte
        self.mnRows = self.mnCols = self.mnBands = -1
        self.mnDatalength = -1
        self.mpGeoTransform = []
        self.msProjectionRef = ""
        self.msFilename = ""
        self.mdInvalidValue = 0.0
        self.mnPerPixSize = 1

        self.srcSR = None
        self.latLongSR = None
        self.poTransform = None
        self.poTransformT = None


    def read(self, band, row, col):
        return self.mpArray[band, row, col]

    def printimg(self):
        print self.mpArray

    def isValid(self):
        if self.__mpData == None or self.mpoDataset == None:
            return False
        return True

    def world2Pixel(self, lat, lon):
        if self.poTransformT is not None:
            CST = osr.CoordinateTransformation(self.poTransformT)
            CST.TransformPoint(lon, lat)

            adfInverseGeoTransform = []
            x = y = 0.0
            gdal.InvGeoTransform(self.mpGeoTransform, adfInverseGeoTransform)
            gdal.ApplyGeoTransform(adfInverseGeoTransform, lon, lat, x, y)
        return {'x': x, 'y': y}

    def pixel2World(self, x, y):
        if self.poTransform is not None:
            self.poTransform = None
            self.poTransform = osr.CoordinateTransformation(self.latLongSR, self.srcSR)

        lon = lat = 0.0
        gdal.ApplyGeoTransform(self.mpGeoTransform, x, y, lon, lat)
        if self.poTransform is not None:
            CST = osr.CoordinateTransformation(self.poTransform)
            CST.TransformPoint(lon, lat)
        return {'lon': lon, 'lat': lat}

    def pixel2Ground(self, x, y):
        pX = pY = 0.0
        gdal.ApplyGeoTransform(self.mpGeoTransform, x, y, pX, pY)
        return {'pX': pX, 'pY': pY}

    def ground2Pixel(self, pX, pY):
        x = y = 0.0
        adfInverseGeoTransform = []
        gdal.InvGeoTransform(self.mpGeoTransform, adfInverseGeoTransform)
        gdal.ApplyGeoTransform(adfInverseGeoTransform, pX, pY, x, y)
        return {'x': x, 'y': y}

    def loadFrom(self,filename):
        #close fore image
        self.mpoDataset = None

        #open image
        try:
            self.mpoDataset = gdal.Open( filename, GA_ReadOnly )
        except RuntimeError, e:
            print 'Unable to open %s' % filename
            print e
            return False

        self.msFilename = filename

        #get attribute
        self.mnRows = self.mpoDataset.RasterYSize
        self.mnCols = self.mpoDataset.RasterXSize
        self.mnBands = self.mpoDataset.RasterCount
        self.mgDataType = self.mpoDataset.GetRasterBand(1).DataType
        self.mdInvalidValue = self.mpoDataset.GetRasterBand(1).GetNoDataValue()

        #mapinfo
        '''
        GeoTransform[0] /* top left x */
        GeoTransform[1] /* w-e pixel resolution */
        GeoTransform[2] /* 0 */
        GeoTransform[3] /* top left y */
        GeoTransform[4] /* 0 */
        GeoTransform[5] /* n-s pixel resolution (negative value) */
        '''
        self.mpGeoTransform = self.mpoDataset.GetGeoTransform()
        self.msProjectionRef = self.mpoDataset.GetProjection()


        self.srcSR = osr.SpatialReference(self.msProjectionRef) #ground
        self.latLongSR = osr.SpatialReference()
        self.latLongSR = osr.SpatialReference.CloneGeogCS(self.srcSR ) #geo
        self.poTransform = osr.CoordinateTransformation(self.srcSR, self.latLongSR)
        self.poTransformT = osr.CoordinateTransformation(self.latLongSR, self.srcSR)


        #get data
        self.msDataType = "Byte"
        typeformat = "B"
        if self.mgDataType == GDT_Byte:
            typeformat = "B"
            self.msDataType = "Byte"
        elif self.mgDataType == GDT_UInt16:
           typeformat = "H"
           self.msDataType = "Unsigned Int 16"
        elif self.mgDataType == GDT_Int16:
            typeformat = "h"
            self.msDataType = "Signed Int 16"
        elif self.mgDataType == GDT_UInt32:
            typeformat = "I"
            self.msDataType = "Unsigned Int 32"
        elif self.mgDataType == GDT_Int32:
            typeformat = "i"
            self.msDataType = "Signed Int 32"
        elif self.mgDataType == GDT_Float32:
            typeformat = "f"
            self.msDataType = "Float 32"
        elif self.mgDataType == GDT_Float64:
            typeformat = "d"
            self.msDataType = "Float 64"
        self.__mpData = struct.unpack(typeformat*self.mnBands*self.mnCols*self.mnRows, self.mpoDataset.ReadRaster())
        self.mpArray = np.array(self.__mpData)
        self.mpArray.shape = (self.mnBands, self.mnRows, self.mnCols)
        return True


    def getRasterBand(self, band_num):
        """获取特定波段的数据
        """
        try:
            srcband = self.mpoDataset.GetRasterBand(band_num)
            return srcband
        except RuntimeError, e:
            print 'Band ( %i ) not found' % band_num
            print e
            sys.exit(0)


    def getRasterBand2Array(self, band_num):
        """获取特定波段的数据,存储为数组"""
        srcband = self.mpoDataset.GetRasterBand(band_num)
        return srcband.ReadAsArray()


    def getRasterBandStas(self, band_num):
        """获取特定波段的统计量(最小值,最大值,均值,标准差)"""
        srcband = self.mpoDataset.GetRasterBand(band_num)
        if srcband is None:
            print "Band %i is NULL" % band_num
            sys.exit(1)

        stats = srcband.GetStatistics(True, True)
        if stats is None:
            print "Statistics of Band %i is NULL" % band_num
            sys.exit(1)

        print "[ STATS ] =  Minimum=%.3f, Maximum=%.3f, Mean=%.3f, StdDev=%.3f" % (
        stats[0], stats[1], stats[2], stats[3])


    def getRasterBandInfo(self, band_num):
        """获取特定波段的描述数据"""
        srcband = self.mpoDataset.GetRasterBand(band_num)
        if srcband is None:
            print "Band %i is NULL" % band_num
            sys.exit(1)
        print "[ NO DATA VALUE ] = ", srcband.GetNoDataValue()
        print "[ MIN ] = ", srcband.GetMinimum()
        print "[ MAX ] = ", srcband.GetMaximum()
        print "[ SCALE ] = ", srcband.GetScale()
        print "[ UNIT TYPE ] = ", srcband.GetUnitType()
        ctable = srcband.GetColorTable()

        if ctable is None:
            print 'No ColorTable found'
            sys.exit(1)

        print "[ COLOR TABLE COUNT ] = ", ctable.GetCount()
        for i in range(0, ctable.GetCount()):
            entry = ctable.GetColorEntry(i)
            if not entry:
                continue
            print "[ COLOR ENTRY RGB ] = ", ctable.GetColorEntryAsRGB(i, entry)


    def getRasterBandMinVal(self, band_num):
        """获取某个波段的最小值"""
        _arr = self.mpArray[band_num-1,:,:]
        if self.mdInvalidValue != None:
            _arr[_arr == self.mdInvalidValue] = np.nan
        return np.nanmin(_arr)


    def getRasterBandMaxVal(self, band_num):
        """由于精度问题,显示一位小数,但计算不出错"""
        _arr = self.mpArray[band_num - 1, :,:]
        if self.mdInvalidValue != None:
            _arr[_arr == self.mdInvalidValue] = np.nan
        return np.nanmax(_arr)

    def getRasterBandMeanVal(self, band_num):
        """均值"""
        _arr = self.mpArray[band_num - 1, :,:]
        if self.mdInvalidValue != None:
            _arr[_arr == self.mdInvalidValue] = np.nan
        return np.nanmean(_arr)

    def getRasterBandStdVal(self, band_num):
        """标准差"""
        _arr = self.mpArray[band_num - 1, :,:]
        if self.mdInvalidValue != None:
            _arr[_arr == self.mdInvalidValue] = np.nan
        return np.nanstd(_arr)

    def getRasterBandVarVal(self, band_num):
        """方差"""
        _arr = self.mpArray[band_num - 1, :,:]
        if self.mdInvalidValue != None:
            _arr[_arr == self.mdInvalidValue] = np.nan
        return np.nanvar(_arr)


    def raster2shp(self, band_num, dst_layername):
        """栅格转矢量,慎用"""
        srcband = self.mpoDataset.GetRasterBand(band_num)
        drv = ogr.GetDriverByName("ESRI Shapefile")
        dst_ds = drv.CreateDataSource(dst_layername + ".shp")
        dst_layer = dst_ds.CreateLayer(dst_layername, srs=None)
        gdal.Polygonize(srcband, None, dst_layer, -1, [], callback=None)



    def replaceNoData2New(self, ds_fn, new_NoData):
        """用新的值替代原先的nodata值"""
        outArr = np.zeros(self.mnBands * self.mnRows * self.mnCols)
        outArr.shape = (self.mnBands, self.mnRows, self.mnCols)
        for band_num in range(1, self.mnBands + 1):
            self.mpoDataset.GetRasterBand(band_num).SetNoDataValue(-9999)
            org_Nodata = -9999

            rasterArray = self.getRasterBand2Array(band_num)
            rasterArray[rasterArray == org_Nodata] = new_NoData
            outArr[band_num - 1, :, :] = rasterArray
        array2MultiBandsrasterfn(self.msFilename, ds_fn, outArr, self.mnBands)


    def linearEnhance(self, ds_fn, _MinValue, _MaxValue):
        """线性增强处理,指定拉伸后的最大最小值,float64型"""
        outArr = np.zeros(self.mnBands * self.mnRows * self.mnCols)
        outArr.shape = (self.mnBands, self.mnRows, self.mnCols)
        for band_num in range(1, self.mnBands + 1):
            print "Linear Cal %i/%i"%(band_num, self.mnBands)
            srcband = self.mpoDataset.GetRasterBand(band_num)
            _nodata = srcband.GetNoDataValue()
            _array = self.getRasterBand2Array(band_num)
            _newarray = _array.astype(np.float32)
            _min = self.getRasterBandMinVal(band_num)
            _max = self.getRasterBandMaxVal(band_num)
            #print _min, _max
            for i in range(self.mnRows):
                for j in range(self.mnCols):
                    if _array[i][j] >= _min and _array[i][j] <= _max:
                        _newarray[i][j] = (_array[i][j] - _min) / ((_max - _min) * 1.0) * (
                        _MaxValue - _MinValue) + _MinValue
                    else:
                        _newarray[i][j] = _nodata
            outArr[band_num - 1, :, :] = _newarray
        print "Writing output data..."
        array2MultiBandsrasterfn(self.msFilename, ds_fn, outArr, self.mnBands, self.mdInvalidValue)



    def spatialFiltering(self, ds_fn, sAlgorithm = "MeanFiltering"):
        """空间滤波增强"""
        window_size = 3
        if window_size%2 == 0:
            print "Please input a uneven number for the window size!"
            sys.exit(1)
        subsize = (window_size-1)/2

        #输出文件
        outArr = np.zeros(self.mnBands * self.mnRows * self.mnCols)
        outArr.shape = (self.mnBands, self.mnRows, self.mnCols)

        algori = np.ones(1 * window_size * window_size, dtype=float)
        algori.shape = (window_size, window_size)
        # 选择算子
        if sAlgorithm == "MeanFiltering":
            algori /= (window_size*window_size)
        elif sAlgorithm == "LaplaceFiltering":
            algori = np.array([[-1.0,-1.0,-1.0],[-1.0,9,-1],[-1,-1,-1]])
        elif sAlgorithm == "WallisFiltering":
            algori = np.array([[0,-0.25,0],[-0.25,1,-0.25],[0,-0.25,0]])
        elif sAlgorithm == "SobelXFiltering":
            algori = np.array([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]])
        elif sAlgorithm == "SobelYFiltering":
            algori = np.array([[1.0, 2.0, 1.0], [0.0, 0.0, 0.0], [-1.0, -2.0, -1.0]])
        elif sAlgorithm == "LogFiltering":
            window_size = 5
            subsize = (window_size - 1) / 2
            algori = np.ones(1 * window_size * window_size, dtype=float)
            algori.shape = (window_size, window_size)
            algori = np.array([[-2.,-4.,-4.,-4.,-2.],
                               [-4.,0.,8.,0.,-4.],
                               [-4.,8.,24.,8.,-4.],
                               [-4., 0., 8., 0., -4.],
                               [-2., -4., -4., -4., -2.]
                               ])
        elif sAlgorithm == "RelievoFiltering":
            algori = np.array([[-3.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]])
        elif sAlgorithm == "HorizonalMaskFiltering":
            algori = np.array([[3.0, 3.0, 3.0], [-6.0, -6.0, -6.0], [3.0, 3.0, 3.0]])
        elif sAlgorithm == "VerticalnMaskFiltering":
            algori = np.array([[3.0, -6.0, 3.0], [3.0, -6.0, 3.0], [3.0, -6.0, 3.0]])
        elif sAlgorithm == "DiagonalMaskFiltering":
            algori = np.array([[3.0, 3.0, -6.0], [3.0, -6.0, 3.0], [-6.0, 3.0, 3.0]])
        elif sAlgorithm== "QualcommEdgeDec":
            algori = np.array([[-1.0, 0.0, -1.0], [0.0, 4.0, 0.0], [-1.0, 0.0, 1.0]])
        else:
            print "There is no such filtering algorithm called %s"%sAlgorithm
            sys.exit(1)
        print "Filtering Algorithm: \n", algori
        #波段迭代循环
        for band_num in range(1, self.mnBands +1):
            _arr = np.zeros(1 * self.mnRows * self.mnCols, dtype=float)
            _arr.shape = (self.mnRows, self.mnCols)
            for i in range(0, self.mnRows):
                for j in range(0, self.mnCols):
                    #边缘维持原像元值
                    if i<=subsize-1 or j<=subsize-1 or i>= self.mnRows-subsize or j >= self.mnCols-subsize:
                        _arr[i][j] = self.mpArray[band_num-1][i][j]
                    else:
                        for x in range(0, window_size ):
                            for y in range(0, window_size):
                                _arr[i][j] += self.mpArray[band_num-1][i - subsize + x][j - subsize + y] * algori[x][y]
            outArr[band_num - 1, :, :] = _arr
            print "Filtered %i/%i"%(band_num, self.mnBands)
        print "Writing output file..."
        array2MultiBandsrasterfn(self.msFilename, ds_fn, outArr, self.mnBands,self.mdInvalidValue)


    def rasterCalculation(self, ds_fn, expr = "(band3-band2)/(band3+band2)"):
        """栅格计算,暂时只支持用band1,2之类的形式表示各个波段"""
        mode = re.compile(r'\d+')
        m = mode.findall(expr)
        nums = np.unique(np.array(m))
        sortedNums = np.sort(nums)

        for num in sortedNums:
            expr = expr.replace(num, str(int(num)-1)+',:,:]')

        expr = expr.replace('band','1.0*self.mpArray[')
        print expr

        resultArr = eval(expr)
        array2MultiBandsrasterfn(self.msFilename,ds_fn,resultArr,1,self.mdInvalidValue)



    def printRasterAttr(self):
        """显示图像信息"""
        print "File Name: %s"%self.msFilename
        print "Rows: %i   Cols: %i   Bands: %i   Pixel Size: %.2f*%.2f"%(self.mnRows, self.mnCols,
                self.mnBands, self.mpGeoTransform[1],-self.mpGeoTransform[5])
        print "Data Type: %s   No-Data Value: "%(self.msDataType),self.mdInvalidValue
        print "SpatialRef: %s    \nProjection: %s"%(self.mpGeoTransform, self.msProjectionRef)


def array2MultiBandsrasterfn(rasterfn, newRasterfn, array, bandCount, nodata = None):
    """文件尺度上数组生成栅格文件,前者栅格文件提供描述信息(多波段)"""
    raster = gdal.Open(rasterfn)
    geotransform = raster.GetGeoTransform()
    originX = geotransform[0]
    originY = geotransform[3]
    pixelWidth = geotransform[1]
    pixelHeight = geotransform[5]
    cols = raster.RasterXSize
    rows = raster.RasterYSize

    array.shape = (bandCount, rows, cols)
    driver = gdal.GetDriverByName('GTiff')
    outRaster = driver.Create(newRasterfn, cols, rows, bandCount, gdal.GDT_Float32)
    outRaster.SetGeoTransform((originX, pixelWidth, 0, originY, 0, pixelHeight))
    outRasterSRS = osr.SpatialReference()
    outRasterSRS.ImportFromWkt(raster.GetProjectionRef())
    outRaster.SetProjection(outRasterSRS.ExportToWkt())
    for band_num in range(1, bandCount + 1):
        outband = outRaster.GetRasterBand(band_num)
        outband.SetNoDataValue(nodata)
        outband.WriteArray(array[band_num - 1, :, :])
        outband.FlushCache()
    print "write output file -- %s success!"%newRasterfn

def creatraster(newRasterfn, GeoTransform, projection, datatype, imgdata, cols, rows, bands):
    #必须使用numpy下的numpy.array作为imgdata
    if bands == 1:
        imgdata.shape = (bands, rows, cols)
    driver = gdal.GetDriverByName('GTiff')
    outRaster = driver.Create(newRasterfn, cols, rows, bands, datatype)
    outRaster.SetGeoTransform(GeoTransform)
    outRaster.SetProjection(projection)
    for i in range(bands):
        array = imgdata[i, :, :]
        outband = outRaster.GetRasterBand(i+1)
        outband.WriteArray(array)
    print "write data succeed!"


def rasterCalculations(ds_fn, expr):
    """仅支持tif格式",表达式中要写文件后缀"""
    m = re.findall(r'([a-z,A-Z,_]+[1-9,a-z,A-Z,_]*.tif)', expr)
    unim = np.unique(np.array(m))
    print unim
    i = 0
    mArrs = []
    for item in unim:
        Cgdal = CGDAL()
        Cgdal.loadFrom(item)
        if i ==0:
            no_data = Cgdal.mdInvalidValue
        if Cgdal.mnBands != 1:
            print "The input raster is not useful. Only 1 band is required, %i is given."%Cgdal.mnBands
            sys.exit(1)
        Cgdal.mpArray[Cgdal.mpArray == no_data] = np.nan
        mArrs.append(Cgdal.mpArray)
        expr = expr.replace(item,'1.0*mArrs[%i]'%i)
        i = i + 1
    print expr
    resultArr = eval(expr)
    array2MultiBandsrasterfn(unim[0],ds_fn,resultArr,1,nodata=no_data)

转载自:https://blog.csdn.net/liuph_/article/details/52491123

You may also like...

退出移动版