Overload the Drift¶
In [1]:
import gstlearn as gl
import numpy as np
Creation of the class¶
In [2]:
class LocalKriging :
def __init__(self,driftFunctions,db,model):
# Input Data Bae
self.ndim = db.getNDim()
self.dbc = gl.Db(db)
namesCoords = db.getNamesByLocator(gl.ELoc.X)
coords = db[namesCoords]
for i,f in enumerate(driftFunctions):
self.dbc["addedDrift" + str(i)] = f(coords)
self.dbc.setLocators(["addedDrift*"],gl.ELoc.F)
# Output Data Base
self.dbout = gl.Db.createFromOnePoint([0 for i in range(self.ndim)])
iptrEst = self.dbout.addColumnsByConstant(1,radix="z_estim",locatorType = gl.ELoc.Z)
coords = self.dbout.getAllCoordinates().transpose()[0]
for i,f in enumerate(driftFunctions):
self.dbout.addColumns(f(np.atleast_2d(np.array(coords))), "addedDrift" + str(i))
self.dbout.setLocators(["addedDrift*"],gl.ELoc.F)
# Model
self.modelc = gl.Model(model)
self.modelc.setDriftIRF(0,len(driftFunctions))
# Neighborhood (Unique)
self.neigh = gl.NeighUnique()
# For later use
self.drifts = driftFunctions
# Instantiate the KrigingSystem
self.ks = gl.KrigingSystem(self.dbc, self.dbout, self.modelc, self.neigh)
self.ks.updKrigOptEstim(iptrEst, -1, -1)
self.ks.isReady()
def eval(self,coordsTarget):
for i in range(self.ndim):
self.dbout.setCoordinate(0, i, coordsTarget[i])
for i in range(len(self.drifts)):
self.dbout.setLocVariable(gl.ELoc.F,0, i, self.drifts[i](np.atleast_2d(coordsTarget))[0])
self.ks.estimate(0)
return self.dbout["z_estim"][0]
Db creation¶
In [3]:
db = gl.Db.create()
np.random.seed(123)
db["x"] = np.random.uniform(size=100)
db["y"] = np.random.uniform(size=100)
db["z"] = np.random.normal(size=100)
db.setLocators(["x","y"],gl.ELoc.X)
db.setLocators(["z"],gl.ELoc.Z)
Drift functions¶
In [4]:
def driftFunction1(coords):
return coords[:,0]**2
def driftFunction2(coords):
return coords[:,1]**2
driftFunctions = [driftFunction1,driftFunction2]
Model¶
In [5]:
model = gl.Model.createFromParam(gl.ECov.MATERN,param=1)
Class instanciation¶
In [6]:
A = LocalKriging(driftFunctions,db,model)
print(round(A.eval(np.array([0.3,0.4])), 4))
0.4335
Test¶
In [7]:
testPoint = db.getAllCoordinates()[0]
A.eval(np.array(testPoint))
print(round(db["z"][0], 4))
-0.7408