Imports¶

In [1]:
import json
import panel as pn
import numpy as np
import pandas as pd
import holoviews as hv
from holoviews import opts, streams
from holoviews.plotting.links import DataLink
import shapely 
import geoviews as gv
import hvplot.pandas
import pyposeidon.mesh as pmesh
import seareport_data
import matplotlib.pyplot as plt

hv.extension('bokeh')

functions¶

Utility functions¶

In [2]:
def extract_area(x, y, triangles, lon_range, lat_range):
    mask = (x >= lon_range[0]) & (x <= lon_range[1]) & \
           (y >= lat_range[0]) & (y <= lat_range[1])
    node_indices = np.where(mask)[0]
    node_map = {old: new for new, old in enumerate(node_indices)}
    node_map_reverse = {new: old for new, old in enumerate(node_indices)}
    extracted_x = x[node_indices]
    extracted_y = y[node_indices]
    triangle_indices = np.array([idx for idx, tri in enumerate(triangles) if all(i in node_map for i in tri)])
    extracted_triangles = triangles[triangle_indices]
    extracted_triangles = np.array([[node_map[i] for i in tri] for tri in extracted_triangles])
    
    return extracted_x, extracted_y, extracted_triangles, node_indices, triangle_indices, node_map_reverse

def is_ccw(tris, meshx, meshy):
    x1, x2, x3 = meshx[tris].T
    y1, y2, y3 = meshy[tris].T
    return (y3 - y1) * (x2 - x1) > (y2 - y1) * (x3 - x1)

def is_overlapping(tris, meshx):
    PIR = 180
    x1, x2, x3 = meshx[tris].T
    return np.logical_or(abs(x2 - x1) > PIR, abs(x3 - x1) > PIR, abs(x3 - x3) > PIR)

Panel dashboard function¶

In [3]:
def extract_and_plot_mesh_elements(mesh, element_indices, xbuffer=0.1, ybuffer=0.1):
    """
    Extract nodes and triangles around specific elements of the mesh and create an interactive dashboard for editing.

    Args:
        mesh: The mesh object (e.g., from pmesh).
        element_indices: List of indices of the elements to focus on.
        buffer: Buffer size around the elements (default: 0.1).

    Returns:
        A tuple containing:
        - The panel layout for interactive editing.
        - The extracted node indices.
        - The extracted triangle indices.
        - The node mapping dictionary.
        - tre extracted triangles connectivity
    """
    # Extract mesh data
    x = mesh.Dataset.SCHISM_hgrid_node_x.values
    y = mesh.Dataset.SCHISM_hgrid_node_y.values
    tri3 = mesh.Dataset.SCHISM_hgrid_face_nodes.values

    # Get the bounding box around the elements
    element_nodes = tri3[element_indices].flatten()
    x_min, x_max = x[element_nodes].min() - xbuffer, x[element_nodes].max() + xbuffer
    y_min, y_max = y[element_nodes].min() - ybuffer, y[element_nodes].max() + ybuffer

    # get coastlines close the area of interest
    area_interest = shapely.box(x_min, y_min, x_max, y_max)
    coasts = seareport_data.gshhg_df('f', '6')
    mask = coasts.intersects(area_interest)
    local_coast = coasts[mask]
    llo, lla, tri_sub, node_indices, tri_indices, node_mapping = extract_area(x, y, tri3, (x_min, x_max), (y_min, y_max))

    points = hv.Points((llo, lla, node_indices), vdims=['original_index'])

    point_stream = streams.PointDraw(data=points.columns(), source=points, num_objects=len(llo), empty_value='black')
    def update_trimesh(data): return hv.TriMesh((tri_sub, data))

    trimesh_dmap = hv.DynamicMap(update_trimesh, streams=[point_stream])
    table = hv.Table(points)

    DataLink(points, table)

    def update_data(event):
        if isinstance(event, pd.DataFrame):  # Table edit event
            points.data = event
        else:  # PointDraw event
            updated_data = pd.DataFrame(point_stream.data)
            points.data = updated_data
            table.data = updated_data

    point_stream.param.watch(update_data, 'data')

    # Combine the plot and table
    if mask.sum()>=1: 
        coastlines = local_coast.hvplot().options(xlim=(x_min, x_max), ylim=(y_min, y_max)).opts(alpha=0.6)
        layout = ((coastlines * trimesh_dmap * points).opts(width=1200, height=800) + table).opts(
            opts.Layout(merge_tools=False),
            opts.Points(active_tools=['point_draw'], alpha=0.6, color='red', size=20, tools=['hover']),
            opts.TriMesh(edge_color='black'),
            opts.Table(editable=True)
            )
    else:
        layout = ((trimesh_dmap * points).opts(width=1200, height=800) + table).opts(
            opts.Layout(merge_tools=False),
            opts.Points(active_tools=['point_draw'], alpha=0.6, color='red', size=20, tools=['hover']),
            opts.TriMesh(edge_color='black'),
            opts.Table(editable=True)
        )

    # Return the layout and extracted data
    return pn.panel(layout), node_indices, tri_indices, node_mapping, tri_sub, points

