Accepts a function to be approximated, and a list of x coordinates that are endpoints of interpolation intervals. Generates cubic splines matching the values and slopes at the ends of the intervals. Can generate fairly fast C code, or can be used directly in Python.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 | class Interpolator:
def __init__(self, name, func, points, deriv=None):
self.name = name # used for naming the C function
self.intervals = intervals = [ ]
# Generate a cubic spline for each interpolation interval.
for u, v in map(None, points[:-1], points[1:]):
FU, FV = func(u), func(v)
# adjust h as needed, or pass in a derivative function
if deriv == None:
h = 0.01
DU = (func(u + h) - FU) / h
DV = (func(v + h) - FV) / h
else:
DU = deriv(u)
DV = deriv(v)
denom = (u - v)**3
A = ((-DV - DU) * v + (DV + DU) * u +
2 * FV - 2 * FU) / denom
B = -((-DV - 2 * DU) * v**2 +
u * ((DU - DV) * v + 3 * FV - 3 * FU) +
3 * FV * v - 3 * FU * v +
(2 * DV + DU) * u**2) / denom
C = (- DU * v**3 +
u * ((- 2 * DV - DU) * v**2 + 6 * FV * v
- 6 * FU * v) +
(DV + 2 * DU) * u**2 * v + DV * u**3) / denom
D = -(u *(-DU * v**3 - 3 * FU * v**2) +
FU * v**3 + u**2 * ((DU - DV) * v**2 + 3 * FV * v) +
u**3 * (DV * v - FV)) / denom
intervals.append((u, A, B, C, D))
def __call__(self, x):
def getInterval(x, intervalList):
# run-time proportional to the log of the length
# of the interval list
n = len(intervalList)
if n < 2:
return intervalList[0]
n2 = n / 2
if x < intervalList[n2][0]:
return getInterval(x, intervalList[:n2])
else:
return getInterval(x, intervalList[n2:])
# Tree-search the intervals to get coefficients.
u, A, B, C, D = getInterval(x, self.intervals)
# Plug coefficients into polynomial.
return ((A * x + B) * x + C) * x + D
def c_code(self):
"""Generate C code to efficiently implement this
interpolator. Run the C code through 'indent' if you
need it to be legible."""
def codeChoice(intervalList):
n = len(intervalList)
if n < 2:
return ("A=%.10e;B=%.10e;C=%.10e;D=%.10e;"
% intervalList[0][1:])
n2 = n / 2
return ("if (x < %.10e) {%s} else {%s}"
% (intervalList[n2][0],
codeChoice(intervalList[:n2]),
codeChoice(intervalList[n2:])))
return ("double interpolator_%s(double x) {" % self.name +
"double A,B,C,D;%s" % codeChoice(self.intervals) +
"return ((A * x + B) * x + C) * x + D;}")
|
I was hoping this would beat the library sqrt() function on my Linux box, but the generated C code is a little slower. Modern compilers use loop-unrolling and modern CPUs use branch prediction to minimize pipeline disruption, and nested if-else statements mess up one or both of those. If there were a way to turn a floating-point comparison into a 0 or 1 without a branch instruction, then I could use it in a multiply, but I don't see a way to do that.