Skip to content

Commit 87a3379

Browse files
committed
Adding more robust linear interpolation method where spline does not work.
1 parent 877dd24 commit 87a3379

4 files changed

Lines changed: 77 additions & 20 deletions

File tree

Project.toml

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,34 @@
11
name = "OptimalTransportNetworks"
22
uuid = "e2b46e68-897f-4e4e-ba36-a93c9789fd96"
33
authors = ["Sebastian Krantz <sebastian.krantz@graduateinstitute.ch>"]
4-
version = "0.1.5"
4+
version = "0.1.6"
55

66
[deps]
77
Dierckx = "39dd38d3-220a-591b-8e3c-4c3a8c710a94"
88
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
99
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
11-
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
11+
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
1212
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1414
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
15+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1516
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1617

1718
[compat]
18-
julia = "1.8.5"
19-
Dierckx = "0.5.0"
20-
Ipopt = "1.4.0"
21-
JuMP = "1.20.0"
22-
MathOptInterface = "1.0.0"
23-
Plots = "1.19.0"
24-
Random = "1.0.0"
25-
SparseArrays = "1.0.0"
26-
Statistics = "1.0.0"
27-
LinearAlgebra = "1.0.0"
19+
julia = "1.8.5, 2"
20+
Ipopt = "1.4.0, 2"
21+
JuMP = "1.20.0, 2"
22+
LinearAlgebra = "1.0.0, 2"
23+
NearestNeighbors = "0.4.10, 1"
24+
Plots = "1.19.0, 2"
25+
Dierckx = "0.5.0, 1"
26+
Random = "1.0.0, 2"
27+
SparseArrays = "1.0.0, 2"
28+
StaticArrays = "1.1.0, 2"
29+
Statistics = "1.0.0, 2"
30+
31+
32+
2833

2934

src/OptimalTransportNetworks.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,11 @@ plot_graph(graph, results_annealing[:Ijk])
7373
module OptimalTransportNetworks
7474

7575
using LinearAlgebra, JuMP, Plots
76-
using SparseArrays: sparse
7776
using Statistics: mean
7877
# using MathOptInterface: Parameter
7978
using Dierckx: Spline2D, evaluate
79+
using NearestNeighbors: KDTree, knn
80+
# using Interpolations: cubic_spline_interpolation
8081
import Ipopt, Plots, Random #, MathOptSymbolicAD
8182
# import MathOptInterface as MOI
8283

src/main/helper.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,3 +285,48 @@ function rescale_network!(param, graph, I1, Il, Iu; max_iter = 100)
285285

286286
return I1
287287
end
288+
289+
# KNN Version: More robust
290+
function linear_interpolation_2d(vec_x, vec_y, vec_map, xmap, ymap)
291+
292+
# Ensure input vectors are of the same length
293+
@assert length(vec_x) == length(vec_y) == length(vec_map) "Input vectors must have the same length"
294+
295+
# Ensure input vectors are Float64
296+
vec_x = convert(Vector{Float64}, vec_x)
297+
vec_y = convert(Vector{Float64}, vec_y)
298+
vec_map = convert(Vector{Float64}, vec_map)
299+
xmap = convert(Vector{Float64}, xmap)
300+
ymap = convert(Vector{Float64}, ymap)
301+
302+
# Initialize the output array
303+
fmap = zeros(length(xmap), length(ymap))
304+
305+
# Create a KDTree for efficient nearest neighbor search
306+
points = hcat(vec_x, vec_y)
307+
tree = KDTree(points'; leafsize = 5)
308+
309+
# Determine the number of neighbors to use (k)
310+
k = min(15, size(points, 1)) # Use 15 or the total number of points, whichever is smaller
311+
312+
for (ix, x) in enumerate(xmap), (iy, y) in enumerate(ymap)
313+
314+
# Find the 15 nearest neighbors
315+
idxs, dists = knn(tree, [x, y], k, true)
316+
317+
# If the point is exactly on a known point, use that value
318+
if dists[1] 0
319+
fmap[ix, iy] = vec_map[idxs[1]]
320+
continue
321+
end
322+
323+
# Weights
324+
weights = 1 ./ dists.^2
325+
weights ./= sum(weights)
326+
327+
# Interpolate
328+
fmap[ix, iy] = sum(weights .* vec_map[idxs])
329+
end
330+
331+
return fmap
332+
end

src/main/plot_graph.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,21 @@ function plot_graph(graph, edges = nothing; kwargs...)
9494
vec_map = vec(op.map)
9595
end
9696
# Interpolate map onto grid
97-
# itp = interpolate((vec_x, vec_y), vec_map, Gridded(Linear()))
98-
spl = Spline2D(vec_x, vec_y, vec_map, s = 0.1)
9997
xmap = range(minimum(vec_x), stop=maximum(vec_x), length=2*length(vec_x))
10098
ymap = range(minimum(vec_y), stop=maximum(vec_y), length=2*length(vec_y))
101-
Xmap, Ymap = xmap' .* ones(length(ymap)), ymap .* ones(length(xmap))'
102-
Xmap, Ymap = Xmap[:], Ymap[:]
103-
fmap = evaluate(spl, Xmap, Ymap)
104-
# make fmap a matrix with same size as xmap and ymap
105-
fmap = reshape(fmap, length(xmap), length(ymap))
99+
# itp = interpolate((vec_x, vec_y), vec_map, Gridded(Linear()))
100+
fmap = zeros(length(xmap), length(ymap))
101+
try
102+
spl = Spline2D(vec_x, vec_y, vec_map, s = 0.1)
103+
Xmap, Ymap = xmap' .* ones(length(ymap)), ymap .* ones(length(xmap))'
104+
Xmap, Ymap = Xmap[:], Ymap[:]
105+
fmap_values = evaluate(spl, Xmap, Ymap)
106+
fmap = reshape(fmap_values, length(xmap), length(ymap))
107+
catch
108+
# println("Spline2D interpolation failed, falling back to linear interpolation")
109+
# If Spline2D interpolation fails, fall back to linear interpolation
110+
fmap = linear_interpolation_2d(vec_x, vec_y, vec_map, xmap, ymap)
111+
end
106112

107113
# Plot heatmap
108114
heatmap!(pl, xmap, ymap, fmap,

0 commit comments

Comments
 (0)