function for updating mesh based on the panel edits¶

In [4]:
def update_mesh_with_interactions(mesh, points_data, tri_sub, tri_indices, node_mapping):
    """
    Update the mesh with the interactions/edits done in the panel dashboard.

    Args:
        mesh: The mesh object (e.g., from pmesh).
        points_data: The updated points data from the dashboard.
        tri_sub: extracted triangles
        tri_indices: The original indices of the extracted triangles.
        node_mapping: The node mapping dictionary.
    Returns:
        A tuple containing:
        - The updated mesh object
        - the indice of the flipped triangles
        - the original nodes
        - the modified nodes
    """
    # Function to flip triangles with incorrect orientation
    def flip_triangles(triangles, is_ccw_mask):
        flipped_triangles = triangles.copy()
        flipped_triangles[~is_ccw_mask] = np.flip(flipped_triangles[~is_ccw_mask], axis=1)
        return flipped_triangles

    # Extract mesh data
    x = mesh.Dataset.SCHISM_hgrid_node_x
    y = mesh.Dataset.SCHISM_hgrid_node_y
    tri3 = mesh.Dataset.SCHISM_hgrid_face_nodes

    # Get modified node coordinates
    modified_nodes = points_data.iloc[:, :2].values
    original_indices = points_data['original_index'].values
    original_nodes = np.vstack((x[original_indices],y[original_indices])).T

    # Update node coordinates in the mesh
    x[original_indices] = modified_nodes[:, 0]
    y[original_indices] = modified_nodes[:, 1]

    # Check triangle orientation and flip if necessary
    is_ccw_mask = is_ccw(tri_sub, modified_nodes[:, 0], modified_nodes[:, 1])
    tri_sub_flipped_local = flip_triangles(tri_sub, is_ccw_mask)
    tri_sub_flipped_global = np.array([[node_mapping[i] for i in tri] for tri in tri_sub_flipped_local])
    flipped_triangle_indices = tri_indices[~is_ccw_mask]

    # Update triangles in the mesh
    tri3[flipped_triangle_indices] = tri_sub_flipped_global[~is_ccw_mask]

    # Save the updated mesh
    mesh.Dataset["SCHISM_hgrid_node_x"] = x
    mesh.Dataset["SCHISM_hgrid_node_y"] = y
    mesh.Dataset["SCHISM_hgrid_face_nodes"] = tri3
    return mesh, flipped_triangle_indices, original_nodes, modified_nodes

JSON functions¶

In [5]:
# Record transformations in a JSON file
def save_transformations(original_nodes, modified_nodes, node_indices, flipped_triangle_indices, filename):
    transformations = {
        "node_transformations": [
            {"original_index": int(node_indices[i]), "original_coords": [float(original_nodes[i, 0]), float(original_nodes[i, 1])],
             "modified_coords": [float(modified_nodes[i, 0]), float(modified_nodes[i, 1])]}
            for i in range(len(node_indices)) if (modified_nodes[i, 0] != original_nodes[i, 0]) and (modified_nodes[i, 1] != original_nodes[i, 1])
        ],
        "flipped_triangles": [int(idx) for idx in flipped_triangle_indices]
    }
    with open(filename, 'w') as f:
        json.dump(transformations, f, indent=4)

