Overload the Drift¶

In [1]:
import gstlearn as gl
import numpy as np

Creation of the class¶

In [2]:
class LocalKriging:
    def __init__(self, driftFunctions, db, model):
        # Input Data Bae
        self.ndim = db.getNDim()
        self.dbc = gl.Db(db)
        namesCoords = db.getNamesByLocator(gl.ELoc.X)
        coords = db[namesCoords]
        for i, f in enumerate(driftFunctions):
            self.dbc["addedDrift" + str(i)] = f(coords)
        self.dbc.setLocators(["addedDrift*"], gl.ELoc.F)

        # Output Data Base
        self.dbout = gl.Db.createFromOnePoint([0 for i in range(self.ndim)])
        iptrEst = self.dbout.addColumnsByConstant(
            1, radix="z_estim", locatorType=gl.ELoc.Z
        )
        coords = self.dbout.getAllCoordinates().transpose()[0]
        for i, f in enumerate(driftFunctions):
            self.dbout.addColumns(
                f(np.atleast_2d(np.array(coords))), "addedDrift" + str(i)
            )
        self.dbout.setLocators(["addedDrift*"], gl.ELoc.F)

        # Model
        self.modelc = gl.Model(model)
        self.modelc.setDriftIRF(0, len(driftFunctions))

        # Neighborhood (Unique)
        self.neigh = gl.NeighUnique()

        # For later use
        self.drifts = driftFunctions

        # Instantiate the KrigingSystem
        self.ks = gl.KrigingSystem(self.dbc, self.dbout, self.modelc, self.neigh)
        self.ks.updKrigOptEstim(iptrEst, -1, -1)
        self.ks.isReady()

    def eval(self, coordsTarget):
        for i in range(self.ndim):
            self.dbout.setCoordinate(0, i, coordsTarget[i])

        for i in range(len(self.drifts)):
            self.dbout.setLocVariable(
                gl.ELoc.F, 0, i, self.drifts[i](np.atleast_2d(coordsTarget))[0]
            )
        self.ks.estimate(0)
        return self.dbout["z_estim"][0]

Db creation¶

In [3]:
db = gl.Db.create()
np.random.seed(123)
db["x"] = np.random.uniform(size=100)
db["y"] = np.random.uniform(size=100)
db["z"] = np.random.normal(size=100)
db.setLocators(["x", "y"], gl.ELoc.X)
db.setLocators(["z"], gl.ELoc.Z)

Drift functions¶

In [4]:
def driftFunction1(coords):
    return coords[:, 0] ** 2


def driftFunction2(coords):
    return coords[:, 1] ** 2


driftFunctions = [driftFunction1, driftFunction2]

Model¶

In [5]:
model = gl.Model.createFromParam(gl.ECov.MATERN, param=1)

Class instanciation¶

In [6]:
A = LocalKriging(driftFunctions, db, model)
print(round(A.eval(np.array([0.3, 0.4])), 4))
0.4335

Test¶

In [7]:
testPoint = db.getAllCoordinates()[0]
A.eval(np.array(testPoint))
print(round(db["z"][0], 4))
-0.7408