import functools
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import xarray as xr
import salvus.namespace as sn
from salvus import fem
from salvus.mesh.algorithms.unstructured_mesh import metrics
from salvus.mesh.layered_meshing import simple_post_refinementmesh = sn.layered_meshing.mesh_from_domain(
domain=sn.domain.dim2.BoxDomain.from_bounds((-1, +1), (-1, +1)),
model=sn.material.from_params(vp=2.0, rho=1.0),
mesh_resolution=sn.MeshResolution(
reference_frequency=1.0, elements_per_wavelength=2.0
),
)
# Plot the mesh
mesh<salvus.mesh.data_structures.unstructured_mesh.unstructured_mesh.UnstructuredMesh object at 0x719321355010>
fig = plt.figure(figsize=(5, 5))
ax = fig.subplots(1, 1)
ax.scatter(*mesh.points.T, c="b")
# Label each point with its id.
for i, (x, y) in enumerate(mesh.points):
ax.annotate(
str(i),
(x, y),
xytext=(5, 5),
textcoords="offset points",
fontsize=8,
ha="left",
)
ax.set_aspect("equal")
ax.set_xlabel("x [m]")
ax.set_ylabel("x [m]")
plt.show()mesh.get_element_nodes()array([[[-1., -1.],
[ 0., -1.],
[-1., 0.],
[ 0., 0.]],
[[ 0., -1.],
[ 1., -1.],
[ 0., 0.],
[ 1., 0.]],
[[-1., 0.],
[ 0., 0.],
[-1., 1.],
[ 0., 1.]],
[[ 0., 0.],
[ 1., 0.],
[ 0., 1.],
[ 1., 1.]]])mesh.get_element_nodes() is essentially a convenience function; under the
hood, it just returns mesh.points[mesh.connectivity]. This means the
connectivity array is the real source of new information here, defining how
points are grouped into elements.fig, axs = plt.subplots(1, 2, sharey=True, figsize=(7, 4))
axs[0].set_title("Tensor Product Ordering")
axs[1].set_title("Counterclockwise Ordering")
axs[0].set_ylabel("$s$")
corner_nodes = np.array(
[
[-1.0, -1.0],
[1.0, -1.0],
[-1.0, 1.0],
[1.0, 1.0],
]
)
# Plot the edges with the order being counterclockwise (CCW)
ccw_ordering = np.array([0, 1, 3, 2])
for ax in axs:
ax.set_aspect("equal")
ax.fill(
corner_nodes[ccw_ordering, 0],
corner_nodes[ccw_ordering, 1],
facecolor="none",
edgecolor="k",
linestyle="--",
)
ax.plot(corner_nodes[:, 0], corner_nodes[:, 1], "bo")
ax.set_xlabel("$r$")
# Place the annotation slightly inside of the element
for i, corner_node in enumerate(corner_nodes):
axs[0].annotate(
i, xy=corner_node * 0.8, fontsize=20, ha="center", va="center"
)
axs[1].annotate(
ccw_ordering[i],
xy=corner_node * 0.8,
fontsize=20,
ha="center",
va="center",
)
plt.show()# Plot the points of the first element.
plt.plot(*mesh.get_element_nodes()[0].T, "--o", color="k", markerfacecolor="b")
plt.axis("equal")
plt.show()# Permute the connectivity according to the right-hand rule.
plt.plot(
*mesh.get_element_nodes()[0, [0, 1, 3, 2, 0]].T,
"--o",
color="k",
markerfacecolor="b",
)
plt.axis("equal")
plt.show()def plot_mesh(
mesh: sn.UnstructuredMesh,
values: npt.NDArray | None = None,
annotate: bool = False,
title: str | None = None,
ax: plt.Axes | None = None,
):
"""
Plot a 2D mesh with optional color mapping and value annotations.
Creates a visualization of an unstructured mesh where each element can be
colored according to a scalar field. Elements are filled with colors from a
colormap, edges are outlined in black, and vertices are marked with dots.
Args:
mesh: The mesh to visualize. Must be a 2D quadrilateral mesh.
values: Scalar values to map to colors for each element. If None,
elements are colored gray. Array length must match mesh.nelem.
annotate: If True, display the numeric value at the center of each
element.
title: Title of the plot. If None, no title is added.
axes: Matplotlib axes to plot on. If None, a new plot is created.
"""
# Get element nodes and reorder for visualization (right-hand rule)
nodes = mesh.get_element_nodes()[:, [0, 1, 3, 2]]
# Close each element by appending the first node to create filled polygons
nodes_closed = np.concatenate((nodes, nodes[:, [0]]), axis=1)
# Create figure and axes
if ax is None:
f, ax = plt.subplots()
else:
f = plt.gcf()
ax.set_aspect("equal")
if values is not None:
# Normalize values to [0,1] range for colormap mapping
norm_values = np.divide(
(values - values.min()).astype(float),
(values.max() - values.min()).astype(float),
where=~np.isclose(values.max(), values.min()),
out=np.full_like(values, 0.5, dtype=float),
)
# Apply colormap to normalized values
colors = matplotlib.colormaps["plasma_r"](norm_values)
# Plot each element with its corresponding color
for i, elm in enumerate(nodes_closed):
ax.fill(*elm.T, facecolor=colors[i], edgecolor="k", alpha=0.8)
ax.plot(*elm.T, "ko", markersize=1)
# Optionally add text annotation at element center
if annotate:
# Calculate element centroid (excluding duplicate closing node)
cx, cy = elm[:-1].mean(axis=0)
# Display value, centered in the moddle of the element
ax.text(cx, cy, f"{values[i]}", ha="center", va="center")
# Create and add colorbar for value mapping
sm = plt.cm.ScalarMappable(
cmap="plasma_r",
norm=plt.Normalize(vmin=values.min(), vmax=values.max()),
)
plt.colorbar(sm, ax=ax)
else:
# Default visualization: gray elements with black edges
for elm in nodes_closed:
ax.fill(*elm.T, facecolor="lightgray", edgecolor="k", alpha=0.5)
ax.plot(*elm.T, "ko", markersize=6, zorder=1)
if title is not None:
ax.set_title(title)
return f, axplot_mesh(mesh, np.arange(mesh.nelem), annotate=True, title="Element Indices")(<Figure size 640x480 with 2 Axes>,
<Axes: title={'center': 'Element Indices'}>)def rotate_mesh(
mesh: sn.UnstructuredMesh, angle: float
) -> sn.UnstructuredMesh:
"""
Rotate a mesh by `angle` degrees.
Args:
mesh: The mesh. Will be copied.
angle: The angle in degrees. Positive angles rotate clockwise.
Returns:
A rotated mesh.
"""
theta = np.deg2rad(angle)
rotation_matrix = np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
)
_mesh = mesh.copy()
_mesh.points[:] = _mesh.points @ rotation_matrix
return _mesh
mesh_r = rotate_mesh(mesh, -45.0)
plot_mesh(
mesh_r,
values=np.arange(mesh_r.nelem),
annotate=True,
title="Element Indices",
)(<Figure size 640x480 with 2 Axes>,
<Axes: title={'center': 'Element Indices'}>)salvus.fem module below. Since
the transformation from reference to physical coordinates is affine here, the
Jacobian will be the same no matter where in the element we compute it.# Compute the Jacobian at a reference coordinate for each element.
j = fem.jacobian.j(np.array([0.0, 0.0]), mesh_r.get_element_nodes())
jarray([[[ 0.35355339, 0.35355339],
[-0.35355339, 0.35355339]],
[[ 0.35355339, 0.35355339],
[-0.35355339, 0.35355339]],
[[ 0.35355339, 0.35355339],
[-0.35355339, 0.35355339]],
[[ 0.35355339, 0.35355339],
[-0.35355339, 0.35355339]]])def compute_gll_points(mesh: sn.UnstructuredMesh, n: int):
"""
Compute the physical coordinates of GLL points for all elements in a mesh.
Maps GLL points from the reference element to physical space using
interpolation. The GLL points define where Lagrange basis functions are
centered in spectral element methods.
Args:
mesh: The mesh containing element connectivity and vertex coordinates.
n: The polynomial order of the GLL points. Number of points per
element dimension is (n+1).
"""
# Get coefficients to map from mesh geometry order to GLL order.
ic = fem.jacobian.get_interpolation_coefficients_from_order(
n_dim=mesh.ndim, from_order=mesh.shape_order, to_order=n
)
# Interpolate mesh vertex x-coordinates to GLL point locations.
gll_x = ic @ mesh.get_element_nodes()[..., None, 0]
# Interpolate mesh vertex y-coordinates to GLL point locations.
gll_y = ic @ mesh.get_element_nodes()[..., None, 1]
# Stack coordinates and flatten to get all GLL points as (x,y) pairs.
return np.stack([gll_x.ravel(), gll_y.ravel()])
# Get the GLL points for a few different orders
orders = [1, 2, 4, 7, 15, 21]
_, axs = plt.subplots(2, 3, figsize=(10, 7), sharey=True, sharex=True)
for row in range(axs.shape[0]):
for col in range(axs.shape[1]):
# Get the i^th index for fetching the order
i = row * axs.shape[1] + col
# Plot the corner nodes of the elements and the GLL points
plot_mesh(mesh, title=f"Order: {orders[i]}", ax=axs[row, col])
axs[row, col].plot(
*compute_gll_points(mesh, orders[i]), "ro", markersize=2, zorder=0
)
plt.show()compute_time_step(...) method
within Salvus' mesh.algorithms.unstructured_mesh.metrics module.def plot_reciprocal_lattice(mesh: sn.UnstructuredMesh) -> None:
"""
Plot reciprocal lattice vectors on top of a mesh visualization.
Visualizes the reciprocal lattice basis vectors used in time-step
estimation. As we only compute the Jacobians at the element centers here,
the linearization error will be apparent for non-affine elements (those
which can not be deformed by a simple rotation and scaling).
Args:
mesh: The mesh to analyze. Jacobians are computed at element centers.
"""
j = fem.jacobian.j(np.array([0.0, 0.0]), mesh.get_element_nodes())
# Plot vectors at the center of each element, oriented along the columns of
# j
for (x, y), j_elm in zip(mesh.get_element_nodes().mean(axis=1), j):
# The columns of inverse Jacobian can be interpreted as vectors normal
# to constant [r, s] isolines in physical space. This acts as a basis
# for the reciprocal lattice.
a0, a1 = np.linalg.inv(j_elm).T
# The Miller indices to construct a diagonal lattice in a 2-D crystal
# are [+1, +1] and [+1, -1]. Construct the two lattice vectors here.
r0 = 1 * a0 + 1 * a1
r1 = 1 * a0 - 1 * a1
# Get the direction of the normal planes to r0 and r1 for plotting.
l0 = np.array([-r0[1], r0[0]]) / np.linalg.det(np.linalg.inv(j_elm.T))
l1 = np.array([-r1[1], r1[0]]) / np.linalg.det(np.linalg.inv(j_elm.T))
# Add quiver plot for l0 and l1 vectors at element centers. These
# should point along the lattice plane directions.
plt.quiver(x, y, *l0)
plt.quiver(x, y, *l1)
# Plot an estimate of the lattice itself. Since we're only computing
# the Jacobian at a single point (the element center), this is only
# accurate for affine element mappings (i.e. constant Jacobians). It
# can nevertheless be an interesting exercise to see how the lattice is
# distorted in non-affine elements.
plt.plot([x - l0[0], x, x + l0[0]], [y - l0[1], y, y + l0[1]], "r-.")
plt.plot([x - l1[0], x, x + l1[0]], [y - l1[1], y, y + l1[1]], "m-.")
plot_mesh(mesh)
plot_reciprocal_lattice(mesh)
plt.gca().set_aspect("equal")compute_minimum_distance function below, which each line annotated in order
to give additional context. This is the exact same algorithm that executes in
the solver upon simulation startup, and which is a key factor in determining
the ever-so-important time step of a simulation.def compute_minimum_distance(mesh: sn.UnstructuredMesh, n: int):
"""
Compute minimum distance between GLL points using a reciprocal lattice.
Estimates the smallest physical spacing between neighboring GLL points in a
deformed mesh using crystallographic techniques.
Args:
mesh: The mesh containing element connectivity and geometry.
n: The polynomial order.
"""
# Get the positions of the GLL points in reference coordinates.
gll_1d = fem.utils.gll_points_from_order(n)
gll_nd = fem.utils.tensorized_quadrature_points(mesh.ndim, n)
# Compute the Jacobian at each GLL point.
j = fem.jacobian.j(gll_nd, mesh.get_element_nodes())
# The columns of the transposed inverse jacobian are the reciprocal lattice
# vectors.
r_lat = np.swapaxes(np.linalg.inv(j), -1, -2)
# Multiply the reciprocal lattice vectors by the diagonal Miller indices on
# a quad to get the reciprocal wavevectors that encode the spacing between
# adjacent lattice planes.
r_wv = r_lat @ np.stack(([+1, +1], [+1, -1])).T
# Compute the norm of each lattice vector. Get the maximum value for all
# GLL points and lattice directions per element. This represents the
# largest reciprocal wavevector, which in turn represents the smallest
# physical distance between two lattice planes.
r_wv_max_norm = np.max(np.linalg.norm(r_wv, axis=-2), axis=(1, 2))
# Scale this wavevector by the minimum GLL spacing, and multiply by sqrt(2)
# to adjust the scaling factor to account for edge-length.
return np.min(np.diff(gll_1d)) * np.sqrt(2) / r_wv_max_normdef compute_time_step(
n: int, mesh: sn.UnstructuredMesh, max_v: float | npt.NDArray = 1.0
):
"""
Compute stable time step for explicit time integration using CFL condition.
Args:
n: The polynomial order.
mesh: The mesh.
max_v: Maximum velocity per element. Can be scalar (constant) or array
with length matching mesh.nelem for spatially varying velocities.
"""
cfl = {2: 0.6, 3: 0.47}
min_dist = compute_minimum_distance(mesh, n)
return np.asarray(cfl[mesh.ndim] * min_dist / max_v)
n = 1
plot_mesh(
mesh,
compute_time_step(n, mesh),
annotate=True,
title="Time step [s]",
)
plt.plot(*compute_gll_points(mesh, n), "ro", markersize=2)[<matplotlib.lines.Line2D at 0x71931d16fe10>]
smallest_time_step, time_step_per_elm = metrics.compute_time_step(
mesh,
1.0,
simulation_order=n,
)
plot_mesh(
mesh,
time_step_per_elm,
annotate=True,
title="Time step [s] (from salvus)",
)
plt.plot(*compute_gll_points(mesh, n), "ro", markersize=2)[<matplotlib.lines.Line2D at 0x71931886a690>]
def generate_nontrivial_mesh(
interlayer_coarsening_policy: (
sn.layered_meshing.meshing_protocol.coarsening_policy.InterlayerCoarseningPolicy
| None
) = None,
intralayer_coarsening_policy: (
sn.layered_meshing.meshing_protocol.coarsening_policy.IntralayerCoarseningPolicy
| None
) = None,
local_refinement_policy: (
sn.layered_meshing.LocalRefinementPolicy | None
) = None,
add_low_vel: bool = False,
):
"""
Generate a complex two-layer ocean-bottom mesh with optional velocity
heterogeneity.
Creates a simple 2D mesh representing an ocean-sediment system with a
water layer above and a sediment layer below. Includes configurable mesh
refinement policies and an optional Gaussian velocity anomaly to
demonstrate adaptive meshing strategies and their impact on computational
cost.
Args:
interlayer_coarsening_policy: Policy for mesh transitions between
material layers. Controls element size changes at material
boundaries.
intralayer_coarsening_policy: Policy for mesh refinement within
individual layers. Can be single policy or list for layer-specific
control.
local_refinement_policy: Policy for targeted mesh refinement in
specific regions. Typically used around velocity anomalies or
interfaces.
add_low_vel: If True, adds a Gaussian low-velocity anomaly in the
sediment layer to create challenging time-step conditions.
"""
# Define spatial grid for velocity field interpolation x-axis spans 5 km
# horizontally, y-axis spans 2.5 km vertically.
x, y = np.linspace(0, 5_000, 101), np.linspace(-2500.0, 0.0, 101)
# Create 2D coordinate meshes for field evaluation Use 'ij' indexing so
# xx[i,j] corresponds to x[i], yy[i,j] corresponds to y[j].
xx, yy = np.meshgrid(x, y, indexing="ij")
# Define the center of our anomaly.
cx, cy = 2500.0, -1750.0
# Create velocity perturbation that results in a challenging low-velocity
# zone that will reduce time steps.
anom = np.where(
(np.abs(np.hypot((xx - cx), yy - cy) - 500.0) <= np.max(np.diff(x))),
-2000.0,
0.0,
)
# Define computational domain.
domain = sn.domain.dim2.BoxDomain.from_bounds((0, 5_000), (-2_500, 0))
# Construct layered velocity model with water-sediment system.
model = sn.layered_meshing.LayeredModel(
[
# Upper layer: seawater with typical oceanic properties
# Density 1050 kg/m³, P-wave velocity 1500 m/s.
sn.material.from_params(rho=1050.0, vp=1500.0),
# Interface between water and sediment layers. Creates linear-ramp
# seafloor with 500m elevation variation.
sn.layered_meshing.interface.Curve.from_points(
[0.0, 5_000.0],
[-250.0, +250.0],
axis="x",
reference_elevation=-500.0,
),
# Lower layer: sediment with higher velocity and optional anomaly
# density 2000 kg/m³, base P-wave velocity 3000 m/s.
sn.material.from_params(
rho=2000.0,
# Velocity field: base velocity + optional perturbation.
vp=xr.DataArray(
3000 + (anom if add_low_vel else 0.0),
[("x", x), ("v", y)],
),
),
]
)
# Configure local refinement if policy is provided; Local refinement
# targets specific regions (e.g., around velocity anomalies).
if local_refinement_policy is not None:
# Extract oracle filter and refinement policy from user-provided
# function. Oracle filter masks the minimum velocity in the meshing
# process; local refinement policy specifies where to refine.
of, lrp = local_refinement_policy(domain=domain, layered_model=model)
else:
# Use default behavior: no local refinement
of, lrp = sn.layered_meshing.filters.no_filter, None
# Generate the final mesh using Salvus layered meshing system.
return sn.layered_meshing.mesh_from_domain(
domain=domain,
# Meshing protocol encapsulates all mesh generation parameters.
model=sn.layered_meshing.MeshingProtocol(
# The layered material model defined above.
lm=model,
# Control mesh transitions between different material layers.
interlayer_coarsening_policy=interlayer_coarsening_policy,
# Control mesh refinement within individual layers.
intralayer_coarsening_policy=intralayer_coarsening_policy,
# Apply local refinement in targeted regions.
local_refinement_policy=lrp,
# Oracle filter determines refinement locations.
oracle_filter=of,
),
mesh_resolution=sn.MeshResolution(
reference_frequency=10.0,
elements_per_wavelength=1.0,
),
)def compute_cost_factor(n_elm: int, time_step: float) -> float:
"""
Compute computational cost factor for mesh and time-step combination.
Estimates the relative computational cost by combining number of elements
with time-stepping constraints. Uses logarithmic scaling to handle the wide
range of values encountered in practice.
Args:
n_elm: Number of elements in the mesh.
time_step: Minimum stable time step for the mesh.
"""
return np.log(n_elm / time_step)def plot_and_report_cost(
mesh: sn.UnstructuredMesh,
n: int = 1,
show_gll: bool = True,
title: str | None = None,
):
"""
Visualize mesh with the time-step and report computational cost metrics.
Args:
mesh: The mesh to analyze and visualize.
n: Polynomial order for GLL point computation and time-step
calculation.
show_gll: If True, overlay red dots showing GLL point locations.
title: Title of the plot. If None, no title is added.
"""
max_vp = np.max(mesh.element_nodal_fields["VP"], axis=-1)
_, time_step = metrics.compute_time_step(
mesh=mesh,
simulation_order=n,
max_velocity=max_vp,
)
plot_mesh(mesh, time_step, title=title)
if show_gll:
plt.plot(*compute_gll_points(mesh, n), "ro", markersize=1)
print(f"NUM ELM: {mesh.nelem}")
print(f"MIN TIME STEP: {np.min(time_step)}")
print(f"COST FACTOR: {compute_cost_factor(mesh.nelem, np.min(time_step))}")nontrivial_mesh = generate_nontrivial_mesh()
plot_mesh(
nontrivial_mesh,
np.mean(nontrivial_mesh.element_nodal_fields["VP"], axis=-1),
title="$v_p$ [m/s]",
)(<Figure size 640x480 with 2 Axes>, <Axes: title={'center': '$v_p$ [m/s]'}>)plot_and_report_cost(
nontrivial_mesh, n=1, show_gll=True, title="Time step [s]"
)NUM ELM: 340 MIN TIME STEP: 0.014636505628749122 COST FACTOR: 10.0531821031102
plot_and_report_cost(
generate_nontrivial_mesh(
interlayer_coarsening_policy=sn.layered_meshing.InterlayerConstant()
),
title="Time step [s]",
)NUM ELM: 442 MIN TIME STEP: 0.024566243134638042 COST FACTOR: 9.79769189079098
plot_and_report_cost(
generate_nontrivial_mesh(
intralayer_coarsening_policy=sn.layered_meshing.IntralayerVerticalRefine()
),
title="Time step [s]",
)NUM ELM: 340 MIN TIME STEP: 0.011171871522389187 COST FACTOR: 10.323301748540702
plot_and_report_cost(
generate_nontrivial_mesh(
intralayer_coarsening_policy=[
sn.layered_meshing.IntralayerVerticalRefine(),
sn.layered_meshing.IntralayerConstant(),
]
),
n=1,
show_gll=True,
title="Time step [s]",
)NUM ELM: 324 MIN TIME STEP: 0.011171871522389187 COST FACTOR: 10.275099646722825
plot_and_report_cost(
generate_nontrivial_mesh(
intralayer_coarsening_policy=[
sn.layered_meshing.IntralayerVerticalRefine(
refinement_type="doubling"
),
sn.layered_meshing.IntralayerConstant(),
]
),
n=1,
show_gll=True,
title="Time step [s]",
)NUM ELM: 304 MIN TIME STEP: 0.014636505628749122 COST FACTOR: 9.941264186906217
mesh_anom = generate_nontrivial_mesh(
intralayer_coarsening_policy=[
sn.layered_meshing.IntralayerVerticalRefine(
refinement_type="doubling"
),
sn.layered_meshing.IntralayerConstant(),
],
local_refinement_policy=functools.partial(
simple_post_refinement,
fac=[2.0],
restrict=[1],
refinement_style="unstable",
),
add_low_vel=True,
)
plot_mesh(mesh_anom, np.mean(mesh_anom.element_nodal_fields["VP"], axis=-1))
plot_and_report_cost(mesh_anom, n=1, show_gll=False, title="Time step [s]")NUM ELM: 950 MIN TIME STEP: 0.006302712208122705 COST FACTOR: 11.923237213595918