# reproduce transformation
def apply_transformations(mesh, transformations_file):
    """
    Apply transformations (node movements and triangle flips) to a mesh based on a JSON file.

    Args:
        mesh: The mesh object (e.g., from pmesh).
        transformations_file (str): Path to the JSON file containing transformations.

    Returns:
        The modified mesh object.
    """
    # Load transformations from the JSON file
    with open(transformations_file, 'r') as f:
        transformations = json.load(f)

    # Extract node transformations
    node_transformations = transformations["node_transformations"]
    for transform in node_transformations:
        original_index = transform["original_index"]
        modified_coords = transform["modified_coords"]
        # Update node coordinates in the mesh
        mesh.Dataset.SCHISM_hgrid_node_x[original_index] = modified_coords[0]
        mesh.Dataset.SCHISM_hgrid_node_y[original_index] = modified_coords[1]

    # Extract flipped triangles
    flipped_triangles = transformations["flipped_triangles"]
    if flipped_triangles:
        # Flip the specified triangles
        tri3 = mesh.Dataset.SCHISM_hgrid_face_nodes.values
        for idx in flipped_triangles:
            tri3[idx] = np.flip(tri3[idx])
        mesh.Dataset.SCHISM_hgrid_face_nodes = tri3

    return mesh

Edit Mesh¶

Context¶

We have mesh genreated with oceanmesh that has the following attributes:

uniform.gr3
15312900 7990779

when launching the SCHISM preprocessing, we have a problem with the following elements:

4223682
In [8]:
mesh = pmesh.set(type='tri2d', mesh_file = "v3.2/GSHHS_f_0.01.gr3")
2025-02-13 12:27:05,674 INFO     pyposeidon.mesh read_file:239 read mesh file v3.2/GSHHS_f_0.01.gr3
/home/tomsail/miniconda3/envs/pos_test/lib/python3.11/site-packages/pyposeidon/mesh.py:247: FutureWarning: Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`
  ni, nj = df.iloc[0].str.split()[0]

Load the dashboard and edit the mesh¶

In [9]:
element_indices = [4223682] 
layout, node_indices, tri_indices, node_mapping, tri_sub, points = extract_and_plot_mesh_elements(mesh, element_indices, xbuffer=0.5, ybuffer=0.1)
layout.servable()
BokehModel(combine_events=True, render_bundle={'docs_json': {'94f388f5-7e17-46af-aa92-0595ae93299b': {'version…
Out[9]:

update the mesh with interaction entered in the dashboard¶

In [41]:
mesh, flipped_triangle_indices, original_nodes, modified_nodes = update_mesh_with_interactions(mesh, points.data, tri_sub, tri_indices, node_mapping)

save transformation in a JSON file¶

In [42]:
save_transformations(original_nodes, modified_nodes, points.data['original_index'].values, flipped_triangle_indices, "transformations.json")

other useful functions¶

In [ ]:
def find_hanging_nodes(nodes, connectivity):
    """
    Identify hanging nodes in the mesh.
    
    :param nodes: List of nodes (coordinates).
    :param connectivity: List of triangles (indices of nodes).
    :return: Set of hanging node indices.
    """
    # Flatten the connectivity list to count occurrences of each node
    all_nodes = np.array(connectivity).flatten()
    unique, counts = np.unique(all_nodes, return_counts=True)
    
    # Nodes that appear less than 2 times are considered hanging
    hanging_nodes = set(unique[counts < 2])
    
    return hanging_nodes

def suppress_hanging_nodes(nodes, connectivity, hanging_nodes):
    """
    Suppress hanging nodes and reindex the remaining nodes and connectivity.
    
    :param nodes: List of nodes (coordinates).
    :param connectivity: List of triangles (indices of nodes).
    :param hanging_nodes: Set of hanging node indices.
    :return: New nodes, new connectivity, and a mapping from old to new indices.
    """
    # Create a mapping from old indices to new indices
    old_to_new = {}
    new_nodes = []
    new_index = 0
    
    # Iterate through nodes and create new list without hanging nodes
    for i, node in enumerate(nodes):
        if i not in hanging_nodes:
            old_to_new[i] = new_index
            new_nodes.append(node)
            new_index += 1
    
    # Reindex connectivity
    new_connectivity = []
    for triangle in connectivity:
        new_triangle = [old_to_new[node] for node in triangle if node not in hanging_nodes]
        if len(new_triangle) == 3:  # Ensure it's still a valid triangle
            new_connectivity.append(new_triangle)
    
    return new_nodes, new_connectivity, old_to_new