def tavolsag(xe,ye,x,y):
    return sqrt((xe-x)**2+(ye-y)**2)
def pontbeszur(xe,ye,x,y,resz):
    be=[]
    t=0
    for i in range(1,resz,+1):
        t=t+1/resz
        xk=(1-t)*xe+t*x
        yk=(1-t)*ye+t*y
        o=floor((yk-cy)/cs)
        s=floor((cx-xk)/cs)
        be.append([s,o])
    return be
    
fname="sometiff.tif" #INSERT HERE THE FILE NAME OF YOUR GEOTIFF (32 floating point)
from math import *
import gdal
from gdalconst import *
dataset=gdal.Open(fname,GA_ReadOnly)
t=dataset.ReadAsArray()
cm=dataset.RasterXSize
cn=dataset.RasterYSize
cs=abs(dataset.GetGeoTransform()[1])
cx=dataset.GetGeoTransform()[3]
cy=dataset.GetGeoTransform()[0]
ujt=t

import ogr, os, sys
fpath='D:/yourdirectory' #INSERT HERE THE PATH OF YOUR SHAPEFILE
os.chdir(fpath) 
 
c = [[0 for i in range(cn)] for j in range(cm)] 

fname='yourShapefileNameWITHOUTextension' #INSERT HERE THE FILENAME OF YOUR SHAPEFILE (CONTAINING RIVER DATA) without extension
fajlnev=fname+'.shp'  
import ogr, os, sys
os.chdir(fpath)
driver = ogr.GetDriverByName('ESRI Shapefile')
v=driver.Open(fajlnev,1)
layer=v.GetLayer()
vt=[]
numFeatures=layer.GetFeatureCount()
for i in range(0,numFeatures):
    feat=layer.GetNextFeature()
    geom=feat.GetGeometryRef()
    if geom.GetGeometryName()=="LINESTRING":
        pNum=geom.GetPointCount()
        ut=[]
        for j in range(0,pNum):
            line = ogr.Geometry(ogr.wkbLineString)
            if j==0:
                y=geom.GetPoints(i)[j][0]
                x=geom.GetPoints(i)[j][1]
                o=floor((y-cy)/cs)
                s=floor((cx-x)/cs)
                c[s][o]=1
                ut.append([y,x,t[s][o]])
            else:
                y=geom.GetPoints(i)[j][0]
                x=geom.GetPoints(i)[j][1]
                ye=geom.GetPoints(i)[j-1][0]
                xe=geom.GetPoints(i)[j-1][1]
                tav=tavolsag(xe,ye,x,y)
                if x!=xe and y!=ye:
                    if tav>cs:
                        resz=ceil(tav/cs)
                        if tav/cs>1:
                            be=[]
                            be=pontbeszur(xe,ye,x,y,resz)  
                            for k in range(0,len(be)):
                                s=be[k][0]
                                o=be[k][1]
                                c[s][o]=1
                                ut.append([y,x,t[s][o]])
                                
                                
                    o=floor((y-cy)/cs)
                    s=floor((cx-x)/cs)
                    c[s][o]=1
                    ut.append([y,x,t[s][o]])
                
        vt.append(ut)        
    elif geom.GetGeometryName()=="MULTILINESTRING":
        gNum=geom.GetGeometryCount()
        print(feat.GetField('nev'))
        for j in range(0,gNum):  
            g=geom.GetGeometryRef(j)
            pNum=g.GetPointCount()
            ut=[]
            for k in range(0,pNum):
                if j==0:
                    y=g.GetPoints(j)[k][0]
                    x=g.GetPoints(j)[k][1]
                    o=floor((y-cy)/cs)
                    s=floor((cx-x)/cs)
                    c[s][o]=1
                    ut.append([y,x,t[s][o]])
                else:
                    y=g.GetPoints(j)[k][0]
                    x=g.GetPoints(j)[k][1]
                    ye=g.GetPoints(j)[k-1][0]
                    xe=g.GetPoints(j)[k-1][1]
                    tav=tavolsag(xe,ye,x,y)
                    if x!=xe and y!=ye:
                        if tav>cs:
                            resz=ceil(tav/cs)
                            if tav/cs>1:
                                be=[]
                                be=pontbeszur(xe,ye,x,y,resz)
                                for k in range(0,len(be)):
                                    s=be[k][0]
                                    o=be[k][1]
                                    c[s][o]=1
                                    ut.append([y,x,t[s][o]])
                        o=floor((y-cy)/cs)
                        s=floor((cx-x)/cs)
                        c[s][o]=1
                        ut.append([y,x,t[s][o]])
            vt.append(ut)
v.Destroy()

fn=open(fname+".asc",'w')
fn.write('ncols '+str(cn)+'\n'+'nrows '+str(cm)+'\n'+'xllcorner '+str(cy)+'\n'+'yllcorner '+str(cx-cm*cs)+'\n'+'cellsize '+str(cs)+'\n')
for k in range (0,len(c)):
    for l in range (0,len(c[k])):
        fn.write(str(c[k][l])+' ')
    fn.write('\n')
fn.close()  
#----------------------------------------------------------Processing
for i in range(0,len(vt)-1):
    elso=vt[i][0][2]
    utolso=vt[i][len(vt[i])-1][2]
    if elso>utolso: #The line starts on the hill 
        a=1
        b=len(vt[i])-1
        c=+1
        zmin=elso
    else: #The line starts at the outfall
        a=len(vt[i])-2
        b=0
        c=-1
        zmin=utolso
    for j in range(a,b,c):
        y=vt[i][j][0]
        x=vt[i][j][1]
        z=vt[i][j][2]
        s=floor((cx-x)/cs)
        o=floor((y-cy)/cs)
        if z>zmin:
            vt[i][j][2]=zmin
            ujt[s][o]=zmin
            for k in range (-2,3,+1):
                for l in range(-1,2,+1):
                    ujt[s+k][o+l]=(t[s+k][o+l]+zmin)/2
        else:
            zmin=z

                    
fn=open(fname+"new.asc",'w')
fn.write('ncols '+str(cm)+'\n'+'nrows '+str(cn)+'\n'+'xllcorner '+str(cy)+'\n'+'yllcorner '+str(cx-cn*cs)+'\n'+'cellsize '+str(cs)+'\n')
for k in range (0,len(ujt)):
    for l in range (0,len(ujt[k])):
        fn.write(str(ujt[k][l])+' ')
    fn.write('\n')
fn.close()                  

