diff --git a/.gitignore b/.gitignore index 6af1c3a..33d02ba 100755 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,6 @@ __pycache__/ datasets/**/*.pkl envs/gym-pybullet-drones/ results/ -stats/ \ No newline at end of file +stats/ +venv/ +.DS_Store \ No newline at end of file diff --git a/data/tennis_data/tr3d_dataset_dist0p0_obs_noise0p0_perturb0p05_ncfg4_steps100.pkl b/data/tennis_data/tr3d_dataset_dist0p0_obs_noise0p0_perturb0p05_ncfg4_steps100.pkl new file mode 100644 index 0000000..c0de956 Binary files /dev/null and b/data/tennis_data/tr3d_dataset_dist0p0_obs_noise0p0_perturb0p05_ncfg4_steps100.pkl differ diff --git a/data/tennis_data/tr3d_dataset_dist0p0_obs_noise0p0_perturb0p05_ncfg4_steps100_phase_space_SO3.png b/data/tennis_data/tr3d_dataset_dist0p0_obs_noise0p0_perturb0p05_ncfg4_steps100_phase_space_SO3.png new file mode 100644 index 0000000..36eef29 Binary files /dev/null and b/data/tennis_data/tr3d_dataset_dist0p0_obs_noise0p0_perturb0p05_ncfg4_steps100_phase_space_SO3.png differ diff --git a/data/tennis_data/tr3d_dataset_dist0p0_obs_noise0p0_perturb0p05_ncfg4_steps100_phase_space_principal.png b/data/tennis_data/tr3d_dataset_dist0p0_obs_noise0p0_perturb0p05_ncfg4_steps100_phase_space_principal.png new file mode 100644 index 0000000..c7a69b4 Binary files /dev/null and b/data/tennis_data/tr3d_dataset_dist0p0_obs_noise0p0_perturb0p05_ncfg4_steps100_phase_space_principal.png differ diff --git a/datasets/3d_tennis_racket_dataset_plot.py b/datasets/3d_tennis_racket_dataset_plot.py new file mode 100644 index 0000000..7c83068 --- /dev/null +++ b/datasets/3d_tennis_racket_dataset_plot.py @@ -0,0 +1,460 @@ +""" +Tennis Racket Phase Space Visualization + +Generates TWO output figures per dataset: + + Figure 1 — Principal-axis angular velocity phase space + 6 rows × num_configs cols + Per-axis (Euler angle, body angular velocity) projections. + Axes correspond to body-frame principal axes e1, e2, e3. + + Figure 2 — SO(3)-aware phase space + 8 rows × num_configs cols + For each of {Train, Test}: + - e1 principal axis trajectory on S² (3D, orientation of long head axis) + - Intermediate-axis alignment α = arccos(|ω̂ · e₂|) vs dα/dt (2D phase portrait) + - ||ω|| over time + - Rotational kinetic energy T(t) = ½ ωᵀ I ω (no gravity — free body) +""" + +import pickle +import numpy as np +import os +import matplotlib.pyplot as plt +from matplotlib.collections import LineCollection +from matplotlib.colors import Normalize +from mpl_toolkits.mplot3d.art3d import Line3DCollection +from mpl_toolkits.mplot3d import Axes3D # noqa: F401 (registers 3d projection) + + +# ─────────────────── I/O ─────────────────── + +def load_data(path): + with open(path, 'rb') as f: + return pickle.load(f) + + +# ─────────────────── Geometry helpers ─────────────────── + +def rotmat_to_euler(R_flat): + """Vectorized batch rotmat → Euler ZYX (roll, pitch, yaw). + R_flat: (..., 9) → (..., 3) in radians. + NOTE: ZYX has gimbal lock at pitch = ±π/2; visualization artifact only. + """ + Rs = R_flat.reshape(*R_flat.shape[:-1], 3, 3) + R00 = Rs[..., 0, 0] + R10 = Rs[..., 1, 0] + R20 = Rs[..., 2, 0] + R21 = Rs[..., 2, 1] + R22 = Rs[..., 2, 2] + R12 = Rs[..., 1, 2] + R11 = Rs[..., 1, 1] + + sy = np.sqrt(R00**2 + R10**2) + near_singular = sy < 1e-6 + + roll = np.where(near_singular, + np.arctan2(-R12, R11), + np.arctan2(R21, R22)) + pitch = np.arctan2(-R20, sy) + yaw = np.where(near_singular, + np.zeros_like(R00), + np.arctan2(R10, R00)) + return np.stack([roll, pitch, yaw], axis=-1) + + +def _rotmat_principal_axis(R_flat, axis=0): + """Extract a principal body-frame axis direction R @ e_i from flat rotation matrices. + R_flat: (..., 9) → (..., 3). + Column i of R (row-major flat) = R_flat[..., [i*3+0, i*3+1, i*3+2]] + but since R is stored row-major: column i = R_flat[..., [i, i+3, i+6]]. + """ + return R_flat[..., [axis, axis + 3, axis + 6]] + + +def _resolve_dataset_path(save_dir, filename): + """Return (filename, file_path), or (None, None) if not found.""" + if filename is None: + files = sorted([f for f in os.listdir(save_dir) if f.endswith('.pkl')]) + if not files: + print(f"No .pkl files in {save_dir}") + return None, None + filename = files[0] + print(f"Auto-selected: {filename}") + file_path = os.path.join(save_dir, filename) + if not os.path.exists(file_path): + print(f"File not found: {file_path}") + return None, None + return filename, file_path + + +# ─────────────────── Figure 1: Principal-axis phase space ─────────────────── + +def plot_racket_phase_space(save_dir, num_trajs_to_plot=15, filename=None): + """Per-principal-axis (Euler angle, body ω_i) phase portrait. + + Mirrors the original pendulum plot but: + - axis labels are renamed to e1 (stable/long), e2 (unstable/intermediate), e3 (stable/handle) + - colormap is keyed on ||disturbance torque|| instead of ||wind force|| + - column titles show 'Config N' (racket geometry) instead of 'Batch N' + - figure title is updated for the tennis racket context + """ + filename, file_path = _resolve_dataset_path(save_dir, filename) + if file_path is None: + return + + print(f"Loading: {file_path}") + data = load_data(file_path) + + train_data = data['x'] # (num_configs, T, N_train, 15) + test_data = data.get('test_x', None) # (num_configs, T, N_test, 15) + + num_configs = train_data.shape[0] + # CHANGED: axis labels reflect principal body-frame axes of the racket, + # not generic X/Y/Z pendulum Euler angles. + axis_labels = ['e1 (stable, long)', 'e2 (unstable)', 'e3 (stable, handle)'] + + # Global disturbance-torque range for shared colormap normalization. + # CHANGED: the "action" stored in dims 12:15 is the constant disturbance + # torque (Nm), not a wind force, so label and semantics are updated. + all_u = [np.linalg.norm(train_data[..., 12:15], axis=-1).flatten()] + if test_data is not None: + all_u.append(np.linalg.norm(test_data[..., 12:15], axis=-1).flatten()) + u_max = np.concatenate(all_u).max() + norm = Normalize(vmin=0, vmax=max(u_max, 0.01)) + cmap = plt.get_cmap('viridis') + + nrows = 6 + ncols = num_configs # CHANGED: variable renamed from num_us → num_configs + fig, axes = plt.subplots(nrows=nrows, ncols=ncols, + figsize=(4.5 * ncols, 3 * nrows), squeeze=False) + + datasets = [ + ("Train", train_data, 0), + ("Test", test_data, 3), + ] + + for ds_label, ds_data, row_offset in datasets: + if ds_data is None: + continue + N_avail = ds_data.shape[2] + n_plot = min(N_avail, num_trajs_to_plot) + + for cfg_idx in range(num_configs): # CHANGED: u_idx → cfg_idx + batch = ds_data[cfg_idx] # (T, N, 15) + R_all = batch[..., :9] # (T, N, 9) + eulers_all = rotmat_to_euler(R_all) # (T, N, 3) + + for axis_idx in range(3): + ax = axes[row_offset + axis_idx, cfg_idx] + + for trial in range(n_plot): + angle = np.unwrap(eulers_all[:, trial, axis_idx]) + omega = batch[:, trial, 9 + axis_idx] + u_norms = np.linalg.norm(batch[:, trial, 12:15], axis=1) + + points = np.array([angle, omega]).T.reshape(-1, 1, 2) + segments = np.concatenate([points[:-1], points[1:]], axis=1) + lc = LineCollection(segments, cmap=cmap, norm=norm, + alpha=0.5, linewidth=1.0) + lc.set_array(u_norms[:-1]) + ax.add_collection(lc) + ax.scatter(angle[0], omega[0], color='black', s=8, zorder=3) + + ax.autoscale() + ax.grid(True, alpha=0.3, linestyle='--') + + if cfg_idx == 0: + ax.set_ylabel(f"{ds_label} — {axis_labels[axis_idx]}\n" + + r"$\omega$ (rad/s)", fontsize=9) + + if row_offset + axis_idx == 0: + # CHANGED: 'Batch N' → 'Config N' (each column = racket geometry) + ax.set_title(f"Config {cfg_idx}", fontsize=11, fontweight='bold') + + if row_offset + axis_idx == nrows - 1: + ax.set_xlabel("Angle (rad)", fontsize=9) + else: + ax.set_xticklabels([]) + + # Colorbar + # CHANGED: label updated from wind force to disturbance torque + cbar_ax = fig.add_axes([0.93, 0.12, 0.015, 0.76]) + cb = fig.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), cax=cbar_ax) + cb.set_label(r'$\|\tau_d\|$ (Nm)', fontsize=10) + + # CHANGED: title updated for the tennis racket free rigid body + fig.suptitle("Tennis Racket Phase Space — Rotation about e1, e2, e3 (Euler ZYX)", + fontsize=14, fontweight='bold', y=0.98) + plt.subplots_adjust(right=0.91, hspace=0.25, wspace=0.3) + + ds_name = os.path.splitext(filename)[0] + # CHANGED: output filename suffix from _phase_space_XYZ → _phase_space_principal + out_path = os.path.join(save_dir, f'{ds_name}_phase_space_principal.png') + fig.savefig(out_path, dpi=150, bbox_inches='tight') + print(f"Saved: {out_path}") + plt.close(fig) + + +# ─────────────────── Figure 2: SO(3)-aware phase space ─────────────────── + +def plot_so3_phase_space(save_dir, num_trajs_to_plot=15, filename=None): + """Geometrically-natural visualizations for SO(3) tennis-racket data. + + Per dataset (train, test) and per geometry config: + Row 0/4: e1 principal axis trajectory on S² (orientation of long head axis in world frame) + Row 1/5: Intermediate-axis alignment α = arccos(|ω̂ · e₂|) vs dα/dt — Dzhanibekov portrait + Row 2/6: ||ω|| (rad/s) over time + Row 3/7: Rotational kinetic energy T(t) = ½ ωᵀ I ω (no gravity — free rigid body) + + Key differences from the pendulum SO(3) plot: + - No gravity → no potential energy term; Row 3/7 shows kinetic energy only. + - Bob-on-S² replaced by e1-axis-on-S²: tracks where the long head axis points. + - Tilt angle replaced by intermediate-axis alignment (Dzhanibekov instability). + - Colormap keyed on ||disturbance torque|| not ||wind force||. + - Inertia values (I1, I2, I3) per config are shown as subtitle text. + - Column title is 'Config N' not 'Batch N'. + - m, l, g parameters are NOT needed (free rigid body, no pendulum physics). + - I_diag is read from data['inertia_info'] stored by the datagen. + """ + filename, file_path = _resolve_dataset_path(save_dir, filename) + if file_path is None: + return + + print(f"Loading: {file_path}") + data = load_data(file_path) + + t = np.asarray(data['t']) # (T,) + # CHANGED: read inertia from the stored inertia_info list (one entry per config). + # The pendulum used m, l, g kwargs; racket has no gravity and needs I_diag per config. + inertia_info_list = data.get('inertia_info', []) + + train_data = data['x'] + test_data = data.get('test_x', None) + num_configs = train_data.shape[0] # CHANGED: num_us → num_configs + + # Global disturbance-torque range + all_u = [np.linalg.norm(train_data[..., 12:15], axis=-1).flatten()] + if test_data is not None: + all_u.append(np.linalg.norm(test_data[..., 12:15], axis=-1).flatten()) + u_max = np.concatenate(all_u).max() + norm = Normalize(vmin=0, vmax=max(u_max, 0.01)) + cmap = plt.get_cmap('viridis') + + nrows = 8 + ncols = num_configs + fig = plt.figure(figsize=(4.5 * ncols, 3.5 * nrows)) + + # Build the axes grid manually because rows 0 and 4 need 3D projection + axes = [[None] * ncols for _ in range(nrows)] + for r in range(nrows): + for c in range(ncols): + idx = r * ncols + c + 1 + if r in (0, 4): + axes[r][c] = fig.add_subplot(nrows, ncols, idx, projection='3d') + else: + axes[r][c] = fig.add_subplot(nrows, ncols, idx) + + datasets = [ + ("Train", train_data, 0), + ("Test", test_data, 4), + ] + + # Wireframe sphere coordinates (reused per panel) + uu, vv = np.mgrid[0:2 * np.pi:24j, 0:np.pi:12j] + sx = np.cos(uu) * np.sin(vv) + sy_ = np.sin(uu) * np.sin(vv) + sz = np.cos(vv) + + for ds_label, ds_data, row_offset in datasets: + if ds_data is None: + continue + N_avail = ds_data.shape[2] + n_plot = min(N_avail, num_trajs_to_plot) + + for cfg_idx in range(num_configs): # CHANGED: u_idx → cfg_idx + batch = ds_data[cfg_idx] # (T, N, 15) + R_all = batch[..., :9] # (T, N, 9) + omega_all = batch[..., 9:12] # (T, N, 3) + u_all = batch[..., 12:15] # (T, N, 3) + + # CHANGED: extract e1 (long head axis, column 0 of R) instead of bob direction (column 2). + # The pendulum tracked R @ e_z (bob direction = 3rd column). + # For the racket we track R @ e1 (long axis = 1st column) to visualize orientation. + e1_all = _rotmat_principal_axis(R_all, axis=0) # (T, N, 3) + + omega_norm_all = np.linalg.norm(omega_all, axis=-1) # (T, N) + u_norm_all = np.linalg.norm(u_all, axis=-1) # (T, N) + + # CHANGED: intermediate-axis alignment angle instead of tilt angle. + # Pendulum used tilt α = arccos((R e_z)_z), a geometric angle of the pendulum. + # For the racket we use α = arccos(|ω̂ · e₂|) to measure how aligned the spin + # is with the unstable intermediate axis — the heart of the Dzhanibekov effect. + e2_body = np.array([0.0, 1.0, 0.0]) # body-frame intermediate axis + omega_hat = omega_all / (omega_norm_all[..., np.newaxis] + 1e-12) + align_all = np.abs(omega_hat @ e2_body) # (T, N) + align_angle_all = np.arccos(np.clip(align_all, 0.0, 1.0)) # α ∈ [0, π/2] + + # CHANGED: kinetic energy only — no potential energy (free rigid body, no gravity). + # Pendulum used E = T + V with V = m g l (R e_z)_z. + # Racket has no gravity, so T(t) = ½ ωᵀ I ω is the full mechanical energy. + if cfg_idx < len(inertia_info_list): + info = inertia_info_list[cfg_idx] + I_diag = np.array([info['I1'], info['I2'], info['I3']]) + else: + I_diag = np.ones(3) # fallback + # T = ½ (I1 ω1² + I2 ω2² + I3 ω3²) + T_kin_all = 0.5 * np.sum(I_diag * omega_all**2, axis=-1) # (T, N) + + dt = float(t[1] - t[0]) if len(t) > 1 else 1.0 + align_rate_all = np.gradient(align_angle_all, dt, axis=0) # (T, N) + + # ── Row 0/4: e1 axis on S² ── + # CHANGED: visualizes where the racket's long head axis (e1) points in world frame, + # instead of the pendulum bob position. Reference markers are removed (no "up"/"down" + # physical meaning for a free body) — replaced by a neutral equator ring. + ax_s2 = axes[row_offset + 0][cfg_idx] + ax_s2.plot_wireframe(sx, sy_, sz, color='lightgray', + alpha=0.3, linewidth=0.5) + # CHANGED: equator circle instead of north/south pole markers (no gravity reference) + theta_eq = np.linspace(0, 2 * np.pi, 100) + ax_s2.plot(np.cos(theta_eq), np.sin(theta_eq), np.zeros(100), + color='steelblue', linewidth=0.8, alpha=0.5) + + for trial in range(n_plot): + e1 = e1_all[:, trial, :] + u_n = u_norm_all[:, trial] + pts = e1.reshape(-1, 1, 3) + segs = np.concatenate([pts[:-1], pts[1:]], axis=1) + lc3d = Line3DCollection(segs, cmap=cmap, norm=norm, + alpha=0.6, linewidth=1.2) + lc3d.set_array(u_n[:-1]) + ax_s2.add_collection3d(lc3d) + ax_s2.scatter(*e1[0], color='black', s=8, zorder=11) + + ax_s2.set_xlim(-1.1, 1.1) + ax_s2.set_ylim(-1.1, 1.1) + ax_s2.set_zlim(-1.1, 1.1) + ax_s2.set_box_aspect([1, 1, 1]) + ax_s2.set_xlabel('x', fontsize=8) + ax_s2.set_ylabel('y', fontsize=8) + ax_s2.set_zlabel('z', fontsize=8) + ax_s2.tick_params(labelsize=7) + if row_offset == 0: + # CHANGED: column title shows 'Config N' and inertia values + if cfg_idx < len(inertia_info_list): + info = inertia_info_list[cfg_idx] + title_str = (f"Config {cfg_idx}\n" + f"I1={info['I1']:.4f} I2={info['I2']:.4f} I3={info['I3']:.4f}") + else: + title_str = f"Config {cfg_idx}" + ax_s2.set_title(title_str, fontsize=9, fontweight='bold') + if cfg_idx == 0: + # CHANGED: row label updated from 'Bob on S²' to 'e1 axis on S²' + ax_s2.text2D(-0.18, 0.5, f"{ds_label}\ne1 axis on S²", + transform=ax_s2.transAxes, fontsize=10, + rotation=90, va='center', ha='center', fontweight='bold') + + # ── Row 1/5: Intermediate-axis alignment phase portrait ── + # CHANGED: x-axis is now α = arccos(|ω̂ · e₂|), the angle between ω + # and the unstable axis e2. Pendulum used tilt α = arccos((R e_z)_z). + # This directly visualizes the Dzhanibekov flip dynamics. + ax_align = axes[row_offset + 1][cfg_idx] + for trial in range(n_plot): + a = align_angle_all[:, trial] + ad = align_rate_all[:, trial] + u_n = u_norm_all[:, trial] + pts = np.array([a, ad]).T.reshape(-1, 1, 2) + segs = np.concatenate([pts[:-1], pts[1:]], axis=1) + lc = LineCollection(segs, cmap=cmap, norm=norm, + alpha=0.6, linewidth=1.0) + lc.set_array(u_n[:-1]) + ax_align.add_collection(lc) + ax_align.scatter(a[0], ad[0], color='black', s=8, zorder=3) + ax_align.autoscale() + ax_align.grid(True, alpha=0.3, linestyle='--') + ax_align.axhline(0, color='gray', linewidth=0.5, alpha=0.5) + # CHANGED: axis label reflects intermediate-axis alignment, not tilt + ax_align.set_xlabel(r"$\alpha = \arccos(|\hat{\omega}\cdot e_2|)$ (rad)", fontsize=9) + if cfg_idx == 0: + ax_align.set_ylabel(f"{ds_label}\n" + r"$\dot{\alpha}$ (rad/s)", + fontsize=9) + + # ── Row 2/6: ||ω|| over time ── + # (unchanged in structure; only label/variable name updates) + ax_om = axes[row_offset + 2][cfg_idx] + for trial in range(n_plot): + om_n = omega_norm_all[:, trial] + u_n = u_norm_all[:, trial] + pts = np.array([t, om_n]).T.reshape(-1, 1, 2) + segs = np.concatenate([pts[:-1], pts[1:]], axis=1) + lc = LineCollection(segs, cmap=cmap, norm=norm, + alpha=0.6, linewidth=1.0) + lc.set_array(u_n[:-1]) + ax_om.add_collection(lc) + ax_om.autoscale() + ax_om.grid(True, alpha=0.3, linestyle='--') + ax_om.set_xlabel("t (s)", fontsize=9) + if cfg_idx == 0: + ax_om.set_ylabel(f"{ds_label}\n" + r"$\|\omega\|$ (rad/s)", + fontsize=9) + + # ── Row 3/7: Kinetic energy over time ── + # CHANGED: pendulum plotted total energy E = T + V (with gravity). + # Racket is a free body: no gravity, so only kinetic energy T = ½ ωᵀ I ω. + # For a torque-free trajectory this should be nearly constant (numerical check). + ax_T = axes[row_offset + 3][cfg_idx] + for trial in range(n_plot): + T_k = T_kin_all[:, trial] + u_n = u_norm_all[:, trial] + pts = np.array([t, T_k]).T.reshape(-1, 1, 2) + segs = np.concatenate([pts[:-1], pts[1:]], axis=1) + lc = LineCollection(segs, cmap=cmap, norm=norm, + alpha=0.6, linewidth=1.0) + lc.set_array(u_n[:-1]) + ax_T.add_collection(lc) + ax_T.autoscale() + ax_T.grid(True, alpha=0.3, linestyle='--') + ax_T.set_xlabel("t (s)", fontsize=9) + if cfg_idx == 0: + # CHANGED: label updated from total energy E (J) to kinetic energy T (J) + ax_T.set_ylabel(f"{ds_label}\n" + r"$T = \frac{1}{2}\omega^\top I\omega$ (J)", + fontsize=9) + + # Colorbar + # CHANGED: label updated from wind force to disturbance torque + cbar_ax = fig.add_axes([0.93, 0.12, 0.015, 0.76]) + cb = fig.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), cax=cbar_ax) + cb.set_label(r'$\|\tau_d\|$ (Nm)', fontsize=10) + + # CHANGED: figure title updated for the tennis racket / Dzhanibekov context + fig.suptitle("Tennis Racket SO(3)-Aware Phase Space " + r"(e2 = unstable intermediate axis, Dzhanibekov effect)", + fontsize=14, fontweight='bold', y=0.99) + plt.subplots_adjust(right=0.91, hspace=0.4, wspace=0.3) + + ds_name = os.path.splitext(filename)[0] + # CHANGED: output filename suffix retained as _phase_space_SO3 + out_path = os.path.join(save_dir, f'{ds_name}_phase_space_SO3.png') + fig.savefig(out_path, dpi=150, bbox_inches='tight') + print(f"Saved: {out_path}") + plt.close(fig) + + +# ─────────────────── Entry point ─────────────────── + +if __name__ == "__main__": + import argparse + + # CHANGED: description updated for tennis racket + parser = argparse.ArgumentParser(description="Plot tennis racket phase space from a dataset .pkl file.") + parser.add_argument("pkl_path", type=str, help="Path to the dataset .pkl file.") + parser.add_argument("--num_trajs", type=int, default=15, help="Number of trajectories to plot per panel.") + args = parser.parse_args() + + dataset_directory = os.path.dirname(os.path.abspath(args.pkl_path)) + filename = os.path.basename(args.pkl_path) + + # CHANGED: function names updated; m/l/g kwargs removed (not needed for free rigid body) + plot_racket_phase_space(dataset_directory, num_trajs_to_plot=args.num_trajs, filename=filename) + plot_so3_phase_space(dataset_directory, num_trajs_to_plot=args.num_trajs, filename=filename) diff --git a/datasets/tennis_racket_3d_datagen.py b/datasets/tennis_racket_3d_datagen.py new file mode 100644 index 0000000..a5e8ab7 --- /dev/null +++ b/datasets/tennis_racket_3d_datagen.py @@ -0,0 +1,477 @@ +import numpy as np +import pickle +import os +import argparse + +import sys +from pathlib import Path +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT)) + +from envs.tennis_racket_3d import tennis_racket_3d + + +# ─────────────────── Helper Functions ─────────────────── +# (identical to windy_pendulum_3d_datagen) + +def to_pickle(obj, path): + with open(path, "wb") as f: + pickle.dump(obj, f) + print(f"Saved data to {path}") + + +def from_pickle(path): + with open(path, "rb") as f: + data = pickle.load(f) + print(f"Loaded data from {path}") + return data + + +def arrange_data(x, t, num_points=2): + """Arrange data to feed into neural ODE in small chunks. + + x : (num_configs, T, N, D) + t : (T,) + + Returns: + x_stack : (num_configs, num_points, N_windows, D) + t_eval : (num_points,) + """ + assert num_points >= 2 and num_points <= len(t) + x_stack = [] + for i in range(num_points): + if i < num_points - 1: + x_stack.append(x[:, i:-num_points + i + 1, :, :]) + else: + x_stack.append(x[:, i:, :, :]) + x_stack = np.stack(x_stack, axis=1) + x_stack = np.reshape(x_stack, (x.shape[0], num_points, -1, x.shape[3])) + t_eval = t[0:num_points] + return x_stack, t_eval + + +def _project_to_so3(R): + """Numpy SVD projection to SO(3).""" + U, _, Vt = np.linalg.svd(R) + Rp = U @ Vt + if np.linalg.det(Rp) < 0: + U[:, -1] *= -1.0 + Rp = U @ Vt + return Rp + + +def add_proper_noise_3d(clean_data, obs_noise_std, rng): + """ + Apply geometrically-correct observation noise to SO(3) rigid-body data. + + Identical to windy_pendulum_3d_datagen.add_proper_noise_3d: + R_noisy = R @ expm(hat(eps)), eps ~ N(0, sigma²) + ω_noisy = ω + noise + action is NOT corrupted. + + clean_data : (..., D) where D = 9 (R) + 3 (ω) + 3 (u) = 15 + """ + noisy_data = np.copy(clean_data) + base_shape = clean_data.shape[:-1] + + R_flat = clean_data[..., :9] + omega = clean_data[..., 9:12] + action = clean_data[..., 12:] + + # Rotation noise via exponential map + eps = rng.normal(0.0, obs_noise_std, size=base_shape + (3,)) + + R_flat_2d = R_flat.reshape(-1, 9) + eps_2d = eps.reshape(-1, 3) + R_noisy_flat = np.zeros_like(R_flat_2d) + + for i in range(R_flat_2d.shape[0]): + R = R_flat_2d[i].reshape(3, 3) + e = eps_2d[i] + ex = np.array([[0, -e[2], e[1]], + [e[2], 0, -e[0]], + [-e[1], e[0], 0 ]], dtype=np.float64) + theta = np.linalg.norm(e) + if theta < 1e-10: + R_perturb = np.eye(3) + ex + else: + R_perturb = (np.eye(3) + + (np.sin(theta) / theta) * ex + + ((1 - np.cos(theta)) / theta**2) * (ex @ ex)) + R_noisy_flat[i] = (R @ R_perturb).flatten() + + noisy_data[..., :9] = R_noisy_flat.reshape(R_flat.shape) + noisy_data[..., 9:12] = omega + rng.normal(0.0, obs_noise_std, size=base_shape + (3,)) + noisy_data[..., 12:] = action # action untouched + return noisy_data + + +# ─────────────────── Racket geometry configs ─────────────────── + +def make_racket_configs(config_list): + """ + Validate and return a list of racket geometry dicts. + + Each entry in config_list is a dict with keys matching the tennis_racket_3d + constructor kwargs for geometry: + head_a, head_b, handle_length, handle_radius, + total_mass, head_mass_ratio + + Missing keys are filled with the env defaults. + """ + defaults = dict( + head_a=0.195, + head_b=0.135, + handle_length=0.25, + handle_radius=0.015, + total_mass=0.290, + head_mass_ratio=0.70, + ) + out = [] + for cfg in config_list: + full = {**defaults, **cfg} + out.append(full) + return out + + +# ─────────────────── Sampling ─────────────────── + +def sample_tennis_racket_3d( + seed=0, + timesteps=75, + trials=50, + disturbance_torque_std=0.0, + omega_0_scale=2 * np.pi, + perturb_std=0.05, + axis_weights=(0.15, 0.70, 0.15), + # [WIND] wind_force_std=0.0, + # [FRICTION] friction_coeff=0.0, + # [FRICTION] varying_friction=False, + racket_kwargs=None, + **kwargs +): + """ + Sample trajectories from the tennis_racket_3d environment. + + Returns + ------- + trajs : (timesteps, trials, obs_dim + act_dim) = (T, N, 15) + tspan : (timesteps,) + + The action stored in the last 3 dims is the disturbance torque that was + active during that trajectory (constant per trajectory, zero if + disturbance_torque_std == 0). This mirrors the pendulum convention of + storing the applied torque alongside the state. + """ + if racket_kwargs is None: + racket_kwargs = {} + + env = tennis_racket_3d( + disturbance_torque_std=disturbance_torque_std, + omega_0_scale=omega_0_scale, + perturb_std=perturb_std, + axis_weights=axis_weights, + # [WIND] wind_force_std=wind_force_std, + # [FRICTION] friction_coeff=friction_coeff, + # [FRICTION] varying_friction=varying_friction, + render_mode=None, + **racket_kwargs, + **kwargs + ) + + obs_dim = env.observation_space.shape[0] # 12 + act_dim = env.action_space.shape[0] # 3 + dt = env.dt + timesteps = int(timesteps) + trials = int(trials) + + trajs = [] + main_seed = int(seed) + + for trial in range(trials): + valid = False + retry_count = 0 + + while not valid: + if retry_count > 50: + raise RuntimeError( + "Too many retries generating a valid trajectory " + "(NaN or solver instability)." + ) + + obs, info = env.reset(seed=main_seed) + + # The disturbance torque is fixed for this episode; record it as + # the "action" so the dataset format matches the pendulum exactly. + curr_u = env._disturbance_torque.copy().astype(np.float32) + + traj = [] + x_init = np.concatenate((obs, curr_u)) + traj.append(x_init) + + for t in range(timesteps - 1): + # Pass zero control — the disturbance lives inside the env. + # To generate controlled trajectories, replace with your policy. + obs, reward, terminated, truncated, info = env.step( + np.zeros(act_dim, dtype=np.float32) + ) + x = np.concatenate((obs, curr_u)) + traj.append(x) + + if terminated or truncated: + break + + traj = np.stack(traj, axis=0) # (timesteps, 15) + + if np.isnan(traj).any(): + retry_count += 1 + main_seed += 10 + print(f"NaN detected in trial {trial}, retrying (seed={main_seed})...") + continue + + valid = True + + trajs.append(traj) + main_seed += 1 + + env.close() + + trajs = np.stack(trajs, axis=0) # (trials, timesteps, 15) + trajs = np.transpose(trajs, (1, 0, 2)) # (timesteps, trials, 15) + tspan = np.arange(timesteps) * dt + + return trajs, tspan + + +# ─────────────────── Dataset ─────────────────── + +def get_dataset( + seed=0, + samples=50, + test_split=0.5, + save_dir=None, + racket_configs=(None,), # list of dicts (or None for default) + disturbance_torque_std=0.0, + omega_0_scale=2 * np.pi, + perturb_std=0.05, + axis_weights=(0.15, 0.70, 0.15), + obs_noise_std=0.0, + timesteps=75, + # [WIND] wind_force_std=0.0, + # [FRICTION] friction_coeff=0.0, + # [FRICTION] varying_friction=False, + **kwargs +): + """ + Build (or load) one pickle per racket geometry configuration. + + Each pickle contains: + data = { + "x": (num_configs, T, N_train, 15), # train (possibly noisy) + "test_x": (num_configs, T, N_test, 15), # test clean + "test_x_noisy": (num_configs, T, N_test, 15), # test noisy + "t": (T,), + "inertia_info": list of dicts (one per config), + "settings": {...} + } + + The leading num_configs axis mirrors the num_us axis in the pendulum + datagen — each racket geometry is treated like a different "force" regime. + + Returns + ------- + data : dict (as above) + out_path : str + """ + if save_dir is None: + raise ValueError("save_dir must be specified.") + os.makedirs(save_dir, exist_ok=True) + + # ── Build canonical geometry list ── + resolved_configs = make_racket_configs( + [c if c is not None else {} for c in racket_configs] + ) + + # ── Filename mirrors pendulum convention ── + dist_str = f"dist{str(disturbance_torque_std).replace('.', 'p')}" + obs_str = f"obs_noise{str(obs_noise_std).replace('.', 'p')}" + perturb_str = f"perturb{str(perturb_std).replace('.', 'p')}" + ncfg_str = f"ncfg{len(resolved_configs)}" + steps_str = f"steps{timesteps}" + # [WIND] wind_str = f"wind{str(wind_force_std).replace('.', 'p')}" + # [FRICTION] fric_str = f"fric{str(friction_coeff).replace('.', 'p')}" + filename = ( + f"tr3d_dataset_{dist_str}_{obs_str}_{perturb_str}" + f"_{ncfg_str}_{steps_str}.pkl" + ) + out_path = os.path.join(save_dir, filename) + + try: + data = from_pickle(out_path) + return data, out_path + except FileNotFoundError: + print(f"Building dataset at {out_path}...") + + # ── Collect trajectories for each geometry config ── + trajs_per_config = [] + inertia_info_list = [] + + for i, rcfg in enumerate(resolved_configs): + current_seed = int(seed) + (i * 10000) + print(f" Config {i+1}/{len(resolved_configs)}: {rcfg}") + + trajs, tspan = sample_tennis_racket_3d( + seed=current_seed, + timesteps=timesteps, + trials=samples, + disturbance_torque_std=disturbance_torque_std, + omega_0_scale=omega_0_scale, + perturb_std=perturb_std, + axis_weights=axis_weights, + # [WIND] wind_force_std=wind_force_std, + # [FRICTION] friction_coeff=friction_coeff, + # [FRICTION] varying_friction=varying_friction, + racket_kwargs=rcfg, + **kwargs + ) + trajs_per_config.append(trajs) + + # Record inertia info from a throw-away env instance + _env = tennis_racket_3d(**rcfg, render_mode=None) + inertia_info_list.append(_env.get_inertia_info()) + _env.close() + + # (num_configs, T, N, D) + all_clean_x = np.stack(trajs_per_config, axis=0) + + # ── Train / test split ── + if test_split >= 0.5: + split_ix = int(samples * 0.5) + else: + split_ix = int(samples * (1.0 - test_split)) + + train_clean_x = all_clean_x[:, :, :split_ix, :] + test_clean_x = all_clean_x[:, :, split_ix:, :] + + # ── Observation noise ── + data = {} + data['t'] = tspan + data['inertia_info'] = inertia_info_list + data['settings'] = { + 'seed': seed, + 'samples': samples, + 'test_split': test_split, + 'racket_configs': resolved_configs, + 'disturbance_torque_std': disturbance_torque_std, + 'omega_0_scale': omega_0_scale, + 'perturb_std': perturb_std, + 'axis_weights': list(axis_weights), + 'obs_noise_std': obs_noise_std, + 'timesteps': timesteps, + # [WIND] 'wind_force_std': wind_force_std, + # [FRICTION] 'friction_coeff': friction_coeff, + # [FRICTION] 'varying_friction': varying_friction, + } + + if obs_noise_std > 0.0: + print(f"Applying SO(3) observation noise (std={obs_noise_std})...") + rng = np.random.default_rng(seed + 999) + + data['x'] = add_proper_noise_3d(train_clean_x, obs_noise_std, rng) + data['test_x'] = test_clean_x + data['test_x_noisy'] = add_proper_noise_3d( + test_clean_x, obs_noise_std, + np.random.default_rng(seed + 1999) + ) + else: + print("No observation noise added.") + data['x'] = train_clean_x + data['test_x'] = test_clean_x + data['test_x_noisy'] = test_clean_x + + to_pickle(data, out_path) + return data, out_path + + +# ─────────────────── CLI ─────────────────── + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate a tennis-racket free-rigid-body dataset on SO(3)." + ) + parser.add_argument("--save_dir", type=str, default="data/tennis_data") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--samples", type=int, default=50) + parser.add_argument("--timesteps", type=int, default=100) + parser.add_argument("--test_split", type=float, default=0.5) + parser.add_argument("--obs_noise_std", type=float, default=0.0) + parser.add_argument("--perturb_std", type=float, default=0.05) + parser.add_argument("--disturbance_torque_std", type=float, default=0.0) + parser.add_argument("--omega_0_scale", type=float, default=6.2832) + # [WIND] parser.add_argument("--wind_force_std", type=float, default=0.0) + # [FRICTION] parser.add_argument("--friction_coeff", type=float, default=0.0) + # [FRICTION] parser.add_argument("--varying_friction", action="store_true") + + # Geometry overrides (applied to ALL configs when using CLI) + parser.add_argument("--head_a", type=float, default=0.195) + parser.add_argument("--head_b", type=float, default=0.135) + parser.add_argument("--handle_length", type=float, default=0.25) + parser.add_argument("--handle_radius", type=float, default=0.015) + parser.add_argument("--total_mass", type=float, default=0.290) + parser.add_argument("--head_mass_ratio", type=float, default=0.70) + + args = parser.parse_args() + + # ── Define geometry sweep ── + # Each dict overrides only the keys you care about; the rest use defaults. + # This example sweeps over three racket sizes (light/medium/heavy head). + # Edit freely — or pass a single {} for the default geometry only. + racket_configs = [ + # Default geometry + {}, + # Heavier head (more pronounced intermediate-axis instability) + {"head_mass_ratio": 0.80, "total_mass": 0.300}, + # Lighter head / longer handle + {"head_mass_ratio": 0.60, "handle_length": 0.28, "total_mass": 0.280}, + # Wider head (larger a → I1 further from I2) + {"head_a": 0.210, "head_b": 0.135}, + ] + + # Override the first config with any CLI geometry flags + racket_configs[0] = dict( + head_a=args.head_a, + head_b=args.head_b, + handle_length=args.handle_length, + handle_radius=args.handle_radius, + total_mass=args.total_mass, + head_mass_ratio=args.head_mass_ratio, + ) + + data, path = get_dataset( + seed=args.seed, + samples=args.samples, + timesteps=args.timesteps, + test_split=args.test_split, + save_dir=args.save_dir, + racket_configs=racket_configs, + disturbance_torque_std=args.disturbance_torque_std, + omega_0_scale=args.omega_0_scale, + perturb_std=args.perturb_std, + obs_noise_std=args.obs_noise_std, + # [WIND] wind_force_std=args.wind_force_std, + # [FRICTION] friction_coeff=args.friction_coeff, + # [FRICTION] varying_friction=args.varying_friction, + ) + + print("\nDone.") + print(f"Train (x): {data['x'].shape}") + print(f"Test (clean): {data['test_x'].shape}") + print(f"Test (noisy): {data['test_x_noisy'].shape}") + print(f"Timesteps: {data['t'].shape}") + print("\nInertia tensors per config:") + for i, info in enumerate(data['inertia_info']): + print( + f" [{i}] I1={info['I1']:.5f} I2={info['I2']:.5f} " + f"I3={info['I3']:.5f} kg·m² " + f"(head_a={info['head_a']}, mass={info['total_mass']}kg)" + ) diff --git a/datasets/tennis_racket_3d_datagen_friction.py b/datasets/tennis_racket_3d_datagen_friction.py new file mode 100644 index 0000000..89f0305 --- /dev/null +++ b/datasets/tennis_racket_3d_datagen_friction.py @@ -0,0 +1,477 @@ +import numpy as np +import pickle +import os +import argparse + +import sys +from pathlib import Path +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT)) + +from envs.tennis_racket_3d_friction import tennis_racket_3d + + +# ─────────────────── Helper Functions ─────────────────── +# (identical to windy_pendulum_3d_datagen) + +def to_pickle(obj, path): + with open(path, "wb") as f: + pickle.dump(obj, f) + print(f"Saved data to {path}") + + +def from_pickle(path): + with open(path, "rb") as f: + data = pickle.load(f) + print(f"Loaded data from {path}") + return data + + +def arrange_data(x, t, num_points=2): + """Arrange data to feed into neural ODE in small chunks. + + x : (num_configs, T, N, D) + t : (T,) + + Returns: + x_stack : (num_configs, num_points, N_windows, D) + t_eval : (num_points,) + """ + assert num_points >= 2 and num_points <= len(t) + x_stack = [] + for i in range(num_points): + if i < num_points - 1: + x_stack.append(x[:, i:-num_points + i + 1, :, :]) + else: + x_stack.append(x[:, i:, :, :]) + x_stack = np.stack(x_stack, axis=1) + x_stack = np.reshape(x_stack, (x.shape[0], num_points, -1, x.shape[3])) + t_eval = t[0:num_points] + return x_stack, t_eval + + +def _project_to_so3(R): + """Numpy SVD projection to SO(3).""" + U, _, Vt = np.linalg.svd(R) + Rp = U @ Vt + if np.linalg.det(Rp) < 0: + U[:, -1] *= -1.0 + Rp = U @ Vt + return Rp + + +def add_proper_noise_3d(clean_data, obs_noise_std, rng): + """ + Apply geometrically-correct observation noise to SO(3) rigid-body data. + + Identical to windy_pendulum_3d_datagen.add_proper_noise_3d: + R_noisy = R @ expm(hat(eps)), eps ~ N(0, sigma²) + ω_noisy = ω + noise + action is NOT corrupted. + + clean_data : (..., D) where D = 9 (R) + 3 (ω) + 3 (u) = 15 + """ + noisy_data = np.copy(clean_data) + base_shape = clean_data.shape[:-1] + + R_flat = clean_data[..., :9] + omega = clean_data[..., 9:12] + action = clean_data[..., 12:] + + # Rotation noise via exponential map + eps = rng.normal(0.0, obs_noise_std, size=base_shape + (3,)) + + R_flat_2d = R_flat.reshape(-1, 9) + eps_2d = eps.reshape(-1, 3) + R_noisy_flat = np.zeros_like(R_flat_2d) + + for i in range(R_flat_2d.shape[0]): + R = R_flat_2d[i].reshape(3, 3) + e = eps_2d[i] + ex = np.array([[0, -e[2], e[1]], + [e[2], 0, -e[0]], + [-e[1], e[0], 0 ]], dtype=np.float64) + theta = np.linalg.norm(e) + if theta < 1e-10: + R_perturb = np.eye(3) + ex + else: + R_perturb = (np.eye(3) + + (np.sin(theta) / theta) * ex + + ((1 - np.cos(theta)) / theta**2) * (ex @ ex)) + R_noisy_flat[i] = (R @ R_perturb).flatten() + + noisy_data[..., :9] = R_noisy_flat.reshape(R_flat.shape) + noisy_data[..., 9:12] = omega + rng.normal(0.0, obs_noise_std, size=base_shape + (3,)) + noisy_data[..., 12:] = action # action untouched + return noisy_data + + +# ─────────────────── Racket geometry configs ─────────────────── + +def make_racket_configs(config_list): + """ + Validate and return a list of racket geometry dicts. + + Each entry in config_list is a dict with keys matching the tennis_racket_3d + constructor kwargs for geometry: + head_a, head_b, handle_length, handle_radius, + total_mass, head_mass_ratio + + Missing keys are filled with the env defaults. + """ + defaults = dict( + head_a=0.195, + head_b=0.135, + handle_length=0.25, + handle_radius=0.015, + total_mass=0.290, + head_mass_ratio=0.70, + ) + out = [] + for cfg in config_list: + full = {**defaults, **cfg} + out.append(full) + return out + + +# ─────────────────── Sampling ─────────────────── + +def sample_tennis_racket_3d( + seed=0, + timesteps=75, + trials=50, + disturbance_torque_std=0.0, + omega_0_scale=2 * np.pi, + perturb_std=0.05, + axis_weights=(0.15, 0.70, 0.15), + # [WIND] wind_force_std=0.0, + friction_coeff=0.0, + varying_friction=False, + racket_kwargs=None, + **kwargs +): + """ + Sample trajectories from the tennis_racket_3d environment. + + Returns + ------- + trajs : (timesteps, trials, obs_dim + act_dim) = (T, N, 15) + tspan : (timesteps,) + + The action stored in the last 3 dims is the disturbance torque that was + active during that trajectory (constant per trajectory, zero if + disturbance_torque_std == 0). This mirrors the pendulum convention of + storing the applied torque alongside the state. + """ + if racket_kwargs is None: + racket_kwargs = {} + + env = tennis_racket_3d( + disturbance_torque_std=disturbance_torque_std, + omega_0_scale=omega_0_scale, + perturb_std=perturb_std, + axis_weights=axis_weights, + # [WIND] wind_force_std=wind_force_std, + friction_coeff=friction_coeff, + varying_friction=varying_friction, + render_mode=None, + **racket_kwargs, + **kwargs + ) + + obs_dim = env.observation_space.shape[0] # 12 + act_dim = env.action_space.shape[0] # 3 + dt = env.dt + timesteps = int(timesteps) + trials = int(trials) + + trajs = [] + main_seed = int(seed) + + for trial in range(trials): + valid = False + retry_count = 0 + + while not valid: + if retry_count > 50: + raise RuntimeError( + "Too many retries generating a valid trajectory " + "(NaN or solver instability)." + ) + + obs, info = env.reset(seed=main_seed) + + # The disturbance torque is fixed for this episode; record it as + # the "action" so the dataset format matches the pendulum exactly. + curr_u = env._disturbance_torque.copy().astype(np.float32) + + traj = [] + x_init = np.concatenate((obs, curr_u)) + traj.append(x_init) + + for t in range(timesteps - 1): + # Pass zero control — the disturbance lives inside the env. + # To generate controlled trajectories, replace with your policy. + obs, reward, terminated, truncated, info = env.step( + np.zeros(act_dim, dtype=np.float32) + ) + x = np.concatenate((obs, curr_u)) + traj.append(x) + + if terminated or truncated: + break + + traj = np.stack(traj, axis=0) # (timesteps, 15) + + if np.isnan(traj).any(): + retry_count += 1 + main_seed += 10 + print(f"NaN detected in trial {trial}, retrying (seed={main_seed})...") + continue + + valid = True + + trajs.append(traj) + main_seed += 1 + + env.close() + + trajs = np.stack(trajs, axis=0) # (trials, timesteps, 15) + trajs = np.transpose(trajs, (1, 0, 2)) # (timesteps, trials, 15) + tspan = np.arange(timesteps) * dt + + return trajs, tspan + + +# ─────────────────── Dataset ─────────────────── + +def get_dataset( + seed=0, + samples=50, + test_split=0.5, + save_dir=None, + racket_configs=(None,), # list of dicts (or None for default) + disturbance_torque_std=0.0, + omega_0_scale=2 * np.pi, + perturb_std=0.05, + axis_weights=(0.15, 0.70, 0.15), + obs_noise_std=0.0, + timesteps=75, + # [WIND] wind_force_std=0.0, + friction_coeff=0.0, + varying_friction=False, + **kwargs +): + """ + Build (or load) one pickle per racket geometry configuration. + + Each pickle contains: + data = { + "x": (num_configs, T, N_train, 15), # train (possibly noisy) + "test_x": (num_configs, T, N_test, 15), # test clean + "test_x_noisy": (num_configs, T, N_test, 15), # test noisy + "t": (T,), + "inertia_info": list of dicts (one per config), + "settings": {...} + } + + The leading num_configs axis mirrors the num_us axis in the pendulum + datagen — each racket geometry is treated like a different "force" regime. + + Returns + ------- + data : dict (as above) + out_path : str + """ + if save_dir is None: + raise ValueError("save_dir must be specified.") + os.makedirs(save_dir, exist_ok=True) + + # ── Build canonical geometry list ── + resolved_configs = make_racket_configs( + [c if c is not None else {} for c in racket_configs] + ) + + # ── Filename mirrors pendulum convention ── + dist_str = f"dist{str(disturbance_torque_std).replace('.', 'p')}" + obs_str = f"obs_noise{str(obs_noise_std).replace('.', 'p')}" + perturb_str = f"perturb{str(perturb_std).replace('.', 'p')}" + ncfg_str = f"ncfg{len(resolved_configs)}" + steps_str = f"steps{timesteps}" + # [WIND] wind_str = f"wind{str(wind_force_std).replace('.', 'p')}" + fric_str = f"fric{str(friction_coeff).replace('.', 'p')}" + filename = ( + f"tr3d_dataset_{dist_str}_{obs_str}_{perturb_str}" + f"_{fric_str}_{ncfg_str}_{steps_str}.pkl" + ) + out_path = os.path.join(save_dir, filename) + + try: + data = from_pickle(out_path) + return data, out_path + except FileNotFoundError: + print(f"Building dataset at {out_path}...") + + # ── Collect trajectories for each geometry config ── + trajs_per_config = [] + inertia_info_list = [] + + for i, rcfg in enumerate(resolved_configs): + current_seed = int(seed) + (i * 10000) + print(f" Config {i+1}/{len(resolved_configs)}: {rcfg}") + + trajs, tspan = sample_tennis_racket_3d( + seed=current_seed, + timesteps=timesteps, + trials=samples, + disturbance_torque_std=disturbance_torque_std, + omega_0_scale=omega_0_scale, + perturb_std=perturb_std, + axis_weights=axis_weights, + # [WIND] wind_force_std=wind_force_std, + friction_coeff=friction_coeff, + varying_friction=varying_friction, + racket_kwargs=rcfg, + **kwargs + ) + trajs_per_config.append(trajs) + + # Record inertia info from a throw-away env instance + _env = tennis_racket_3d(**rcfg, render_mode=None) + inertia_info_list.append(_env.get_inertia_info()) + _env.close() + + # (num_configs, T, N, D) + all_clean_x = np.stack(trajs_per_config, axis=0) + + # ── Train / test split ── + if test_split >= 0.5: + split_ix = int(samples * 0.5) + else: + split_ix = int(samples * (1.0 - test_split)) + + train_clean_x = all_clean_x[:, :, :split_ix, :] + test_clean_x = all_clean_x[:, :, split_ix:, :] + + # ── Observation noise ── + data = {} + data['t'] = tspan + data['inertia_info'] = inertia_info_list + data['settings'] = { + 'seed': seed, + 'samples': samples, + 'test_split': test_split, + 'racket_configs': resolved_configs, + 'disturbance_torque_std': disturbance_torque_std, + 'omega_0_scale': omega_0_scale, + 'perturb_std': perturb_std, + 'axis_weights': list(axis_weights), + 'obs_noise_std': obs_noise_std, + 'timesteps': timesteps, + # [WIND] 'wind_force_std': wind_force_std, + 'friction_coeff': friction_coeff, + 'varying_friction': varying_friction, + } + + if obs_noise_std > 0.0: + print(f"Applying SO(3) observation noise (std={obs_noise_std})...") + rng = np.random.default_rng(seed + 999) + + data['x'] = add_proper_noise_3d(train_clean_x, obs_noise_std, rng) + data['test_x'] = test_clean_x + data['test_x_noisy'] = add_proper_noise_3d( + test_clean_x, obs_noise_std, + np.random.default_rng(seed + 1999) + ) + else: + print("No observation noise added.") + data['x'] = train_clean_x + data['test_x'] = test_clean_x + data['test_x_noisy'] = test_clean_x + + to_pickle(data, out_path) + return data, out_path + + +# ─────────────────── CLI ─────────────────── + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate a tennis-racket free-rigid-body dataset on SO(3)." + ) + parser.add_argument("--save_dir", type=str, default="data/tennis_data") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--samples", type=int, default=50) + parser.add_argument("--timesteps", type=int, default=100) + parser.add_argument("--test_split", type=float, default=0.5) + parser.add_argument("--obs_noise_std", type=float, default=0.0) + parser.add_argument("--perturb_std", type=float, default=0.05) + parser.add_argument("--disturbance_torque_std", type=float, default=0.0) + parser.add_argument("--omega_0_scale", type=float, default=6.2832) + # [WIND] parser.add_argument("--wind_force_std", type=float, default=0.0) + parser.add_argument("--friction_coeff", type=float, default=0.0) + parser.add_argument("--varying_friction", action="store_true") + + # Geometry overrides (applied to ALL configs when using CLI) + parser.add_argument("--head_a", type=float, default=0.195) + parser.add_argument("--head_b", type=float, default=0.135) + parser.add_argument("--handle_length", type=float, default=0.25) + parser.add_argument("--handle_radius", type=float, default=0.015) + parser.add_argument("--total_mass", type=float, default=0.290) + parser.add_argument("--head_mass_ratio", type=float, default=0.70) + + args = parser.parse_args() + + # ── Define geometry sweep ── + # Each dict overrides only the keys you care about; the rest use defaults. + # This example sweeps over three racket sizes (light/medium/heavy head). + # Edit freely — or pass a single {} for the default geometry only. + racket_configs = [ + # Default geometry + {}, + # Heavier head (more pronounced intermediate-axis instability) + {"head_mass_ratio": 0.80, "total_mass": 0.300}, + # Lighter head / longer handle + {"head_mass_ratio": 0.60, "handle_length": 0.28, "total_mass": 0.280}, + # Wider head (larger a → I1 further from I2) + {"head_a": 0.210, "head_b": 0.135}, + ] + + # Override the first config with any CLI geometry flags + racket_configs[0] = dict( + head_a=args.head_a, + head_b=args.head_b, + handle_length=args.handle_length, + handle_radius=args.handle_radius, + total_mass=args.total_mass, + head_mass_ratio=args.head_mass_ratio, + ) + + data, path = get_dataset( + seed=args.seed, + samples=args.samples, + timesteps=args.timesteps, + test_split=args.test_split, + save_dir=args.save_dir, + racket_configs=racket_configs, + disturbance_torque_std=args.disturbance_torque_std, + omega_0_scale=args.omega_0_scale, + perturb_std=args.perturb_std, + obs_noise_std=args.obs_noise_std, + # [WIND] wind_force_std=args.wind_force_std, + friction_coeff=args.friction_coeff, + varying_friction=args.varying_friction, + ) + + print("\nDone.") + print(f"Train (x): {data['x'].shape}") + print(f"Test (clean): {data['test_x'].shape}") + print(f"Test (noisy): {data['test_x_noisy'].shape}") + print(f"Timesteps: {data['t'].shape}") + print("\nInertia tensors per config:") + for i, info in enumerate(data['inertia_info']): + print( + f" [{i}] I1={info['I1']:.5f} I2={info['I2']:.5f} " + f"I3={info['I3']:.5f} kg·m² " + f"(head_a={info['head_a']}, mass={info['total_mass']}kg)" + ) diff --git a/docs/hackmd_advisor.md b/docs/hackmd_advisor.md new file mode 100644 index 0000000..faf38ba --- /dev/null +++ b/docs/hackmd_advisor.md @@ -0,0 +1,297 @@ +# Dzhanibekov Effect: Port-Hamiltonian Neural ODE on SO(3) + +**Ecaterina Sur — LieSPHGP Project** + +[TOC] + +--- + +## 1. Motivation + +The **Dzhanibekov effect** (intermediate-axis theorem) is a striking phenomenon in rigid-body mechanics: an object spinning freely about its second principal axis will spontaneously flip its orientation 180°, repeatedly, with no external torque applied. The effect is famously visible in microgravity and follows directly from the instability of the intermediate-inertia axis in Euler's equations. + +This project uses the tennis racket as the test system and pursues two goals: + +1. **Learn** the port-Hamiltonian structure $H(R,\omega)$ of the rigid body directly from trajectory data, with no prior knowledge of the inertia tensor. +2. **Control** the learned system — including stabilising the open-loop unstable intermediate axis and achieving full orientation + angular velocity targets on $SO(3)\times\mathbb{R}^3$. + +The progression is: dataset generation → Hamiltonian identification (Stages A–C) → IDA-PBC control design (Stages D–F). + +--- + +## 2. Mathematics + +### 2.1 Rigid body kinematics and Euler's equations + +Let $R \in SO(3)$ be the body orientation and $\omega \in \mathbb{R}^3$ the angular velocity in the body frame. The equations of motion are: + +$$ +I\dot{\omega} = -\omega \times (I\omega) + \tau, \qquad \dot{R} = R\,\widehat{\omega} +$$ + +where $I = \mathrm{diag}(I_1, I_2, I_3)$ is the principal-axis inertia tensor and $\widehat{\omega}$ is the skew-symmetric matrix satisfying $\widehat{\omega}\, v = \omega \times v$. For the tennis racket (cfg0): + +| Axis | Geometry | Inertia | Open-loop stability | +|------|----------|---------|---------------------| +| $e_1$ | long head axis | $I_1 = 5.72 \times 10^{-3}$ kg·m² | **stable** | +| $e_2$ | short head axis | $I_2 = 1.12 \times 10^{-2}$ kg·m² | **unstable** | +| $e_3$ | handle axis | $I_3 = 1.32 \times 10^{-2}$ kg·m² | **stable** | + +### 2.2 The intermediate-axis instability + +Linearise about steady spin $\omega^* = \omega_2 e_2$ (intermediate axis). The perturbation $(\delta\omega_1, \delta\omega_3)$ satisfies: + +$$ +\begin{pmatrix}\delta\dot{\omega}_1 \\ \delta\dot{\omega}_3\end{pmatrix} += \underbrace{\begin{pmatrix} 0 & \tfrac{(I_2-I_3)\omega_2}{I_1} \\ \tfrac{(I_1-I_2)\omega_2}{I_3} & 0\end{pmatrix}}_{A_{\text{free}}}\begin{pmatrix}\delta\omega_1\\ \delta\omega_3\end{pmatrix} +$$ + +The eigenvalues of $A_{\text{free}}$ are $\lambda = \pm\,\omega_2\sqrt{\dfrac{(I_2-I_3)(I_1-I_2)}{I_1 I_3}}$. + +Since $I_1 < I_2 < I_3$, both factors $(I_2 - I_3)$ and $(I_1 - I_2)$ are **negative**, so their product is positive and $\lambda$ is **real**: a saddle point with exponential divergence. The doubling time at $\omega_2 = 2\pi$ rad/s is $1/\lambda = 0.42$ s. + +Spin about $e_1$ or $e_3$ gives $\lambda^2 < 0$ — imaginary eigenvalues, stable oscillation. + +### 2.3 Port-Hamiltonian structure + +The system is cast as a port-Hamiltonian system on $SO(3)\times\mathbb{R}^3$ with Hamiltonian: + +$$ +H(R,\omega) = \tfrac{1}{2}\,\omega^\top M^{-1}(R)\,\omega + V(R) +$$ + +where $M^{-1}(R)$ is the effective inverse inertia (equal to $\mathrm{diag}(1/I_1,1/I_2,1/I_3)$ for a free body), and $V(R) = 0$ for a body in free space. The full dynamics are: + +$$ +\dot{x} = \bigl[J(x) - D(x)\bigr]\nabla_x H + g(R)\,u +$$ + +with $x = (R,\omega)$ and: +- $J(x)$: skew-symmetric **interconnection matrix** (encodes gyroscopic coupling via the Lie–Poisson bracket on $\mathfrak{so}(3)^*$) +- $D(x) \succeq 0$: symmetric **dissipation matrix** +- $g(R)$: **input coupling matrix** ($= I_3$ for a direct body-frame torque actuator) + +The structure is **energy-consistent** by construction: $\dot{H} = -\omega^\top D\,\omega + \omega^\top u \leq \omega^\top u$ (power balance), guaranteeing passivity. + +--- + +## 3. What is learned + +Four neural sub-networks jointly parametrise the port-Hamiltonian structure. Each takes the extended configuration $q_\text{ext} = [\text{vec}(R),\, I_1, I_2, I_3] \in \mathbb{R}^{12}$ as input (in Stage C; $\text{vec}(R) \in \mathbb{R}^9$ alone in Stages A–B). + +| Network | Output | Constraint | Physical meaning | +|---------|--------|-----------|-----------------| +| `M_net` | $3\times 3$ | PSD via $LL^\top + \epsilon I$ | inverse inertia $M^{-1}(R)$ | +| `V_net` | scalar | — | potential energy $V(R)$ | +| `Dw_net` | $3\times 3$ | PSD | dissipation $D(R)$ | +| `g_net` | $3\times 3$ | — | input coupling $g(R)$ | + +The neural ODE is integrated with `torchdiffeq` RK4; the training loss is a **windowed geodesic loss** on $SO(3)$: + +$$ +\mathcal{L} = \frac{1}{N}\sum_{t}\bigl\|\log(R_{\text{pred}}(t)^\top R_{\text{true}}(t))\bigr\|_F^2 +$$ + +This is coordinate-free and respects the geometry of $SO(3)$. + +### 3.1 Key design challenge: Hamiltonian degeneracy + +For a free body ($V = 0$ in physics), the Hamiltonian $H = \frac{1}{2}\omega^\top M^{-1}\omega$ is invariant under $M \to \alpha M$ for any $\alpha > 0$ — the dynamics only depend on **ratios** of inertia components, not their absolute scale. Without intervention, the optimizer finds a saddle $(M_\text{wrong},\, V_\text{compensating})$ where $V$ absorbs what $M$ misses: the geodesic loss converges while $M_\text{loss}$ diverges. + +**Fix 1** — `--lambda_V_zero 1.0`: adds a $\|V_\theta\|$ regularisation term to force $V \to 0$, recovering the physical $M$. + +**Fix 2** — `FixedInertiaFromState`: in Stage C the inertia values are embedded in the 18D state vector. Rather than training `M_net` to recover them from geometry (which reintroduces the degeneracy), we pin $M^{-1}$ to the exact physical value read from the state: + +```python +class FixedInertiaFromState(nn.Module): + def forward(self, q_ext): # q_ext: (N, 12) + return torch.diag_embed(1.0 / q_ext[:, 9:12]) # (N, 3, 3) +``` + +This makes multi-geometry generalisation exact by construction. + +--- + +## 4. Training stages + +### Stage A — single geometry, torque-free + +Train on one racket geometry ($I_1, I_2, I_3$ fixed) with zero applied torque. `--fix_M` pins `M_net` to the exact inertia and `--lambda_V_zero 1.0` forces $V \to 0$. + +**Result**: windowed geodesic loss $= 2.03\times 10^{-6}$ (threshold $0.01$ rad²). $V$ and $D_w$ converge to machine $\epsilon$. + +### Stage B — friction identification + +A physical friction torque $\tau_\text{fric} = -\kappa\,\omega$ is added to the dataset. `Dw_net` learns the dissipation without any supervision signal on $D_w$ itself. + +**Result**: $D_w$ loss $= 5.13 \times 10^{-14}$ — ten orders of magnitude below threshold. (Friction makes the system contractive, collapsing trajectory variance.) + +### Stage C — multi-geometry generalisation + +State extended from 15D to 18D: $[\text{vec}(R)_9,\, I_1 I_2 I_3,\, \omega_3,\, u_3]$. Train jointly on 4 geometries; test on a held-out 5th. + +**Result**: windowed geo $= 2.03\times 10^{-6}$ on all 4 geometries and on the unseen 5th. +Generalisation is **exact** because `FixedInertiaFromState` reads the exact $I$ from the embedded state — it does not need to learn anything geometry-dependent. + +--- + +## 5. Control design + +### 5.1 IDA-PBC framework + +**Interconnection and Damping Assignment PBC** (Ortega et al. 2001) designs a controller by specifying a desired closed-loop port-Hamiltonian system with Hamiltonian $H_d$, interconnection $J_d$, and dissipation $D_d$. The control law solves the **matching equation**: + +$$ +g(R)\,u = \bigl[J_d - D_d\bigr]\nabla H_d - \bigl[J - D\bigr]\nabla H +$$ + +For our environment the actuator applies torque directly, so $g = I_3$ exactly (confirmed by training: `g_net` $\to I_3$). The matching equation becomes straightforward. + +### 5.2 Stage D — velocity stabilisation (e₃ target) + +Choose $H_d = H + \frac{1}{2}K_p\|\omega - \omega^*\|^2$, $J_d = J$, $D_d = D + K_p I$. Substituting and cancelling: + +$$ +u = K_p(\omega^* - \omega), \qquad \omega^* = (0,0,2\pi) \text{ rad/s} +$$ + +The learned model is not used at control time — the IDA-PBC law simplifies entirely because $g = I_3$. The linearised closed-loop about $\omega^* = \omega_3^* e_3$ has eigenvalues: + +$$ +\lambda_3 = -K_p/I_3, \qquad \lambda_{1,2} = \tfrac{\text{tr}A \pm \sqrt{\text{tr}^2 A - 4\det A}}{2} +$$ + +Overdamped for $K_p \geq 0.071$. **Result**: $K_p = 0.10$ gives 100% convergence from Dzhanibekov tumbling in $0.30$ s. + +### 5.3 Stage E — saddle stabilisation and ZOH instability + +**Saddle stabilisation** (hold $e_2$): the closed-loop matrix at $\omega^* = \omega_2 e_2$ is: + +$$ +A_{cl} = \begin{pmatrix} -K_p/I_1 & (I_2-I_3)\omega_2/I_1 \\ (I_1-I_2)\omega_2/I_3 & -K_p/I_3 \end{pmatrix} +$$ + +Stability requires $\det(A_{cl}) > 0$, giving: + +$$ +K_{p,\min} = \omega_2\sqrt{|I_2-I_3|\cdot|I_1-I_2|} = 0.021 \text{ N·m·s/rad} +$$ + +Empirical threshold: $K_p \leq 0.020$ (matches theory to 5%). + +**ZOH discrete instability** (new finding): the environment applies $u$ at $dt = 0.05$ s and holds it constant for 10 integration substeps. The discrete-time eigenvalue for axis $i$ is: + +$$ +z_i = 1 - K_p\,T/I_i, \qquad T = dt = 0.05 \text{ s} +$$ + +For $|z_i| < 1$ we need $K_p < 2I_i/T$. The binding constraint is the smallest inertia $I_1$: + +$$ +K_{p,\max} = \frac{2I_1}{T} = \frac{2 \times 5.72\times10^{-3}}{0.05} = 0.229 \text{ N·m·s/rad} +$$ + +At $K_p = 0.50$: $z_1 = 1 - 4.39 = -3.39$, $|z_1| \gg 1$ — **confirmed divergence** to 14 rad/s. + +**Wind robustness**: constant disturbance $d$ creates a steady-state velocity offset $\|\delta\omega\|_{ss} = |d|/K_p$. The viable window under wind $\sigma = 0.05$ N·m: + +$$ +K_p \in \!\left(\frac{|d|}{\varepsilon_\text{conv}},\; \frac{2I_1}{T}\right) \approx (0.17,\; 0.23) +$$ + +$K_p = 0.20$ achieves 95% convergence; ratio $\|\delta\omega\|_\text{actual}/(|d|/K_p) = 1.019$ (linear prediction is tight). + +### 5.4 Stage F — full-state geometric attitude control + +**What Stage D/E cannot do**: the proportional controller $u = K_p(\omega^*-\omega)$ stabilises $\omega$ but leaves $R$ free to drift anywhere on $SO(3)$. + +**Geodesic attitude error**: the natural error on $SO(3)$ is: + +$$ +e_R = \mathrm{vee}\!\bigl(\log(R^{*\top} R)\bigr) \in \mathbb{R}^3 +$$ + +where $\log: SO(3) \to \mathfrak{so}(3)$ is the matrix logarithm (Rodrigues formula): + +$$ +\log R = \frac{\theta}{2\sin\theta}(R - R^\top), \quad \theta = \arccos\!\tfrac{\text{tr}R - 1}{2} +$$ + +and $\mathrm{vee}$ extracts the axial vector of a skew-symmetric matrix. $\|e_R\| = \theta$ is the geodesic distance from $R^*$ to $R$ on $SO(3)$. + +**Stage F controller** (IDA-PBC with $H_d = H + \frac{1}{2}K_R\|e_R\|^2 + \frac{1}{2}K_p\|e_\omega\|^2$): + +$$ +u = -K_R\, e_R - K_p\,(omega - \omega^*) +$$ + +The linearised closed-loop per axis $i$ is a **2nd-order oscillator**: + +$$ +\begin{pmatrix}\dot{e}_{R,i}\\ \dot{e}_{\omega,i}\end{pmatrix} = \begin{pmatrix}0 & 1 \\ -K_R/I_i & -K_p/I_i\end{pmatrix}\begin{pmatrix}e_{R,i}\\ e_{\omega,i}\end{pmatrix} +$$ + +with $\omega_n = \sqrt{K_R/I_i}$ and $\zeta = K_p/(2\sqrt{K_R I_i})$. + +**ZOH bound for the attitude loop**: the 2nd-order discrete system (Schur stability) gives an additional constraint: + +$$ +K_{R,\max} = K_p/T = 0.10/0.05 = 2.00 \text{ N·m/rad} +$$ + +**Orientation–velocity duality under wind**: with $\omega^* = 0$, constant wind $d$ is absorbed entirely into a steady-state **orientation** offset (the body comes to rest at a tilted angle): + +$$ +K_R\, e_R = d \;\Rightarrow\; \|e_R\|_{ss} = |d|/K_R, \qquad \|\omega\|_{ss} = 0 +$$ + +Compare with Stage E ($\omega^* \neq 0$): wind creates a **velocity** offset $|d|/K_p$. The ratio $\|e_R\|_\text{actual}/(|d|/K_R) = 1.000$ exactly for all tested $K_R \in [0.10, 1.50]$ N·m/rad. + +--- + +## 6. Results summary + +| Stage | Task | Metric | Result | +|-------|------|--------|--------| +| A | Single-config identification | windowed geo | $2.03\times10^{-6}$ rad² ✓ | +| B | Friction ($D_w$) identification | $D_w$ loss | $5.13\times10^{-14}$ ✓ | +| C | 4-config + held-out 5th | windowed geo | $2.03\times10^{-6}$ all ✓ | +| D | Stabilise $e_3$ spin from tumbling | conv time | 0.30 s, 100%, $K_p=0.10$ ✓ | +| E-B | Hold unstable $e_2$ axis | empirical $K_{p,\min}$ | 0.020 (theory 0.021) ✓ | +| E-C | Hold $e_2$ with wind $\sigma=0.05$ | 95% conv | $K_p=0.20$, ratio 1.019 ✓ | +| F-A | Rest at $R^*=I_3$, $\omega^*=0$ | conv time | 2.75 s, 100%, $K_R=0.10$ ✓ | +| F-B | Rest at $R^*=R_y(\pi/4)$, $\omega^*=0$ | conv time | 2.82 s, 100% ✓ | +| F-D | Rest at $I_3$ with wind $\sigma=0.05$ | 95% conv | $K_R=1.50$, ratio 1.000 ✓ | + +--- + +## 7. Key non-obvious findings + +:::info +**Hamiltonian degeneracy**: The free-body Hamiltonian $H=\frac{1}{2}\omega^\top M^{-1}\omega$ is unchanged under $M\to\alpha M$. Without regularisation, the optimizer finds a $(M_\text{wrong}, V_\text{compensating})$ saddle where the geodesic loss converges but $M$ is unphysical. Fixed by $\|V\|$ regularisation + `FixedInertiaFromState`. +::: + +:::warning +**ZOH discrete instability**: The gain bound from continuous-time analysis ($K_p$ any positive value) does not apply to the digital controller. Zero-order hold at $T=0.05$ s gives $z_1 = 1 - K_p T/I_1$; stability requires $K_p < 2I_1/T = 0.229$. Discovered empirically: $K_p=0.50$ diverges to 14 rad/s with $|z_1|=3.39$. +::: + +:::success +**Orientation–velocity duality**: The steady-state wind offset formula $\|\text{error}\|_{ss} = |d|/\text{gain}$ holds in **both** state spaces. In velocity control ($\omega^*\neq 0$): offset is in $\omega$, proportional to $1/K_p$. In attitude control ($\omega^*=0$): offset is in $R$, proportional to $1/K_R$. Ratio exact to 3 decimal places in both cases. +::: + +:::success +**Universal ZOH bound**: $K_{p,\max} = 2I_1/T$ is the same regardless of which axis is the control target, because $I_1$ (the smallest inertia) is always the first mode to go unstable. Confirmed: $K_p=0.50$ diverges on $e_1$, $e_2$, and $e_3$ hold with identical late-trajectory error $\approx 14$ rad/s. +::: + +--- + +## 8. Relationship to the broader LieSPHGP project + +This work parallels the 3D SO(3) Windy Pendulum model in the same repository, which is the primary case study. The tennis racket model: + +- **Shares** the `DissipativeSO3HamNODE` base network architecture and loss utilities +- **Extends** with a multi-geometry state (18D vs 15D) via the embedded inertia approach +- **Adds** `FixedInertiaFromState` as a novel module resolving the identifiability problem +- **Demonstrates** IDA-PBC control design through Stages D–F, which has not yet been applied to the pendulum + +The key structural lesson — that the free-body Hamiltonian degeneracy requires either regularisation ($V\to 0$) or structural pinning (`FixedInertiaFromState`) — is applicable to any rigid-body system without a gravity potential to break the symmetry. diff --git a/docs/hackmd_reference.md b/docs/hackmd_reference.md new file mode 100644 index 0000000..913f47b --- /dev/null +++ b/docs/hackmd_reference.md @@ -0,0 +1,535 @@ +# Dzhanibekov Effect — Full Technical Reference + +*Personal reference: every formula, file, decision, and fix. Last updated 2026-06-26.* + +[TOC] + +--- + +## 0. Repo layout + +``` +LieSPHGP/ +├── envs/ +│ ├── tennis_racket_3d.py # Gymnasium env (frictionless) +│ └── tennis_racket_3d_friction.py # Stage B env (with friction) +├── datasets/ +│ ├── tennis_racket_3d_datagen.py # Multi-config dataset generator +│ └── tennis_racket_3d_datagen_friction.py +├── src/ +│ ├── utils/ +│ │ └── loss_utils.py # geodesic loss, safe logm +│ └── models/ +│ ├── 3D_SO3_Windy_Pendulum/ph_nn_ode_v2/ +│ │ └── network.py # shared DissipativeSO3HamNODE base +│ └── 3D_SO3_Tennis_Racket/ph_nn_ode_v2/ +│ ├── network_stageC.py # 18D state, FixedInertiaFromState +│ ├── train.py # Stage A (single config, fix_M) +│ ├── train_e2e.py # Stage A end-to-end +│ ├── train_stageB.py # Stage B (friction) +│ ├── train_stageC.py # Stage C (multi-config) +│ ├── eval_stageC_5th.py # held-out geometry test +│ ├── controller_stageD.py # PDBodyFrameController +│ ├── controller_stageF.py # GeometricAttitudeController +│ ├── simulate_stageD.py # Stage D simulation + K_p scan +│ ├── simulate_stageE.py # Stage E (A, B, C scenarios) +│ ├── simulate_stageE_v2.py # Stage E v2 (+ e1+wind, e3+wind) +│ └── simulate_stageF.py # Stage F (full (R*, ω*) stabilisation) +└── docs/ + ├── hackmd_advisor.md # ← advisor-facing summary + └── hackmd_reference.md # ← this file +``` + +**Python env**: `/Users/katesur/Projects/LieSPHGP/venv/bin/python3` +Always run with `-u` for unbuffered output: `python3 -u script.py` + +--- + +## 1. Physics + +### 1.1 Euler's equations (body frame) + +$$ +I\dot{\omega} = -\omega \times (I\omega) + \tau +$$ +$$ +\dot{R} = R\,\widehat{\omega} +$$ + +$R\in SO(3)$, $\omega\in\mathbb{R}^3$ body-frame angular velocity, $I = \mathrm{diag}(I_1,I_2,I_3)$, $\widehat{\omega}$ skew-symmetric hat map. + +**Hat map**: $\widehat{v} = \begin{pmatrix}0&-v_3&v_2\\v_3&0&-v_1\\-v_2&v_1&0\end{pmatrix}$, so $\widehat{v}\,w = v\times w$. + +**Vee map** (inverse): $\mathrm{vee}(\Omega)_i = \Omega_{jk}$ with $(i,j,k)$ cyclic, concretely $\mathrm{vee}(\Omega) = [\Omega_{32}, \Omega_{13}, \Omega_{21}]^\top$. + +### 1.2 Principal axis stability (linearisation) + +Steady spin about $e_k$ at rate $\omega_k^*$. Perturb in the plane $\{e_i, e_j\}$ with $(i,j,k)$ distinct. + +From Euler: $I_i\delta\dot{\omega}_i = (I_j - I_k)\omega_k^*\,\delta\omega_j$, $I_j\delta\dot{\omega}_j = (I_k - I_i)\omega_k^*\,\delta\omega_i$. + +Eigenvalues: + +$$ +\lambda^2 = \frac{(I_j - I_k)(I_k - I_i)}{I_i\, I_j}\,(\omega_k^*)^2 +$$ + +- $\lambda^2 < 0$: $\lambda$ imaginary, **stable** oscillation ($e_1$ and $e_3$ for our racket) +- $\lambda^2 > 0$: $\lambda$ real, **saddle** unstable ($e_2$ — intermediate axis) + +For cfg0 at $\omega_2^* = 2\pi$ rad/s: $\lambda = \pm 2.40$ s$^{-1}$, doubling time $0.29$ s. + +### 1.3 Tennis racket geometry (cfg0) + +``` +head semi-axes: a=0.195 m (long, e₁), b=0.135 m (short, e₂) +handle: L=0.32 m, r=0.012 m +total mass: 0.312 kg +``` + +| Axis | Inertia (kg·m²) | $1/I_i$ (m$^{-2}$·kg$^{-1}$) | +|------|----------------|-----------------------------| +| $e_1$ (long head) | $5.719\times10^{-3}$ | 174.8 | +| $e_2$ (short head) | $1.121\times10^{-2}$ | 89.2 | +| $e_3$ (handle) | $1.322\times10^{-2}$ | 75.6 | + +Asymmetry index $(I_3-I_1)/I_2 = 0.67$ — large enough to produce visible flipping at $\omega_2^*=2\pi$ rad/s. + +--- + +## 2. Port-Hamiltonian structure + +### 2.1 Hamiltonian + +$$ +H(R,\omega) = \underbrace{\tfrac{1}{2}\,\omega^\top M^{-1}(R)\,\omega}_{\text{kinetic}} + \underbrace{V(R)}_{\text{potential}} +$$ + +For a free body: $M^{-1} = \mathrm{diag}(1/I_1,1/I_2,1/I_3)$ (constant), $V = 0$. + +### 2.2 Equations of motion in pH form + +$$ +\dot{x} = [J(x) - D(x)]\nabla_x H + g(R)\,u +$$ + +The $\omega$-subsystem (expanding the Lie–Poisson bracket): + +$$ +I\dot{\omega} = -\widehat{\omega}(I\omega) - D_w\,\omega + g\,u +$$ + +which is exactly Euler's equations with damping $D_w$ and input coupling $g$. + +**Gradients**: $\nabla_\omega H = M^{-1}\omega$ (the angular momentum direction). $\nabla_R H$ involves $\partial V/\partial R$ and $\frac{1}{2}\omega^\top\frac{\partial M^{-1}}{\partial R}\omega$ (only matters when $M^{-1}$ or $V$ depend on $R$; zero for the free body). + +### 2.3 Why four sub-networks + +Every term in $[J - D]\nabla H + gu$ must be learned: + +| Term | Network | Notes | +|------|---------|-------| +| $M^{-1}(R)$ | `M_net` | PSD constraint via Cholesky $LL^\top + \epsilon I$ | +| $V(R)$ | `V_net` | potential energy; must be regularised to 0 for free body | +| $D_w(R)$ | `Dw_net` | PSD; learns friction if present; $\approx 0$ for frictionless | +| $g(R)$ | `g_net` | $3\times3$; converges to $I_3$ for direct-torque actuator | + +--- + +## 3. State representations + +| Stage | Dim | Layout | +|-------|-----|--------| +| A, B | 15 | $[\text{vec}(R)_9,\, \omega_3,\, u_3]$ | +| C | 18 | $[\text{vec}(R)_9,\, I_1 I_2 I_3,\, \omega_3,\, u_3]$ | + +**Inertia embedding**: writing $I_1,I_2,I_3$ directly into the state vector (cols 9:11) lets one model serve all racket geometries. At training and inference time, `FixedInertiaFromState` reads those values and returns the exact physical $M^{-1}$. + +**`strip_inertia`**: the loss function expects 15D input. Before computing loss, remove cols 9:11 from the 18D vector. The `split=[9,3,3]` (vec(R), ω, u) is unchanged after stripping. + +--- + +## 4. File-by-file descriptions + +### `envs/tennis_racket_3d.py` +Gymnasium env. Lie-group Heun integrator, 10 substeps per `dt=0.05 s`. Observation: `obs = [vec(R)_9, ω_3]` ∈ ℝ¹². Action: body-frame torque `u` ∈ ℝ³ (clipped to ±2 N·m). + +Key: `tau = u + disturbance_torque` — so $g = I_3$ exactly and any `disturbance_torque_std` adds mean-zero noise to the applied torque. + +`reset(options={"axis": k})` → random $R_0$, spin $\omega_0 = \omega^*_k e_k + \epsilon$. +`reset(options={"R_init": R})` → use specified $R_0$. +`get_inertia_info()` → dict of $I_1, I_2, I_3$, head dims, mass. + +### `network_stageC.py` +Defines `DissipativeSO3HamNODE` for 18D state. Contains `FixedInertiaFromState`: + +```python +class FixedInertiaFromState(nn.Module): + """M_net replacement that reads I from state cols 9:12 exactly.""" + def forward(self, q_ext): # q_ext: (N, 12) [vec(R), I1, I2, I3] + return torch.diag_embed(1.0 / q_ext[:, 9:12]) # (N, 3, 3) +``` + +Activated by passing `fix_M=True` to the model constructor. + +### `train_stageC.py` +Training loop for Stage C. Key flags: + +| Flag | Effect | +|------|--------| +| `--fix_M` | Use `FixedInertiaFromState` instead of learned `M_net` | +| `--lambda_V_zero 1.0` | Add $\|V_\theta\|^2$ regularisation (needed for e2e) | +| `--total_steps 5000` | Number of gradient steps | +| `--num_points 5` | Window size for windowed geodesic loss | +| `--n_samples 25` | Trajectories per batch | + +Checkpoint path: `data/run_tr3d_stageC_fp32/.../tr3d-so3ham-rk4-5p.tar` + +### `controller_stageD.py` +`PDBodyFrameController`: proportional ω-only controller. + +```python +def __call__(self, R_flat, omega): + tau = self.Kp * (self.omega_star - np.asarray(omega)) + return np.clip(tau, -self.clip, self.clip) +``` + +No model used. `use_model_g=False` (default); setting `True` would apply `g_net⁻¹` via `lstsq`, but since $g \approx I_3$ this makes no difference. + +### `controller_stageF.py` +`GeometricAttitudeController`: full $(R^*, \omega^*)$ stabiliser. + +Key helpers: +```python +def logm_SO3(R): + """Rodrigues formula; handles θ≈0 and θ≈π edge cases.""" + cos_theta = np.clip((np.trace(R) - 1.0) / 2.0, -1.0, 1.0) + theta = float(np.arccos(cos_theta)) + if theta < 1e-7: + return (R - R.T) / 2.0 # first-order: near identity + if abs(theta - np.pi) < 1e-4: # antipodal case + sym = (R + np.eye(3)) / 2.0 # = n nᵀ + i = int(np.argmax(np.diag(sym))) + n = sym[:, i] / np.sqrt(sym[i, i]) + return theta * hat(n) + return theta / (2.0 * np.sin(theta)) * (R - R.T) + +def vee(Omega): # Omega skew-symmetric + return np.array([Omega[2,1], Omega[0,2], Omega[1,0]]) +``` + +Controller call: +```python +def __call__(self, R_flat, omega): + e_R = vee(logm_SO3(self.R_star.T @ R_flat.reshape(3,3))) + e_w = np.asarray(omega) - self.omega_star + u = -self.K_R * e_R - self.K_p * e_w + return np.clip(u, -self.clip, self.clip) +``` + +When `K_R = 0`: reduces exactly to `PDBodyFrameController`. Confirmed (Stage F Scenario C = Stage D, 0.30 s, 100%). + +### `simulate_stageD.py` +K_p scan, linearisation printout. Bug fix history: +- **Bug 1**: overdamped threshold formula `sqrt(a12*a21)` returned NaN (product < 0). Fix: `sqrt(abs(a12*a21))`. +- **Bug 2**: `best_Kp = max(all_results, ...)` picked first (slowest) not fastest. Fix: filter ≥95% conv, then `min(..., key=conv_times.mean)`. + +### `simulate_stageE.py` / `simulate_stageE_v2.py` +E.py: Scenarios A (e₁ redirect), B (e₂ hold no wind), C (e₂ hold with wind). +E_v2.py: adds D (e₁ hold with wind), E (e₃ hold with wind). + +Key functions: +```python +def kp_max_zoh(I1, dt): + return 2.0 * I1 / dt # = 0.229 for cfg0 + +def kp_min_e2(I1, I2, I3, omega_star): + return omega_star * np.sqrt(abs(I2 - I3) * abs(I1 - I2)) # = 0.021 +``` + +Bug fix: `scenario_C` verdict checked `Kp=0.10` hardcoded (was 80% conv). Fix: find best-converging Kp with `late_max < 0.5` across the scan. + +### `simulate_stageF.py` +Scenarios: A (R*=I, ω*=0), B (R*=Ry(π/4), ω*=0), C (K_R=0 sanity), D (with wind, K_R scan). + +```python +def print_linearised_stability(I1, I2, I3, K_R, K_p, dt): + # Per axis: ωn = sqrt(K_R/I), ζ = K_p / (2*sqrt(K_R*I)) + # ZOH K_p_max = 2*I1/dt, ZOH K_R_max = K_p/dt +``` + +Bug fix: first run labeled wind scenario "UNSTABLE" because `late_eR > theta_thr*5`. Changed to detect truly growing errors (2nd-half mean vs initial value) vs bounded SS offset. + +--- + +## 5. Training stage details + +### Stage A — `train.py` + +```bash +python3 -u train.py --fix_M --lambda_V_zero 0 --total_steps 5000 +``` + +Dataset: `tr3d_dataset_dist0p0_obs_noise0p0_perturb0p05_ncfg4_steps100.pkl` +Shape: `(4, 100, 25, 15)` — 4 configs, 100 timesteps, 25 trajectories, 15D state. + +**Loss terms**: +- `geo_loss`: windowed geodesic loss on $SO(3)$ +- `V_loss`: MSE of `V_net` output vs 0 (when `--lambda_V_zero` set) +- `M_loss`: MSE of `M_net` output vs exact $M_\text{tgt}$ (only with `--fix_M`) + +Result at 5000 steps: `windowed_geo = 2.03e-6`, `V_loss ≈ 0`, `Dw_loss ≈ 0`. + +### Stage B — `train_stageB.py` + +Dataset with `fric_coeff=0.01`: `tau_fric = -0.01 * omega`. +`Dw_net` learns dissipation unsupervised. +Key: friction makes trajectories contractive → near-zero variance → `Dw_loss = 5.13e-14`. + +### Stage C — `train_stageC.py` + +Four racket geometries trained jointly. Checkpoint naming: +``` +run_tr3d_stageC_fp32/obs0_dist0_cfgALL_lP0_lV0_lB0_lD0_lr0p001_s5000_np5_smp25_T100_rk4_seed0_fixM_YYMMDD-HHMMSS/tr3d-so3ham-rk4-5p.tar +``` + +`strip_inertia(x)`: removes cols 9:12 from shape `(*, 18)` to get `(*, 15)` for loss. + +**Identifiability note**: without `--fix_M`, `M_net` output initialises near $\epsilon I \approx 1 \cdot I_3$ (Cholesky output near zero → $L L^\top \approx 0 + \epsilon I$). True targets are $O(100)$ (e.g. $1/I_1 = 174.8$). This is why a high learning rate `lr=0.1` is needed for `M_net` pretraining, not the default `lr=1e-3`. + +### Stage C evaluation — `eval_stageC_5th.py` + +Evaluates on a 5th geometry not seen during training. Result: `windowed_geo = 2.03e-6`. This is exact because `FixedInertiaFromState` reads the embedded $I$ directly — generalisation is by construction, not by the network having learned geometry. + +--- + +## 6. Control mathematics + +### 6.1 IDA-PBC matching equation + +For port-Hamiltonian $\dot{x} = [J-D]\nabla H + gu$, find $u$ such that the closed-loop equals $[J_d - D_d]\nabla H_d$: + +$$ +g\,u = [J_d - D_d]\nabla H_d - [J - D]\nabla H +$$ + +For $g = I_3$: + +$$ +u = [J_d - D_d]\nabla H_d - [J - D]\nabla H +$$ + +**Stage D choice**: $H_d = H + \frac{1}{2}K_p\|\omega-\omega^*\|^2$, $J_d = J$, $D_d = D + K_p I$. + +$\nabla_\omega H_d = M^{-1}\omega + K_p(\omega-\omega^*)$. Substituting, the $[J_d-D_d]\nabla H_d$ and $[J-D]\nabla H$ terms cancel except for: + +$$ +u = K_p(\omega^* - \omega) +$$ + +**Stage F choice**: additionally add $\frac{1}{2}K_R\|e_R\|^2$ to $H_d$. Since $\nabla_\omega(\|e_R\|^2) \approx 0$ (orientation error doesn't depend on $\omega$ in the body frame), the additional term enters only through the $\nabla_R H_d$ equation. The result: + +$$ +u = -K_R e_R - K_p(\omega-\omega^*) +$$ + +### 6.2 ZOH discrete-time analysis + +**First-order (ω loop only, Stage D/E)**: + +The discrete map per axis $i$ at $T = dt$: + +$$ +\omega_i[k+1] = \omega_i[k] - (K_p T / I_i)\,(\omega_i[k] - \omega_i^*) = (1 - K_p T/I_i)\omega_i[k] + \ldots +$$ + +Stability: $|z_i| = |1 - K_p T/I_i| < 1$, i.e. $0 < K_p T/I_i < 2$: + +$$ +\boxed{K_{p,\max} = \frac{2I_1}{T} = \frac{2\times5.72\times10^{-3}}{0.05} = 0.229 \text{ N·m·s/rad}} +$$ + +**Second-order (R + ω loop, Stage F)**: + +Characteristic polynomial of the ZOH-sampled 2nd-order system: + +$$ +z^2 - \underbrace{\left(2 - \frac{K_p T}{I_i}\right)}_{\alpha} z + \underbrace{\left(1 - \frac{K_p T}{I_i} + \frac{K_R T^2}{I_i}\right)}_{\beta} = 0 +$$ + +Schur stability conditions: $|\beta| < 1$ and $|\alpha| < 1 + \beta$. The binding condition from $\beta < 1$: + +$$ +\boxed{K_{R,\max} = K_p / T = 0.10/0.05 = 2.00 \text{ N·m/rad}} +$$ + +### 6.3 Saddle stabilisation (Stage E) + +Closed-loop linearisation about $\omega^* = \omega_2^* e_2$ with $u = K_p(\omega^*-\omega)$: + +$$ +A_{cl} = \begin{pmatrix} -K_p/I_1 & (I_2-I_3)\omega_2^*/I_1 \\ (I_1-I_2)\omega_2^*/I_3 & -K_p/I_3 \end{pmatrix} +$$ + +$\text{tr}(A) = -K_p(1/I_1 + 1/I_3) < 0$ always. + +$\det(A) = K_p^2/(I_1 I_3) - (I_2-I_3)(I_1-I_2)(\omega_2^*)^2/(I_1 I_3)$ + +Note: $(I_2-I_3) < 0$ and $(I_1-I_2) < 0$ so the product is **positive**. For $\det > 0$: + +$$ +\boxed{K_{p,\min} = \omega_2^*\sqrt{|I_2-I_3|\cdot|I_1-I_2|} = 0.021 \text{ N·m·s/rad}} +$$ + +Empirical value: $K_p = 0.020$ gives ≥95% conv; $K_p = 0.015$ gives 90%. + +### 6.4 Steady-state offset under constant wind + +**Stage E** ($\omega^* \neq 0$, proportional ω controller): at SS, $\dot{\omega} = 0$: + +$$ +0 = -K_p(\omega_{ss} - \omega^*) + d \;\Rightarrow\; \|\delta\omega\|_{ss} = |d|/K_p +$$ + +The orientation $R$ is unconstrained and drifts freely. + +**Stage F** ($\omega^* = 0$, attitude+velocity controller): at SS, $\dot{\omega} = 0$ and $\dot{R} = R\widehat{\omega} = 0$: + +$$ +0 = -K_R e_R - K_p\,0 + d \;\Rightarrow\; \|e_R\|_{ss} = |d|/K_R, \quad \|\omega\|_{ss} = 0 +$$ + +The angular velocity is zero; the body tilts until the orientation restoring torque balances the wind. + +Measured ratios $\|e_R\|_\text{actual}/(|d|/K_R)$ at Stage F Scenario D: + +| $K_R$ | Ratio | +|-------|-------| +| 0.10 | 1.000 | +| 0.50 | 1.000 | +| 1.00 | 1.000 | +| 1.50 | 1.000 | + +Predicted window for $\|e_R\|_{ss} < \theta_\text{thr} = 0.10$ rad with $|d| \approx 0.082$ N·m: + +$$ +K_R \in \bigl(0.082/0.10,\; 0.10/0.05\bigr) = (0.82,\; 2.00) +$$ + +Best operating point: $K_R = 1.50$ (middle of window, 95% conv in 1.80 s). + +### 6.5 SO(3) geometry for Stage F + +**Matrix logarithm** $\log: SO(3) \to \mathfrak{so}(3)$ via Rodrigues: + +$$ +\log R = \frac{\theta}{2\sin\theta}(R - R^\top), \quad \theta = \arccos\!\left(\frac{\text{tr}R - 1}{2}\right) +$$ + +Edge cases: +- $\theta \approx 0$: $\log R \approx (R-R^\top)/2$ (first-order expansion) +- $\theta \approx \pi$ (antipodal): $R = 2\mathbf{n}\mathbf{n}^\top - I$ for rotation axis $\mathbf{n}$; reconstruct $\mathbf{n}$ from symmetric part $(R+I)/2 = \mathbf{n}\mathbf{n}^\top$ + +**Attitude error**: $e_R = \mathrm{vee}(\log(R^{*\top}R)) \in \mathbb{R}^3$; $\|e_R\| = \theta_\text{err} \in [0,\pi]$. + +**Almost-global stability**: $\|e_R\| = \pi$ is the only failure point (antipodal configuration). On SO(3) this is a closed set of measure zero; starting from any smooth distribution of initial conditions the probability of hitting it exactly is zero. + +--- + +## 7. Complete results tables + +### Stage D — K_p scan (20 trials, axis=1 start, no wind) + +| $K_p$ | Conv time (s) | Final $\|\omega-\omega^*\|$ | % conv | Mode | +|-------|--------------|--------------------------|--------|------| +| 0.01 | 2.97 ± 0.xx | 0.002 | 100 | oscillatory | +| 0.05 | 0.65 | ~0 | 100 | oscillatory | +| **0.10** | **0.30** | **~0** | **100** | overdamped ✓ | +| 0.20 | 0.10 | ~0 | 100 | overdamped | + +Overdamped threshold: $K_p \geq 0.071$ (from $\text{disc}(A_{cl}) = 0$). + +### Stage E — scenarios + +| Scenario | Target | Wind | Kp range | Best Kp | % conv | SS ratio | +|----------|--------|------|----------|---------|--------|---------| +| A | $e_1$ redirect | — | 0.10 | 0.10 | 100 | — | +| B | $e_2$ hold | — | 0.010–0.500 | 0.025 | 100 | — | +| C | $e_2$ hold | 0.05 | 0.05–0.50 | 0.20 | 95 | 1.019 | +| D | $e_1$ hold | 0.05 | 0.05–0.50 | 0.20 | 95 | 0.987 | +| E | $e_3$ hold | 0.05 | 0.05–0.50 | 0.20 | 95 | 0.999 | + +ZOH instability at $K_p=0.50$: late_max ≈ 14 rad/s on all three axes. + +### Stage F — scenarios + +| Scenario | $(R^*, \omega^*)$ | $(K_R, K_p)$ | % conv | Conv time | Notes | +|----------|------------------|-------------|--------|-----------|-------| +| A | $(I_3, \mathbf{0})$ | (0.10, 0.10) | 100 | 2.75 s | | +| B | $(R_y(\pi/4), \mathbf{0})$ | (0.10, 0.10) | 100 | 2.82 s | arbitrary target orientation | +| C (sanity) | $(I_3, (0,0,2\pi))$ | (0, 0.10) | 100 | 0.30 s | = Stage D exactly | +| D wind σ=0.05 | $(I_3, \mathbf{0})$ | (1.50, 0.10) | 95 | 1.80 s | ratio=1.000 | + +--- + +## 8. Run commands (all stages) + +```bash +VENV=/Users/katesur/Projects/LieSPHGP/venv/bin/python3 +MODEL=src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2 + +# ── Data ────────────────────────────────────────────────────────────────────── +$VENV datasets/tennis_racket_3d_datagen.py \ + --ncfg 4 --timesteps 100 --trials 25 \ + --perturb_std 0.05 --dist_std 0.0 --obs_noise 0.0 + +# ── Stage A ─────────────────────────────────────────────────────────────────── +$VENV -u $MODEL/train.py --fix_M --total_steps 5000 + +# ── Stage A e2e ─────────────────────────────────────────────────────────────── +$VENV -u $MODEL/train_e2e.py --lambda_V_zero 1.0 --total_steps 5000 + +# ── Stage B ─────────────────────────────────────────────────────────────────── +$VENV -u $MODEL/train_stageB.py --fix_M --total_steps 5000 + +# ── Stage C ─────────────────────────────────────────────────────────────────── +$VENV -u $MODEL/train_stageC.py --fix_M --total_steps 5000 \ + --num_points 5 --n_samples 25 + +# ── Stage C eval (held-out 5th geometry) ───────────────────────────────────── +$VENV -u $MODEL/eval_stageC_5th.py + +# ── Stage D ─────────────────────────────────────────────────────────────────── +$VENV -u $MODEL/simulate_stageD.py --Kp_scan "0.01,0.05,0.10,0.20" + +# ── Stage E ─────────────────────────────────────────────────────────────────── +$VENV -u $MODEL/simulate_stageE.py # A, B, C +$VENV -u $MODEL/simulate_stageE_v2.py # A–E incl. wind + +# ── Stage F ─────────────────────────────────────────────────────────────────── +$VENV -u $MODEL/simulate_stageF.py +$VENV -u $MODEL/simulate_stageF.py --n_steps 400 --theta_thr 0.05 # stricter +``` + +--- + +## 9. Known issues / gotchas + +**`M_net` init too small**: Cholesky-based PSD output initialises near $\epsilon I$ (values $\approx 10^{-3}$). True $1/I_i \approx O(100)$. Need `lr=0.1` for `M_net` in standalone pretraining; `--fix_M` sidesteps this entirely. + +**`split=[9,3,3]` unchanged in Stage C**: The loss `split` always describes the 15D stripped state (vec(R), ω, u). Never include the embedded inertia in `split`. + +**`env._disturbance_torque`**: This attribute stores the constant disturbance sampled at each `env.reset()`. Use it to compute the per-trial disturbance norm for the SS offset verification. + +**Stage F: `K_R=0.01` gives `BOUNDED_SS` not `CONVERGED`**: With initial $\|e_R\|\approx 2.2$ rad and $\omega_n = 1.32$ rad/s, the convergence time to $\|e_R\|<0.10$ rad is $\approx 2.2/1.32 \times 2\zeta = 2.2/1.32 \times 13.2 \approx 22$ s. Our 15 s window is too short. Run with `--n_steps 600` to confirm convergence. + +**`simulate_stageF.py`: `scenario_D_wind` creates a fresh `Env` per $K_R$ value** with the same seed, ensuring all $K_R$ values see identical disturbance vectors — necessary for the clean ratio comparison. + +--- + +## 10. Next possible directions + +- **Stage F with `g_net⁻¹`**: For non-trivial input coupling (e.g. reaction wheels), $u = g_\theta(R)^{-1}(-K_R e_R - K_p e_\omega)$ would use the learned model at control time. Currently redundant because $g \approx I_3$. +- **Model-predictive control (MPPI/iLQR)**: Use Stage C neural ODE as a differentiable rollout model for trajectory optimisation. +- **Apply to the Windy Pendulum**: Port Stages D–F controllers to the 3D SO(3) pendulum, which has a non-trivial gravity potential $V(R) \neq 0$ — the learned $V_\text{net}$ would directly enter the IDA-PBC matching equation. +- **Continuous-time robust control**: The ZOH instability suggests that designing in continuous time then discretising is unsafe for this class of systems. A proper discrete-time IDA-PBC design would be more principled. diff --git a/envs/tennis_racket_3d.py b/envs/tennis_racket_3d.py new file mode 100644 index 0000000..34d5a01 --- /dev/null +++ b/envs/tennis_racket_3d.py @@ -0,0 +1,691 @@ +from __future__ import annotations + +"""Free rigid body (tennis racket) environment on SO(3). + +Models a tennis racket as an elliptical hoop (head) + cylindrical rod (handle). +The principal moments of inertia are computed analytically from geometry. +The equations of motion are Euler's equations for a torque-free (or torqued) +rigid body, integrated with the same Lie-group Heun scheme as windy_pendulum_3d. + +State : (R ∈ SO(3), ω ∈ ℝ³) → obs vector length 12 (same as pendulum) +Action : τ ∈ ℝ³ (body-frame torque) length 3 + +Body-frame principal axes convention +------------------------------------- + e₁ — long axis of head (largest semi-axis a, SMALLEST moment I₁) + e₂ — short axis of head (semi-axis b, INTERMEDIATE I₂) + e₃ — handle axis (out-of-plane, LARGEST moment I₃) + +This ordering guarantees I₁ < I₂ < I₃ for realistic racket geometries, +so rotation about e₂ is the unstable intermediate axis (Dzhanibekov effect). + +Inertia formulae (thin-shell / thin-rod approximations) +--------------------------------------------------------- +Elliptical hoop (mass m_h, semi-axes a, b): + I_hoop_1 = (1/2) m_h b² (about e₁) + I_hoop_2 = (1/2) m_h a² (about e₂) + I_hoop_3 = (1/2) m_h (a² + b²) (about e₃) + +Thin rod (mass m_r, length L, radius r_r ≪ L): + COM of rod is offset d = a + L/2 from racket COM (along -e₁ / handle dir) + I_rod_1 = (1/12) m_r L² + m_r d² (parallel axis, about e₁) + I_rod_2 = (1/12) m_r L² + m_r d² (same by symmetry) + I_rod_3 = (1/2) m_r r_r² (about handle axis ≈ 0 for thin rod) + +Total: I_i = I_hoop_i + I_rod_i for i = 1, 2, 3. + +Wind / friction (commented out — ready to enable) +--------------------------------------------------- +The wind block mirrors windy_pendulum_3d exactly. To enable, uncomment the +sections marked # [WIND] and # [FRICTION] and pass the relevant kwargs. +""" + +from typing import Optional, Tuple, Union + +import numpy as np +import gymnasium as gym +from gymnasium import spaces + + +# ────────────────────────────────────────────────────────────────────────────── +# SO(3) Lie group utilities (identical to windy_pendulum_3d) +# ────────────────────────────────────────────────────────────────────────────── + +def _hat(w: np.ndarray) -> np.ndarray: + """Skew-symmetric matrix (hat map) for w in R^3.""" + wx, wy, wz = w + return np.array([[0.0, -wz, wy], + [wz, 0.0, -wx], + [-wy, wx, 0.0]], dtype=np.float64) + + +def _vee(W: np.ndarray) -> np.ndarray: + """Inverse hat map (vee).""" + return np.array([W[2, 1], W[0, 2], W[1, 0]], dtype=np.float64) + + +def _exp_so3(phi: np.ndarray) -> np.ndarray: + """Matrix exponential on so(3) via Rodrigues' formula.""" + theta_sq = np.dot(phi, phi) + theta = np.sqrt(theta_sq) + Phi = _hat(phi) + if theta < 1e-10: + A = 1.0 - theta_sq / 6.0 + B = 0.5 - theta_sq / 24.0 + else: + A = np.sin(theta) / theta + B = (1.0 - np.cos(theta)) / theta_sq + return np.eye(3) + A * Phi + B * (Phi @ Phi) + + +def _log_so3(R: np.ndarray) -> np.ndarray: + """Logarithmic map on SO(3).""" + cos_theta = np.clip(0.5 * (np.trace(R) - 1.0), -1.0, 1.0) + theta = np.arccos(cos_theta) + if theta < 1e-10: + return _vee(0.5 * (R - R.T)) + elif abs(theta - np.pi) < 1e-6: + M = R + np.eye(3) + norms = np.linalg.norm(M, axis=0) + k = np.argmax(norms) + v = M[:, k] / norms[k] + return v * theta + else: + return _vee(theta / (2.0 * np.sin(theta)) * (R - R.T)) + + +def _project_to_so3(R: np.ndarray) -> np.ndarray: + """Project to nearest rotation matrix via SVD (safety net).""" + U, _, Vt = np.linalg.svd(R) + Rp = U @ Vt + if np.linalg.det(Rp) < 0: + U[:, -1] *= -1.0 + Rp = U @ Vt + return Rp + + +def _random_rotation(rng: np.random.Generator) -> np.ndarray: + """Uniform random rotation via Shoemake's method.""" + u1, u2, u3 = rng.random(3) + q1 = np.sqrt(1 - u1) * np.sin(2 * np.pi * u2) + q2 = np.sqrt(1 - u1) * np.cos(2 * np.pi * u2) + q3 = np.sqrt(u1) * np.sin(2 * np.pi * u3) + q4 = np.sqrt(u1) * np.cos(2 * np.pi * u3) + x, y, z, w = q1, q2, q3, q4 + return np.array([ + [1 - 2*(y*y + z*z), 2*(x*y - z*w), 2*(x*z + y*w)], + [ 2*(x*y + z*w), 1 - 2*(x*x + z*z), 2*(y*z - x*w)], + [ 2*(x*z - y*w), 2*(y*z + x*w), 1 - 2*(x*x + y*y)], + ], dtype=np.float64) + + +# ────────────────────────────────────────────────────────────────────────────── +# Inertia computation +# ────────────────────────────────────────────────────────────────────────────── + +def compute_racket_inertia( + head_a: float, + head_b: float, + handle_length: float, + handle_radius: float, + total_mass: float, + head_mass_ratio: float, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Compute the principal inertia tensor for the tennis-racket model. + + The racket is modelled as: + • An elliptical hoop (head) with semi-axes a (long) × b (short), + lying in the e₁-e₂ plane. + • A thin cylindrical rod (handle) of length L along -e₁, + with its proximal end at x = -a from the racket COM. + + Parameters + ---------- + head_a : semi-axis along e₁ (long direction of head), metres + head_b : semi-axis along e₂ (short direction of head), metres + handle_length : length of handle rod, metres + handle_radius : radius of handle rod (only affects I₃ of rod), metres + total_mass : total racket mass, kg + head_mass_ratio : fraction of total mass in the head hoop, (0, 1) + + Returns + ------- + I : (3,3) diagonal inertia tensor [I₁, I₂, I₃] in body frame + I_inv : inverse of I + """ + m_h = total_mass * head_mass_ratio # head mass + m_r = total_mass * (1.0 - head_mass_ratio) # handle mass + + # ── Head (elliptical hoop) ── + # Thin hoop: all mass on the ellipse perimeter. + # For a uniform elliptical ring the MOI about its own axes are: + # about e₁ (long axis in plane) : I = (1/2) m b² + # about e₂ (short axis in plane): I = (1/2) m a² + # about e₃ (normal to plane) : I = (1/2) m (a²+b²) + I_h1 = 0.5 * m_h * head_b**2 + I_h2 = 0.5 * m_h * head_a**2 + I_h3 = 0.5 * m_h * (head_a**2 + head_b**2) + + # ── Handle (thin rod) ── + # Rod COM is at distance d = head_a + handle_length/2 from racket COM + # (along -e₁). Parallel-axis theorem shifts the rod's own MOI. + d = head_a + 0.5 * handle_length # offset along e₁ + I_rod_own = (1.0 / 12.0) * m_r * handle_length**2 # rod about its midpoint + + I_r1 = I_rod_own + m_r * d**2 # about e₁ (⊥ to rod, in head plane) + I_r2 = I_rod_own + m_r * d**2 # about e₂ (same by symmetry) + I_r3 = 0.5 * m_r * handle_radius**2 # about e₃ ≈ 0 for thin rod + + # ── Total in original geometric axes [e1_geom, e2_geom, e3_geom] ── + I_geom = np.array([ + I_h1 + I_r1, + I_h2 + I_r2, + I_h3 + I_r3, + ], dtype=np.float64) + + # Sort so body-frame axes are always labeled: + # axis 0 = smallest I, axis 1 = intermediate I, axis 2 = largest I + perm = np.argsort(I_geom) + I_sorted = I_geom[perm] + + # Permutation matrix: columns are old geometric axes selected for new body axes + P = np.eye(3)[:, perm] + + I_diag = np.diag(I_sorted) + return I_diag, np.linalg.inv(I_diag), P + + +# ────────────────────────────────────────────────────────────────────────────── +# Environment +# ────────────────────────────────────────────────────────────────────────────── + +class tennis_racket_3d(gym.Env): + """Free rigid body (tennis racket) on SO(3). + + Equations of motion (body frame): + I ω̇ = -ω × (I ω) + τ (Euler's equations) + Ṙ = R · hat(ω) + + No gravity (free body in space). Optional external torque (constant + per trajectory, sampled from N(0, disturbance_torque_std²·I)) models + slow aerodynamic disturbances. + + State obs (12,): [R.flatten() (9), ω (3)] + Action (3,): body-frame torque τ + + Principal axes convention (body frame): + axis 0 (e₁): long head axis — smallest I — stable rotation + axis 1 (e₂): short head axis — middle I — UNSTABLE rotation + axis 2 (e₃): handle axis — largest I — stable rotation + """ + + metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30} + + def __init__( + self, + # ── Geometry & mass ────────────────────────────────────────────── + head_a: float = 0.195, # head semi-axis (long), m + head_b: float = 0.135, # head semi-axis (short), m + handle_length: float = 0.25, # handle length, m + handle_radius: float = 0.015, # handle radius, m + total_mass: float = 0.290, # total mass, kg + head_mass_ratio: float = 0.70, # fraction of mass in head + # ── Integration ────────────────────────────────────────────────── + dt: float = 0.05, + max_torque: float = 2.0, # action space bound (API only) + # ── Initial condition ───────────────────────────────────────────── + omega_0_scale: float = 2 * np.pi, # base spin rate, rad/s (~1 rev/s) + perturb_std: float = 0.05, # perturbation std, rad/s + axis_weights: Tuple[float, float, float] = (0.15, 0.70, 0.15), + # ── Disturbance torque (constant per trajectory) ────────────────── + disturbance_torque_std: float = 0.0, # 0 = off; e.g. 0.05 N·m + # ── Wind / friction (disabled by default) ───────────────────────── + # [WIND] wind_force_std: float = 0.0, + # [FRICTION] friction_coeff: Union[float, Tuple] = 0.0, + # [FRICTION] varying_friction: bool = False, + # ── Misc ────────────────────────────────────────────────────────── + ori_rep: str = "rotmat", + render_mode: Optional[str] = None, + seed: Optional[int] = None, + ): + super().__init__() + + if ori_rep != "rotmat": + raise ValueError("Only ori_rep='rotmat' is supported.") + + # ── Geometry ── + self.head_a = float(head_a) + self.head_b = float(head_b) + self.handle_length = float(handle_length) + self.handle_radius = float(handle_radius) + self.total_mass = float(total_mass) + self.head_mass_ratio = float(head_mass_ratio) + + # ── Inertia ── + self.I, self.I_inv, self.P_body_from_geom = compute_racket_inertia( + head_a, head_b, handle_length, handle_radius, + total_mass, head_mass_ratio + ) + # Convenience: diagonal values in order [I1, I2, I3] + self.I_diag = np.diag(self.I) + + # ── Integration ── + self.dt = float(dt) + self.max_torque = float(max_torque) + self.ori_rep = ori_rep + + # ── Initial condition ── + self.omega_0_scale = float(omega_0_scale) + self.perturb_std = float(perturb_std) + axis_w = np.array(axis_weights, dtype=np.float64) + self.axis_weights = axis_w / axis_w.sum() # normalise to sum=1 + + # ── Disturbance torque ── + self.disturbance_torque_std = float(disturbance_torque_std) + self._disturbance_torque = np.zeros(3, dtype=np.float64) + + # ── Wind / friction (disabled) ── + # [WIND] self.wind_force_std = float(wind_force_std) + # [FRICTION] self.friction_coeff = np.broadcast_to( + # [FRICTION] np.asarray(friction_coeff, dtype=np.float64), (3,)).copy() + # [FRICTION] self.varying_friction = bool(varying_friction) + + # ── Spaces ── + self.action_space = spaces.Box( + low=-self.max_torque, high=self.max_torque, + shape=(3,), dtype=np.float32 + ) + self.observation_space = spaces.Box( + low=-np.inf, high=np.inf, shape=(12,), dtype=np.float32 + ) + + # ── State ── + self._np_rng = np.random.default_rng(seed) + self.t = 0.0 + self.last_u = np.zeros(3, dtype=np.float64) + self.R = np.eye(3, dtype=np.float64) + self.omega = np.zeros(3, dtype=np.float64) + + self._fig = None + self._ax = None + self.render_mode = render_mode + + # ── Properties ──────────────────────────────────────────────────────── + + @property + def intermediate_axis(self) -> int: + """Index of the intermediate (unstable) principal axis (always 1).""" + return 1 + + def get_state(self): + return self.R.copy(), self.omega.copy() + + def get_inertia_info(self) -> dict: + """Return a dict summarising the inertia tensor for logging.""" + return { + "I1": self.I_diag[0], + "I2": self.I_diag[1], + "I3": self.I_diag[2], + "head_a": self.head_a, + "head_b": self.head_b, + "handle_length": self.handle_length, + "total_mass": self.total_mass, + "head_mass_ratio": self.head_mass_ratio, + } + + # ── Seeding ─────────────────────────────────────────────────────────── + + def seed(self, seed: Optional[int] = None): + self._np_rng = np.random.default_rng(seed) + + # ── Wind model (disabled — uncomment to enable) ─────────────────────── + + # [WIND] def update_wind(self, t: float) -> np.ndarray: + # [WIND] """Return a 3-vector stochastic wind torque increment.""" + # [WIND] if self.wind_force_std > 0.0: + # [WIND] return self._np_rng.normal(0.0, self.wind_force_std, size=3) + # [WIND] return np.zeros(3) + + # ── Friction (disabled — uncomment to enable) ───────────────────────── + + # [FRICTION] def _variable_friction(self, omega: np.ndarray) -> np.ndarray: + # [FRICTION] if not self.varying_friction: + # [FRICTION] return self.friction_coeff + # [FRICTION] speed_term = np.tanh(np.linalg.norm(omega)) + # [FRICTION] return self.friction_coeff * (1.0 + 0.5 * speed_term) + + # ── Observation ─────────────────────────────────────────────────────── + + def _get_obs(self) -> np.ndarray: + return np.hstack((self.R.reshape(-1), self.omega)).astype(np.float32) + + # ── Dynamics ────────────────────────────────────────────────────────── + + def _compute_omega_rates( + self, + omega: np.ndarray, + u: np.ndarray, + # [WIND] dW: np.ndarray, sigma: float + ): + """Compute deterministic ω̇ for given angular velocity and torque. + + Euler's equations (body frame): + I ω̇ = τ_total - ω × (I ω) + + where τ_total includes the control torque u, the per-trajectory + constant disturbance, and (when enabled) friction and wind. + """ + tau = u + self._disturbance_torque + + # [FRICTION] fric = self._variable_friction(omega) + # [FRICTION] tau = tau - fric * omega + + # [WIND] if sigma > 0.0: + # [WIND] tau = tau + sigma * dW # additive stochastic torque + + Iw = self.I @ omega + omega_dot = self.I_inv @ (tau - np.cross(omega, Iw)) + return omega_dot + + # ── Lie-group Heun integrator ───────────────────────────────────────── + + def _lie_heun_step( + self, + R: np.ndarray, + omega: np.ndarray, + u: np.ndarray, + h: float, + # [WIND] sigma: float, dW: np.ndarray + ): + """One Stratonovich-Heun substep on SO(3) × ℝ³. + + Identical structure to windy_pendulum_3d._lie_heun_step, but + the dynamics are Euler's equations for a free rigid body (no + gravity, no pivot constraint). + """ + # Stage 1 + omega_dot_1 = self._compute_omega_rates(omega, u) + phi_1 = omega * h + + # Predictor + R_pred = R @ _exp_so3(phi_1) + omega_pred = omega + omega_dot_1 * h + + # Stage 2 + omega_dot_2 = self._compute_omega_rates(omega_pred, u) + phi_2 = omega_pred * h + + # Corrector + phi_avg = 0.5 * (phi_1 + phi_2) + R_new = R @ _exp_so3(phi_avg) + omega_new = omega + 0.5 * (omega_dot_1 + omega_dot_2) * h + + return R_new, omega_new + + # ── Step ────────────────────────────────────────────────────────────── + + def step(self, u): + u = np.asarray(u, dtype=np.float64).reshape(3) + self.last_u = u.copy() + self.t += self.dt + + n_substeps = 10 + dt_sub = self.dt / n_substeps + + # [WIND] sigma = self.wind_force_std + + for _ in range(n_substeps): + # [WIND] dW = self._np_rng.normal(0.0, np.sqrt(dt_sub), size=3) if sigma > 0 else np.zeros(3) + self.R, self.omega = self._lie_heun_step( + self.R, self.omega, u, dt_sub, + # [WIND] sigma, dW + ) + + # Periodic re-orthogonalisation (numerical safety net) + if (abs(np.linalg.det(self.R) - 1.0) > 1e-8 or + np.linalg.norm(self.R.T @ self.R - np.eye(3)) > 1e-8): + self.R = _project_to_so3(self.R) + + # ── Reward: penalise deviation from intermediate-axis rotation ── + # Ideal state: ω aligned with e₂ (body axis 1), R arbitrary. + # We reward |ω · e₂| / |ω| — how close spin is to the unstable axis. + e2 = np.array([0.0, 1.0, 0.0], dtype=np.float64) + omega_norm = float(np.linalg.norm(self.omega)) + if omega_norm > 1e-8: + axis_align = float(np.dot(self.omega / omega_norm, e2)) ** 2 + else: + axis_align = 0.0 + act_cost = 0.001 * float(np.dot(u, u)) + reward = axis_align - act_cost + + obs = self._get_obs() + info = {"disturbance_torque": self._disturbance_torque.copy()} + + return obs, reward, False, False, info + + # ── Reset ───────────────────────────────────────────────────────────── + + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): + """Reset the environment. + + options keys (all optional) + --------------------------- + "R_init" : (3,3) initial rotation matrix (projected to SO(3)) + "omega_init" : (3,) initial angular velocity + "axis" : 0, 1, or 2 — which principal axis to spin about; + if absent, sampled according to self.axis_weights + """ + super().reset(seed=seed) + if seed is not None: + self.seed(seed) + + self.t = 0.0 + self.last_u = np.zeros(3, dtype=np.float64) + + if options is None: + options = {} + + # ── Orientation ── + if "R_init" in options: + R0 = _project_to_so3( + np.asarray(options["R_init"], dtype=np.float64).reshape(3, 3) + ) + else: + R0 = _random_rotation(self._np_rng) + + # ── Angular velocity ── + if "omega_init" in options: + w0 = np.asarray(options["omega_init"], dtype=np.float64).reshape(3) + else: + # Choose rotation axis + if "axis" in options: + axis = int(options["axis"]) + else: + axis = int(self._np_rng.choice(3, p=self.axis_weights)) + + # Spin magnitude (positive or negative with equal probability) + sign = self._np_rng.choice([-1.0, 1.0]) + speed = sign * self.omega_0_scale + + # Base angular velocity along chosen axis + small perturbation + w0 = np.zeros(3, dtype=np.float64) + w0[axis] = speed + w0 += self._np_rng.normal(0.0, self.perturb_std, size=3) + + # ── Per-trajectory disturbance torque ── + # Sampled once per episode; held constant throughout. + # Represents slow aerodynamic asymmetry. + if self.disturbance_torque_std > 0.0: + self._disturbance_torque = self._np_rng.normal( + 0.0, self.disturbance_torque_std, size=3 + ) + else: + self._disturbance_torque = np.zeros(3, dtype=np.float64) + + self.R = R0 + self.omega = w0 + + return self._get_obs(), {} + + # ── Rendering ───────────────────────────────────────────────────────── + + def _init_render(self): + import matplotlib + if self.render_mode == "rgb_array": + matplotlib.use("Agg") + import matplotlib.pyplot as plt + self._fig = plt.figure(figsize=(7, 7)) + self._ax = self._fig.add_subplot(111, projection="3d") + + def render(self): + if self.render_mode is None: + return + + import matplotlib.pyplot as plt + + if self._fig is None or not plt.fignum_exists(self._fig.number): + self._init_render() + + ax = self._ax + ax.cla() + + # Draw principal axes of the body frame + axis_len = 0.25 + axis_colors = ["#e74c3c", "#2ecc71", "#3498db"] + axis_labels = ["e₁ (stable, long)", "e₂ (unstable)", "e₃ (stable, handle)"] + origin = np.zeros(3) + for i in range(3): + e_body = np.zeros(3); e_body[i] = 1.0 + tip = axis_len * (self.R @ e_body) + ax.quiver( + *origin, *tip, + color=axis_colors[i], linewidth=2.5, + arrow_length_ratio=0.18, + ) + ax.text(*(tip * 1.1), axis_labels[i], + color=axis_colors[i], fontsize=7, fontweight="bold") + + # Draw a rough ellipse for the racket head + theta = np.linspace(0, 2 * np.pi, 80) + head_pts_geom = np.column_stack([ + self.head_a * np.cos(theta), + self.head_b * np.sin(theta), + np.zeros_like(theta), + ]) # (80, 3) in geometric frame + + head_pts_body = (self.P_body_from_geom.T @ head_pts_geom.T).T + head_pts_world = (self.R @ head_pts_body.T).T + ax.plot(head_pts_world[:, 0], head_pts_world[:, 1], head_pts_world[:, 2], + color="#888888", linewidth=1.5, alpha=0.8) + + # Draw handle rod + handle_start_geom = np.array([-self.head_a, 0.0, 0.0]) + handle_end_geom = np.array([-self.head_a - self.handle_length, 0.0, 0.0]) + + handle_start_body = self.P_body_from_geom.T @ handle_start_geom + handle_end_body = self.P_body_from_geom.T @ handle_end_geom + hs = self.R @ handle_start_body + he = self.R @ handle_end_body + ax.plot([hs[0], he[0]], [hs[1], he[1]], [hs[2], he[2]], + color="#555555", linewidth=4) + + # Draw ω vector + omega_scale = 0.08 + ov = self.omega * omega_scale + ax.quiver(*origin, *ov, color="#ff9800", linewidth=2, + arrow_length_ratio=0.2, label=f"|ω|={np.linalg.norm(self.omega):.2f}") + + lim = self.head_a * 1.8 + ax.set_xlim(-lim, lim); ax.set_ylim(-lim, lim); ax.set_zlim(-lim, lim) + ax.set_xlabel("X"); ax.set_ylabel("Y"); ax.set_zlabel("Z") + ax.set_title("Tennis Racket — Free Rigid Body on SO(3)", fontweight="bold") + ax.set_box_aspect([1, 1, 1]) + + omega_norm = np.linalg.norm(self.omega) + e2 = np.array([0.0, 1.0, 0.0]) + align = float(np.dot(self.omega / (omega_norm + 1e-12), e2)) ** 2 + hud = ( + f"t={self.t:.2f}s " + f"|ω|={omega_norm:.2f} rad/s " + f"e₂-align={align:.3f}" + ) + ax.text2D(0.02, 0.96, hud, transform=ax.transAxes, + fontsize=9, fontfamily="monospace", + verticalalignment="top", + bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8)) + ax.legend(loc="upper right", fontsize=8, framealpha=0.7) + + if self.render_mode == "human": + plt.draw(); plt.pause(0.001) + else: + self._fig.canvas.draw() + buf = self._fig.canvas.buffer_rgba() + return np.asarray(buf)[:, :, :3].copy() + + def close(self): + if self._fig is not None: + import matplotlib.pyplot as plt + plt.close(self._fig) + self._fig = None + self._ax = None + + +# ────────────────────────────────────────────────────────────────────────────── +# Demo +# ────────────────────────────────────────────────────────────────────────────── + +if __name__ == "__main__": + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + from matplotlib.animation import FuncAnimation + import os + + N_STEPS = 400 + SAVE_PATH = "videos/tennis_racket_3d.mp4" + + env = tennis_racket_3d( + disturbance_torque_std=0.0, + render_mode="rgb_array", + seed=42, + ) + # Start spinning about the intermediate axis with a small perturbation + obs, _ = env.reset(seed=42, options={"axis": 1}) + print("Inertia:", env.get_inertia_info()) + print(f"I1={env.I_diag[0]:.5f} I2={env.I_diag[1]:.5f} I3={env.I_diag[2]:.5f} kg·m²") + + frames = [] + frame = env.render() + if frame is not None: + frames.append(frame) + + for step in range(N_STEPS): + obs, reward, _, _, info = env.step([0.0, 0.0, 0.0]) + frame = env.render() + if frame is not None: + frames.append(frame) + if step % 40 == 0: + R, omega = env.get_state() + print( + f"Step {step:4d} | reward={reward:.4f} | " + f"|ω|={np.linalg.norm(omega):.3f} | " + f"det(R)={np.linalg.det(R):.8f}" + ) + + env.close() + + os.makedirs(os.path.dirname(SAVE_PATH), exist_ok=True) + fig_vid, ax_vid = plt.subplots(figsize=(7, 7)) + ax_vid.axis("off") + im = ax_vid.imshow(frames[0]) + + def _update(i): + im.set_data(frames[i]) + return [im] + + anim = FuncAnimation(fig_vid, _update, frames=len(frames), + interval=1000 / 30, blit=True) + anim.save(SAVE_PATH, writer="ffmpeg", fps=30, dpi=100) + plt.close(fig_vid) + print(f"Saved to {SAVE_PATH}") + diff --git a/envs/tennis_racket_3d_friction.py b/envs/tennis_racket_3d_friction.py new file mode 100644 index 0000000..adebaa5 --- /dev/null +++ b/envs/tennis_racket_3d_friction.py @@ -0,0 +1,691 @@ +from __future__ import annotations + +"""Free rigid body (tennis racket) environment on SO(3). + +Models a tennis racket as an elliptical hoop (head) + cylindrical rod (handle). +The principal moments of inertia are computed analytically from geometry. +The equations of motion are Euler's equations for a torque-free (or torqued) +rigid body, integrated with the same Lie-group Heun scheme as windy_pendulum_3d. + +State : (R ∈ SO(3), ω ∈ ℝ³) → obs vector length 12 (same as pendulum) +Action : τ ∈ ℝ³ (body-frame torque) length 3 + +Body-frame principal axes convention +------------------------------------- + e₁ — long axis of head (largest semi-axis a, SMALLEST moment I₁) + e₂ — short axis of head (semi-axis b, INTERMEDIATE I₂) + e₃ — handle axis (out-of-plane, LARGEST moment I₃) + +This ordering guarantees I₁ < I₂ < I₃ for realistic racket geometries, +so rotation about e₂ is the unstable intermediate axis (Dzhanibekov effect). + +Inertia formulae (thin-shell / thin-rod approximations) +--------------------------------------------------------- +Elliptical hoop (mass m_h, semi-axes a, b): + I_hoop_1 = (1/2) m_h b² (about e₁) + I_hoop_2 = (1/2) m_h a² (about e₂) + I_hoop_3 = (1/2) m_h (a² + b²) (about e₃) + +Thin rod (mass m_r, length L, radius r_r ≪ L): + COM of rod is offset d = a + L/2 from racket COM (along -e₁ / handle dir) + I_rod_1 = (1/12) m_r L² + m_r d² (parallel axis, about e₁) + I_rod_2 = (1/12) m_r L² + m_r d² (same by symmetry) + I_rod_3 = (1/2) m_r r_r² (about handle axis ≈ 0 for thin rod) + +Total: I_i = I_hoop_i + I_rod_i for i = 1, 2, 3. + +Wind / friction (commented out — ready to enable) +--------------------------------------------------- +The wind block mirrors windy_pendulum_3d exactly. To enable, uncomment the +sections marked # [WIND] and # [FRICTION] and pass the relevant kwargs. +""" + +from typing import Optional, Tuple, Union + +import numpy as np +import gymnasium as gym +from gymnasium import spaces + + +# ────────────────────────────────────────────────────────────────────────────── +# SO(3) Lie group utilities (identical to windy_pendulum_3d) +# ────────────────────────────────────────────────────────────────────────────── + +def _hat(w: np.ndarray) -> np.ndarray: + """Skew-symmetric matrix (hat map) for w in R^3.""" + wx, wy, wz = w + return np.array([[0.0, -wz, wy], + [wz, 0.0, -wx], + [-wy, wx, 0.0]], dtype=np.float64) + + +def _vee(W: np.ndarray) -> np.ndarray: + """Inverse hat map (vee).""" + return np.array([W[2, 1], W[0, 2], W[1, 0]], dtype=np.float64) + + +def _exp_so3(phi: np.ndarray) -> np.ndarray: + """Matrix exponential on so(3) via Rodrigues' formula.""" + theta_sq = np.dot(phi, phi) + theta = np.sqrt(theta_sq) + Phi = _hat(phi) + if theta < 1e-10: + A = 1.0 - theta_sq / 6.0 + B = 0.5 - theta_sq / 24.0 + else: + A = np.sin(theta) / theta + B = (1.0 - np.cos(theta)) / theta_sq + return np.eye(3) + A * Phi + B * (Phi @ Phi) + + +def _log_so3(R: np.ndarray) -> np.ndarray: + """Logarithmic map on SO(3).""" + cos_theta = np.clip(0.5 * (np.trace(R) - 1.0), -1.0, 1.0) + theta = np.arccos(cos_theta) + if theta < 1e-10: + return _vee(0.5 * (R - R.T)) + elif abs(theta - np.pi) < 1e-6: + M = R + np.eye(3) + norms = np.linalg.norm(M, axis=0) + k = np.argmax(norms) + v = M[:, k] / norms[k] + return v * theta + else: + return _vee(theta / (2.0 * np.sin(theta)) * (R - R.T)) + + +def _project_to_so3(R: np.ndarray) -> np.ndarray: + """Project to nearest rotation matrix via SVD (safety net).""" + U, _, Vt = np.linalg.svd(R) + Rp = U @ Vt + if np.linalg.det(Rp) < 0: + U[:, -1] *= -1.0 + Rp = U @ Vt + return Rp + + +def _random_rotation(rng: np.random.Generator) -> np.ndarray: + """Uniform random rotation via Shoemake's method.""" + u1, u2, u3 = rng.random(3) + q1 = np.sqrt(1 - u1) * np.sin(2 * np.pi * u2) + q2 = np.sqrt(1 - u1) * np.cos(2 * np.pi * u2) + q3 = np.sqrt(u1) * np.sin(2 * np.pi * u3) + q4 = np.sqrt(u1) * np.cos(2 * np.pi * u3) + x, y, z, w = q1, q2, q3, q4 + return np.array([ + [1 - 2*(y*y + z*z), 2*(x*y - z*w), 2*(x*z + y*w)], + [ 2*(x*y + z*w), 1 - 2*(x*x + z*z), 2*(y*z - x*w)], + [ 2*(x*z - y*w), 2*(y*z + x*w), 1 - 2*(x*x + y*y)], + ], dtype=np.float64) + + +# ────────────────────────────────────────────────────────────────────────────── +# Inertia computation +# ────────────────────────────────────────────────────────────────────────────── + +def compute_racket_inertia( + head_a: float, + head_b: float, + handle_length: float, + handle_radius: float, + total_mass: float, + head_mass_ratio: float, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Compute the principal inertia tensor for the tennis-racket model. + + The racket is modelled as: + • An elliptical hoop (head) with semi-axes a (long) × b (short), + lying in the e₁-e₂ plane. + • A thin cylindrical rod (handle) of length L along -e₁, + with its proximal end at x = -a from the racket COM. + + Parameters + ---------- + head_a : semi-axis along e₁ (long direction of head), metres + head_b : semi-axis along e₂ (short direction of head), metres + handle_length : length of handle rod, metres + handle_radius : radius of handle rod (only affects I₃ of rod), metres + total_mass : total racket mass, kg + head_mass_ratio : fraction of total mass in the head hoop, (0, 1) + + Returns + ------- + I : (3,3) diagonal inertia tensor [I₁, I₂, I₃] in body frame + I_inv : inverse of I + """ + m_h = total_mass * head_mass_ratio # head mass + m_r = total_mass * (1.0 - head_mass_ratio) # handle mass + + # ── Head (elliptical hoop) ── + # Thin hoop: all mass on the ellipse perimeter. + # For a uniform elliptical ring the MOI about its own axes are: + # about e₁ (long axis in plane) : I = (1/2) m b² + # about e₂ (short axis in plane): I = (1/2) m a² + # about e₃ (normal to plane) : I = (1/2) m (a²+b²) + I_h1 = 0.5 * m_h * head_b**2 + I_h2 = 0.5 * m_h * head_a**2 + I_h3 = 0.5 * m_h * (head_a**2 + head_b**2) + + # ── Handle (thin rod) ── + # Rod COM is at distance d = head_a + handle_length/2 from racket COM + # (along -e₁). Parallel-axis theorem shifts the rod's own MOI. + d = head_a + 0.5 * handle_length # offset along e₁ + I_rod_own = (1.0 / 12.0) * m_r * handle_length**2 # rod about its midpoint + + I_r1 = I_rod_own + m_r * d**2 # about e₁ (⊥ to rod, in head plane) + I_r2 = I_rod_own + m_r * d**2 # about e₂ (same by symmetry) + I_r3 = 0.5 * m_r * handle_radius**2 # about e₃ ≈ 0 for thin rod + + # ── Total in original geometric axes [e1_geom, e2_geom, e3_geom] ── + I_geom = np.array([ + I_h1 + I_r1, + I_h2 + I_r2, + I_h3 + I_r3, + ], dtype=np.float64) + + # Sort so body-frame axes are always labeled: + # axis 0 = smallest I, axis 1 = intermediate I, axis 2 = largest I + perm = np.argsort(I_geom) + I_sorted = I_geom[perm] + + # Permutation matrix: columns are old geometric axes selected for new body axes + P = np.eye(3)[:, perm] + + I_diag = np.diag(I_sorted) + return I_diag, np.linalg.inv(I_diag), P + + +# ────────────────────────────────────────────────────────────────────────────── +# Environment +# ────────────────────────────────────────────────────────────────────────────── + +class tennis_racket_3d(gym.Env): + """Free rigid body (tennis racket) on SO(3). + + Equations of motion (body frame): + I ω̇ = -ω × (I ω) + τ (Euler's equations) + Ṙ = R · hat(ω) + + No gravity (free body in space). Optional external torque (constant + per trajectory, sampled from N(0, disturbance_torque_std²·I)) models + slow aerodynamic disturbances. + + State obs (12,): [R.flatten() (9), ω (3)] + Action (3,): body-frame torque τ + + Principal axes convention (body frame): + axis 0 (e₁): long head axis — smallest I — stable rotation + axis 1 (e₂): short head axis — middle I — UNSTABLE rotation + axis 2 (e₃): handle axis — largest I — stable rotation + """ + + metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30} + + def __init__( + self, + # ── Geometry & mass ────────────────────────────────────────────── + head_a: float = 0.195, # head semi-axis (long), m + head_b: float = 0.135, # head semi-axis (short), m + handle_length: float = 0.25, # handle length, m + handle_radius: float = 0.015, # handle radius, m + total_mass: float = 0.290, # total mass, kg + head_mass_ratio: float = 0.70, # fraction of mass in head + # ── Integration ────────────────────────────────────────────────── + dt: float = 0.05, + max_torque: float = 2.0, # action space bound (API only) + # ── Initial condition ───────────────────────────────────────────── + omega_0_scale: float = 2 * np.pi, # base spin rate, rad/s (~1 rev/s) + perturb_std: float = 0.05, # perturbation std, rad/s + axis_weights: Tuple[float, float, float] = (0.15, 0.70, 0.15), + # ── Disturbance torque (constant per trajectory) ────────────────── + disturbance_torque_std: float = 0.0, # 0 = off; e.g. 0.05 N·m + # ── Wind / friction (disabled by default) ───────────────────────── + # [WIND] wind_force_std: float = 0.0, + friction_coeff: Union[float, Tuple] = 0.0, + varying_friction: bool = False, + # ── Misc ────────────────────────────────────────────────────────── + ori_rep: str = "rotmat", + render_mode: Optional[str] = None, + seed: Optional[int] = None, + ): + super().__init__() + + if ori_rep != "rotmat": + raise ValueError("Only ori_rep='rotmat' is supported.") + + # ── Geometry ── + self.head_a = float(head_a) + self.head_b = float(head_b) + self.handle_length = float(handle_length) + self.handle_radius = float(handle_radius) + self.total_mass = float(total_mass) + self.head_mass_ratio = float(head_mass_ratio) + + # ── Inertia ── + self.I, self.I_inv, self.P_body_from_geom = compute_racket_inertia( + head_a, head_b, handle_length, handle_radius, + total_mass, head_mass_ratio + ) + # Convenience: diagonal values in order [I1, I2, I3] + self.I_diag = np.diag(self.I) + + # ── Integration ── + self.dt = float(dt) + self.max_torque = float(max_torque) + self.ori_rep = ori_rep + + # ── Initial condition ── + self.omega_0_scale = float(omega_0_scale) + self.perturb_std = float(perturb_std) + axis_w = np.array(axis_weights, dtype=np.float64) + self.axis_weights = axis_w / axis_w.sum() # normalise to sum=1 + + # ── Disturbance torque ── + self.disturbance_torque_std = float(disturbance_torque_std) + self._disturbance_torque = np.zeros(3, dtype=np.float64) + + # ── Wind / friction (disabled) ── + # [WIND] self.wind_force_std = float(wind_force_std) + self.friction_coeff = np.broadcast_to( + np.asarray(friction_coeff, dtype=np.float64), (3,)).copy() + self.varying_friction = bool(varying_friction) + + # ── Spaces ── + self.action_space = spaces.Box( + low=-self.max_torque, high=self.max_torque, + shape=(3,), dtype=np.float32 + ) + self.observation_space = spaces.Box( + low=-np.inf, high=np.inf, shape=(12,), dtype=np.float32 + ) + + # ── State ── + self._np_rng = np.random.default_rng(seed) + self.t = 0.0 + self.last_u = np.zeros(3, dtype=np.float64) + self.R = np.eye(3, dtype=np.float64) + self.omega = np.zeros(3, dtype=np.float64) + + self._fig = None + self._ax = None + self.render_mode = render_mode + + # ── Properties ──────────────────────────────────────────────────────── + + @property + def intermediate_axis(self) -> int: + """Index of the intermediate (unstable) principal axis (always 1).""" + return 1 + + def get_state(self): + return self.R.copy(), self.omega.copy() + + def get_inertia_info(self) -> dict: + """Return a dict summarising the inertia tensor for logging.""" + return { + "I1": self.I_diag[0], + "I2": self.I_diag[1], + "I3": self.I_diag[2], + "head_a": self.head_a, + "head_b": self.head_b, + "handle_length": self.handle_length, + "total_mass": self.total_mass, + "head_mass_ratio": self.head_mass_ratio, + } + + # ── Seeding ─────────────────────────────────────────────────────────── + + def seed(self, seed: Optional[int] = None): + self._np_rng = np.random.default_rng(seed) + + # ── Wind model (disabled — uncomment to enable) ─────────────────────── + + # [WIND] def update_wind(self, t: float) -> np.ndarray: + # [WIND] """Return a 3-vector stochastic wind torque increment.""" + # [WIND] if self.wind_force_std > 0.0: + # [WIND] return self._np_rng.normal(0.0, self.wind_force_std, size=3) + # [WIND] return np.zeros(3) + + # ── Friction (disabled — uncomment to enable) ───────────────────────── + + def _variable_friction(self, omega: np.ndarray) -> np.ndarray: + if not self.varying_friction: + return self.friction_coeff + speed_term = np.tanh(np.linalg.norm(omega)) + return self.friction_coeff * (1.0 + 0.5 * speed_term) + + # ── Observation ─────────────────────────────────────────────────────── + + def _get_obs(self) -> np.ndarray: + return np.hstack((self.R.reshape(-1), self.omega)).astype(np.float32) + + # ── Dynamics ────────────────────────────────────────────────────────── + + def _compute_omega_rates( + self, + omega: np.ndarray, + u: np.ndarray, + # [WIND] dW: np.ndarray, sigma: float + ): + """Compute deterministic ω̇ for given angular velocity and torque. + + Euler's equations (body frame): + I ω̇ = τ_total - ω × (I ω) + + where τ_total includes the control torque u, the per-trajectory + constant disturbance, and (when enabled) friction and wind. + """ + tau = u + self._disturbance_torque + + fric = self._variable_friction(omega) + tau = tau - fric * omega + + # [WIND] if sigma > 0.0: + # [WIND] tau = tau + sigma * dW # additive stochastic torque + + Iw = self.I @ omega + omega_dot = self.I_inv @ (tau - np.cross(omega, Iw)) + return omega_dot + + # ── Lie-group Heun integrator ───────────────────────────────────────── + + def _lie_heun_step( + self, + R: np.ndarray, + omega: np.ndarray, + u: np.ndarray, + h: float, + # [WIND] sigma: float, dW: np.ndarray + ): + """One Stratonovich-Heun substep on SO(3) × ℝ³. + + Identical structure to windy_pendulum_3d._lie_heun_step, but + the dynamics are Euler's equations for a free rigid body (no + gravity, no pivot constraint). + """ + # Stage 1 + omega_dot_1 = self._compute_omega_rates(omega, u) + phi_1 = omega * h + + # Predictor + R_pred = R @ _exp_so3(phi_1) + omega_pred = omega + omega_dot_1 * h + + # Stage 2 + omega_dot_2 = self._compute_omega_rates(omega_pred, u) + phi_2 = omega_pred * h + + # Corrector + phi_avg = 0.5 * (phi_1 + phi_2) + R_new = R @ _exp_so3(phi_avg) + omega_new = omega + 0.5 * (omega_dot_1 + omega_dot_2) * h + + return R_new, omega_new + + # ── Step ────────────────────────────────────────────────────────────── + + def step(self, u): + u = np.asarray(u, dtype=np.float64).reshape(3) + self.last_u = u.copy() + self.t += self.dt + + n_substeps = 10 + dt_sub = self.dt / n_substeps + + # [WIND] sigma = self.wind_force_std + + for _ in range(n_substeps): + # [WIND] dW = self._np_rng.normal(0.0, np.sqrt(dt_sub), size=3) if sigma > 0 else np.zeros(3) + self.R, self.omega = self._lie_heun_step( + self.R, self.omega, u, dt_sub, + # [WIND] sigma, dW + ) + + # Periodic re-orthogonalisation (numerical safety net) + if (abs(np.linalg.det(self.R) - 1.0) > 1e-8 or + np.linalg.norm(self.R.T @ self.R - np.eye(3)) > 1e-8): + self.R = _project_to_so3(self.R) + + # ── Reward: penalise deviation from intermediate-axis rotation ── + # Ideal state: ω aligned with e₂ (body axis 1), R arbitrary. + # We reward |ω · e₂| / |ω| — how close spin is to the unstable axis. + e2 = np.array([0.0, 1.0, 0.0], dtype=np.float64) + omega_norm = float(np.linalg.norm(self.omega)) + if omega_norm > 1e-8: + axis_align = float(np.dot(self.omega / omega_norm, e2)) ** 2 + else: + axis_align = 0.0 + act_cost = 0.001 * float(np.dot(u, u)) + reward = axis_align - act_cost + + obs = self._get_obs() + info = {"disturbance_torque": self._disturbance_torque.copy()} + + return obs, reward, False, False, info + + # ── Reset ───────────────────────────────────────────────────────────── + + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): + """Reset the environment. + + options keys (all optional) + --------------------------- + "R_init" : (3,3) initial rotation matrix (projected to SO(3)) + "omega_init" : (3,) initial angular velocity + "axis" : 0, 1, or 2 — which principal axis to spin about; + if absent, sampled according to self.axis_weights + """ + super().reset(seed=seed) + if seed is not None: + self.seed(seed) + + self.t = 0.0 + self.last_u = np.zeros(3, dtype=np.float64) + + if options is None: + options = {} + + # ── Orientation ── + if "R_init" in options: + R0 = _project_to_so3( + np.asarray(options["R_init"], dtype=np.float64).reshape(3, 3) + ) + else: + R0 = _random_rotation(self._np_rng) + + # ── Angular velocity ── + if "omega_init" in options: + w0 = np.asarray(options["omega_init"], dtype=np.float64).reshape(3) + else: + # Choose rotation axis + if "axis" in options: + axis = int(options["axis"]) + else: + axis = int(self._np_rng.choice(3, p=self.axis_weights)) + + # Spin magnitude (positive or negative with equal probability) + sign = self._np_rng.choice([-1.0, 1.0]) + speed = sign * self.omega_0_scale + + # Base angular velocity along chosen axis + small perturbation + w0 = np.zeros(3, dtype=np.float64) + w0[axis] = speed + w0 += self._np_rng.normal(0.0, self.perturb_std, size=3) + + # ── Per-trajectory disturbance torque ── + # Sampled once per episode; held constant throughout. + # Represents slow aerodynamic asymmetry. + if self.disturbance_torque_std > 0.0: + self._disturbance_torque = self._np_rng.normal( + 0.0, self.disturbance_torque_std, size=3 + ) + else: + self._disturbance_torque = np.zeros(3, dtype=np.float64) + + self.R = R0 + self.omega = w0 + + return self._get_obs(), {} + + # ── Rendering ───────────────────────────────────────────────────────── + + def _init_render(self): + import matplotlib + if self.render_mode == "rgb_array": + matplotlib.use("Agg") + import matplotlib.pyplot as plt + self._fig = plt.figure(figsize=(7, 7)) + self._ax = self._fig.add_subplot(111, projection="3d") + + def render(self): + if self.render_mode is None: + return + + import matplotlib.pyplot as plt + + if self._fig is None or not plt.fignum_exists(self._fig.number): + self._init_render() + + ax = self._ax + ax.cla() + + # Draw principal axes of the body frame + axis_len = 0.25 + axis_colors = ["#e74c3c", "#2ecc71", "#3498db"] + axis_labels = ["e₁ (stable, long)", "e₂ (unstable)", "e₃ (stable, handle)"] + origin = np.zeros(3) + for i in range(3): + e_body = np.zeros(3); e_body[i] = 1.0 + tip = axis_len * (self.R @ e_body) + ax.quiver( + *origin, *tip, + color=axis_colors[i], linewidth=2.5, + arrow_length_ratio=0.18, + ) + ax.text(*(tip * 1.1), axis_labels[i], + color=axis_colors[i], fontsize=7, fontweight="bold") + + # Draw a rough ellipse for the racket head + theta = np.linspace(0, 2 * np.pi, 80) + head_pts_geom = np.column_stack([ + self.head_a * np.cos(theta), + self.head_b * np.sin(theta), + np.zeros_like(theta), + ]) # (80, 3) in geometric frame + + head_pts_body = (self.P_body_from_geom.T @ head_pts_geom.T).T + head_pts_world = (self.R @ head_pts_body.T).T + ax.plot(head_pts_world[:, 0], head_pts_world[:, 1], head_pts_world[:, 2], + color="#888888", linewidth=1.5, alpha=0.8) + + # Draw handle rod + handle_start_geom = np.array([-self.head_a, 0.0, 0.0]) + handle_end_geom = np.array([-self.head_a - self.handle_length, 0.0, 0.0]) + + handle_start_body = self.P_body_from_geom.T @ handle_start_geom + handle_end_body = self.P_body_from_geom.T @ handle_end_geom + hs = self.R @ handle_start_body + he = self.R @ handle_end_body + ax.plot([hs[0], he[0]], [hs[1], he[1]], [hs[2], he[2]], + color="#555555", linewidth=4) + + # Draw ω vector + omega_scale = 0.08 + ov = self.omega * omega_scale + ax.quiver(*origin, *ov, color="#ff9800", linewidth=2, + arrow_length_ratio=0.2, label=f"|ω|={np.linalg.norm(self.omega):.2f}") + + lim = self.head_a * 1.8 + ax.set_xlim(-lim, lim); ax.set_ylim(-lim, lim); ax.set_zlim(-lim, lim) + ax.set_xlabel("X"); ax.set_ylabel("Y"); ax.set_zlabel("Z") + ax.set_title("Tennis Racket — Free Rigid Body on SO(3)", fontweight="bold") + ax.set_box_aspect([1, 1, 1]) + + omega_norm = np.linalg.norm(self.omega) + e2 = np.array([0.0, 1.0, 0.0]) + align = float(np.dot(self.omega / (omega_norm + 1e-12), e2)) ** 2 + hud = ( + f"t={self.t:.2f}s " + f"|ω|={omega_norm:.2f} rad/s " + f"e₂-align={align:.3f}" + ) + ax.text2D(0.02, 0.96, hud, transform=ax.transAxes, + fontsize=9, fontfamily="monospace", + verticalalignment="top", + bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8)) + ax.legend(loc="upper right", fontsize=8, framealpha=0.7) + + if self.render_mode == "human": + plt.draw(); plt.pause(0.001) + else: + self._fig.canvas.draw() + buf = self._fig.canvas.buffer_rgba() + return np.asarray(buf)[:, :, :3].copy() + + def close(self): + if self._fig is not None: + import matplotlib.pyplot as plt + plt.close(self._fig) + self._fig = None + self._ax = None + + +# ────────────────────────────────────────────────────────────────────────────── +# Demo +# ────────────────────────────────────────────────────────────────────────────── + +if __name__ == "__main__": + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + from matplotlib.animation import FuncAnimation + import os + + N_STEPS = 400 + SAVE_PATH = "videos/tennis_racket_3d.mp4" + + env = tennis_racket_3d( + disturbance_torque_std=0.0, + render_mode="rgb_array", + seed=42, + ) + # Start spinning about the intermediate axis with a small perturbation + obs, _ = env.reset(seed=42, options={"axis": 1}) + print("Inertia:", env.get_inertia_info()) + print(f"I1={env.I_diag[0]:.5f} I2={env.I_diag[1]:.5f} I3={env.I_diag[2]:.5f} kg·m²") + + frames = [] + frame = env.render() + if frame is not None: + frames.append(frame) + + for step in range(N_STEPS): + obs, reward, _, _, info = env.step([0.0, 0.0, 0.0]) + frame = env.render() + if frame is not None: + frames.append(frame) + if step % 40 == 0: + R, omega = env.get_state() + print( + f"Step {step:4d} | reward={reward:.4f} | " + f"|ω|={np.linalg.norm(omega):.3f} | " + f"det(R)={np.linalg.det(R):.8f}" + ) + + env.close() + + os.makedirs(os.path.dirname(SAVE_PATH), exist_ok=True) + fig_vid, ax_vid = plt.subplots(figsize=(7, 7)) + ax_vid.axis("off") + im = ax_vid.imshow(frames[0]) + + def _update(i): + im.set_data(frames[i]) + return [im] + + anim = FuncAnimation(fig_vid, _update, frames=len(frames), + interval=1000 / 30, blit=True) + anim.save(SAVE_PATH, writer="ffmpeg", fps=30, dpi=100) + plt.close(fig_vid) + print(f"Saved to {SAVE_PATH}") + diff --git a/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/README.md b/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/README.md new file mode 100644 index 0000000..1f7b46e --- /dev/null +++ b/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/README.md @@ -0,0 +1,530 @@ +# Port-Hamiltonian Neural ODE on SO(3) — Tennis Racket / Dzhanibekov Effect + +A physics-structured neural ODE that learns the Hamiltonian and dissipation of a free +rigid body from trajectory data, then uses that learned structure to stabilise the +system with an IDA-PBC controller. + +--- + +## 1. The physics + +### 1.1 Euler's equations for a free rigid body + +A rigid body rotating freely in space obeys Euler's equations in the body frame: + +``` +I ω̇ = −ω × (I ω) + τ +Ṙ = R · hat(ω) +``` + +where `I = diag(I₁, I₂, I₃)` is the inertia tensor, `ω ∈ ℝ³` is the angular +velocity in the body frame, `R ∈ SO(3)` is the orientation, and `τ ∈ ℝ³` is an +applied body-frame torque. + +### 1.2 The Dzhanibekov (intermediate-axis) theorem + +The three principal axes have a fundamental stability difference: + +| Axis | Inertia | Open-loop stability | Notes | +|------|---------|---------------------|-------| +| e₁ (long head) | I₁ = 0.0057 kg·m² | **stable** (oscillations) | smallest I | +| e₂ (short head) | I₂ = 0.0112 kg·m² | **UNSTABLE** (saddle) | intermediate I | +| e₃ (handle) | I₃ = 0.0132 kg·m² | **stable** (oscillations) | largest I | + +Rotation about e₂ is linearly unstable: the linearised free-body eigenvalue is +`λ = ±ω₂ √[(I₂−I₃)(I₁−I₂)/(I₁I₃)]`, which is real positive for intermediate I. +For cfg0 at ω₂ = 2π rad/s: `λ = ±2.40 s⁻¹` (doubling time ≈ 0.29 s). + +### 1.3 Port-Hamiltonian structure + +The system is cast as a port-Hamiltonian system on SO(3) × ℝ³: + +``` +H(R, ω) = ½ ωᵀ M⁻¹(R) ω + V(R) +``` + +- `M⁻¹(R)` — effective inverse-inertia (PSD matrix, learned by `M_net`) +- `V(R)` — potential energy (zero for free body in space, learned by `V_net`) +- `Dw(R)` — dissipation matrix (zero for frictionless body, learned by `Dw_net`) +- `g(R)` — input coupling matrix (= I₃ for direct body-frame torque, learned by `g_net`) + +The dynamics in port-Hamiltonian form: + +``` +I ω̇ = J(R,ω) ∇_ω H − Dw ∇_ω H + g u +Ṙ = R · hat(ω) +``` + +where `J` is the structure matrix encoding gyroscopic coupling. + +--- + +## 2. Model architecture + +### 2.1 State representation + +| Stage | Dimension | Layout | +|-------|-----------|--------| +| A, B | 15D | `[vec(R)₉, ω₃, u₃]` | +| C | 18D | `[vec(R)₉, I₁I₂I₃₃, ω₃, u₃]` — inertia embedded in state | + +The inertia is embedded in the state vector in Stage C so that one model generalises +across multiple racket geometries at inference time. + +### 2.2 Sub-networks + +All sub-networks take `q_ext = [vec(R), I₁, I₂, I₃]` (12D input in Stage C). + +| Network | Output | Constraint | Notes | +|---------|--------|-----------|-------| +| `M_net` | 3×3 matrix | PSD via `L Lᵀ + ε I` | inverse inertia `M⁻¹` | +| `V_net` | scalar | — | potential energy | +| `Dw_net` | 3×3 matrix | PSD | dissipation | +| `g_net` | 3×3 matrix | — | input coupling | + +### 2.3 `FixedInertiaFromState` (Stage C) + +Instead of training `M_net` to learn the inertia from SO(3) geometry, we pin it to the +exact physical value read from the embedded state: + +```python +class FixedInertiaFromState(nn.Module): + def forward(self, q_ext): # q_ext: (N, 12) + return torch.diag_embed(1.0 / q_ext[:, 9:12]) # (N, 3, 3) +``` + +This resolves the `M → αM` scale invariance (Hamiltonian degeneracy) that arises in +free-body training and makes multi-config generalisation exact by construction. + +### 2.4 Integration + +The neural ODE is integrated using `torchdiffeq` RK4 (`method='rk4'`). +The physical environment uses a Lie-group Heun integrator with 10 substeps per +`dt = 0.05 s` step. + +--- + +## 3. Training stages + +### Stage A — single-config, torque-free (`train.py`) + +Train on one racket geometry with zero applied torque. +`--fix_M` pins M_net to the exact inertia (avoids `M→αM` degeneracy). + +**Result** — windowed geodesic loss = 2.03×10⁻⁶ (threshold 0.01 rad²). +Key insight: without `--lambda_V_zero`, the optimizer finds a `(M_wrong, V_compensating)` +saddle — V absorbs what M misses. Force `V → 0` to recover physical M. + +--- + +### Stage A end-to-end (`train_e2e.py`) + +Full end-to-end training (M, V, Dw, g jointly). +Uses `--lambda_V_zero 1.0` to break the (M, V) Hamiltonian degeneracy. + +--- + +### Stage B — friction (`train_stageB.py`) + +Add physical friction `τ_fric = −fric_coeff · ω` to the dataset. +`Dw_net` learns the dissipation: `Dw_loss` converges to 5.13×10⁻¹⁴ (10 orders of +magnitude below threshold). Friction makes the system contractive, so full-trajectory +variance collapses to near zero. + +--- + +### Stage C — multi-config generalisation (`train_stageC.py`, `network_stageC.py`) + +Train jointly on 4 racket geometries with different inertia values. +State extended to 18D: `[vec(R)₉, I₁I₂I₃₃, ω₃, u₃]`. + +Key design decisions: +- `FixedInertiaFromState` reads `I` from state cols 9:12, returns `diag(1/I)`. +- `strip_inertia(x)` removes cols 9:12 before passing 18D state to loss. +- `split = [9, 3, 3]` unchanged — loss sees 15D after stripping. +- Per-sample `M_tgt` computed from embedded inertia in `subnet_diagnostics_stageC.py`. + +**Result** — windowed geo = 2.03×10⁻⁶ across all 4 configs. +Held-out 5th geometry (different head/handle dimensions): same geo = 2.03×10⁻⁶. +Generalisation is exact because `FixedInertiaFromState` reads the exact I from state. + +Run: +```bash +venv/bin/python3 train_stageC.py --fix_M --total_steps 5000 +``` + +Checkpoint saved to `data/run_tr3d_stageC_fp32/`. + +--- + +### Stage D — IDA-PBC stabilisation (`controller_stageD.py`, `simulate_stageD.py`) + +Stabilise spinning about **e₃** (handle axis) starting from Dzhanibekov tumbling (e₂). + +Because the environment applies torque directly (`τ_total = u + τ_disturbance`), the +true input coupling is `g = I₃` exactly. The IDA-PBC law simplifies to: + +``` +u = K_p · (ω* − ω) ω* = (0, 0, 2π) rad/s +``` + +Linearised stability about ω* = ω*₃ e₃ (overdamped for K_p ≥ 0.071): + +| Mode | Eigenvalue | Time constant | +|------|-----------|---------------| +| δω₃ | −K_p/I₃ = −7.6 s⁻¹ | 0.13 s | +| δω₁,₂ (overdamped) | −10.2, −16.2 s⁻¹ | 0.06–0.10 s | + +**Result** (20 trials, no disturbance): + +| K_p | Conv time | Final ‖ω−ω*‖ | +|-----|-----------|--------------| +| 0.01 | 2.97 s | 0.002 rad/s | +| 0.05 | 0.65 s | ~0 | +| **0.10** | **0.30 s** | **~0** (recommended) | +| 0.20 | 0.10 s | ~0 | + +Run: +```bash +venv/bin/python3 simulate_stageD.py --Kp_scan "0.01,0.05,0.10,0.20" +``` + +--- + +### Stage E — e₁, e₂ hold, wind robustness (`simulate_stageE.py`) + +Three scenarios using the same `u = K_p(ω*−ω)` controller: + +**Scenario A** — redirect from e₂ tumbling → **e₁** (no wind): +100% convergence in 0.25 s with K_p = 0.10. e₁ is naturally stable; trivial. + +**Scenario B** — hold **e₂** (saddle, no wind), K_p scan: +The saddle instability requires active feedback to counteract: + +``` +K_p_min = ω*₂ · √(|I₂−I₃| · |I₁−I₂|) = 0.021 N·m·s/rad (theory) + ≤ 0.020 (empirical) +``` + +ZOH stability bound — with control update period T = dt = 0.05 s, the discrete +eigenvalue of the I₁ mode is `z = 1 − K_p T/I₁`. For stability need |z| < 1: + +``` +K_p_max = 2·I₁/T = 2 × 0.0057/0.05 = 0.229 N·m·s/rad +``` + +K_p = 0.50 violates this: z = −3.39, final error grows to 13.7 rad/s. Confirmed. + +Safe operating range for e₂ hold: **K_p ∈ (0.021, 0.229)**. + +**Scenario C** — hold **e₂** with stochastic wind (σ = 0.05 N·m): +Constant-per-episode disturbance `d ∼ N(0, σ² I₃)` sets a steady-state offset: + +``` +‖δω_ss‖ ≈ |d| / K_p (linear prediction) +``` + +Measured ratio (final error / predicted SS) at K_p = 0.20: **1.019** — clean. +The wind further narrows the viable K_p window: + +``` +K_p > |d| / ε_conv = 0.084 / 0.5 = 0.168 (wind-offset requirement) +K_p < K_p_max = 0.229 (ZOH stability) +→ viable range: (0.17, 0.23) +``` + +K_p = 0.20 achieves 95% convergence. PASS. + +**Scenario D** — hold e₁ with wind (σ = 0.05 N·m): +e₁ is stable (free-body λ = ±3.31i), so K_p_min = 0. With wind the binding lower +constraint is K_p > |d|/ε_conv ≈ 0.17. K_p = 0.20 achieves 95% convergence, SS +error ratio = 0.987 (nearly perfect linear prediction). + +**Scenario E** — hold e₃ with wind (σ = 0.05 N·m): +e₃ is stable (free-body λ = ±3.05i), same analysis. K_p = 0.20 achieves 95% +convergence, SS error ratio = 0.999 — the most accurate ratio across all axes. + +The ZOH bound K_p_max = 2·I₁/T = 0.229 is **the same for all three axes** because +it is set by the smallest inertia I₁, not by which axis you are holding. + +SS offset ratios across all axes at K_p = 0.20: + +| Scenario | Target | SS error / (|d|/K_p) | Note | +|----------|--------|----------------------|------| +| C | e₂ (unstable) | 1.019 | wind + saddle narrowing | +| D | e₁ (stable) | 0.987 | wind only | +| E | e₃ (stable) | 0.999 | wind only, cleanest | + +Run: +```bash +venv/bin/python3 simulate_stageE.py # original three scenarios (A, B, C) +venv/bin/python3 simulate_stageE_v2.py # all five scenarios (A–E) +``` + +--- + +### Stage F — Geometric attitude control (`controller_stageF.py`, `simulate_stageF.py`) + +Extends Stage D to stabilise the **full state (R*, ω*) on SO(3) × ℝ³**, not just ω*. +Stage D/E could drive ω → ω* but had no mechanism to control the orientation R. + +Control law: +``` +u = −K_R · e_R − K_p · e_ω + +e_R = vee( logm(R*ᵀ R) ) ∈ ℝ³ geodesic attitude error +e_ω = ω − ω* ∈ ℝ³ angular velocity error +``` + +`vee` extracts the axial vector of a skew-symmetric matrix; `logm` is the matrix +logarithm on SO(3) implemented via the Rodrigues formula. When `K_R = 0` this +reduces exactly to the Stage D proportional controller. + +**IDA-PBC interpretation**: desired Hamiltonian +`H_d = H + ½K_R ‖e_R‖² + ½K_p ‖e_ω‖²`, J_d = J (unchanged), R_d = R + K_p I +(added damping). + +**Linearised closed-loop** (per axis i — independent 2nd-order oscillator): + +| Quantity | Formula | K_R=0.10, K_p=0.10, I=I₁ | +|----------|---------|--------------------------| +| Natural frequency | ωn = √(K_R/Iᵢ) | 4.18 rad/s | +| Damping ratio | ζ = K_p/(2√(K_R·Iᵢ)) | 2.09 (overdamped) | + +**ZOH bounds** (from discrete-time Schur stability): +``` +K_p < 2·I₁/T = 0.229 (velocity loop — same as Stage D/E) +K_R < K_p/T = 2.000 (attitude loop — 2nd-order ZOH constraint) +``` + +**Almost-global stability**: fails only at the antipodal set {‖e_R‖ = π}, +which has measure zero on SO(3). + +**Results** (20 trials, Dzhanibekov start, random R₀): + +| Scenario | Target (R*, ω*) | K_R | Conv time | Notes | +|----------|-----------------|-----|-----------|-------| +| A | I₃, 0 | 0.10 | 2.75 s | 100% ✓ | +| B | Ry(π/4), 0 | 0.10 | 2.82 s | 100% ✓, arbitrary orientation | +| C | I₃, (0,0,2π) | 0 | 0.30 s | 100% ✓, exact Stage D match | +| D (wind σ=0.05) | I₃, 0 | 1.50 | 1.80 s | 95% ✓, SS ‖e_R‖=0.055 rad | + +**Orientation–velocity duality under wind** (new finding in Stage F): + +With ω* = 0, constant wind d ≠ 0 is balanced in steady state by the **orientation** +restoring torque, not the velocity damping: + +``` +K_R · e_R = d → ‖e_R‖_ss = |d| / K_R (velocity → 0 exactly) +``` + +Compare with Stage E (ω* ≠ 0): `‖e_ω‖_ss = |d| / K_p` (orientation unconstrained). +The SS ratio ‖e_R‖_actual / (|d|/K_R) = **1.000** for all tested K_R ∈ [0.10, 1.50]. + +Viable K_R window with wind to satisfy ‖e_R‖ < θ_thr: +``` +K_R_min(wind) = |d| / θ_thr ≈ 0.082/0.10 = 0.82 (SS offset requirement) +K_R_max(ZOH) = K_p / T = 2.00 (discrete stability) +→ viable range: (0.82, 2.00) best: K_R = 1.50 +``` + +Run: +```bash +venv/bin/python3 simulate_stageF.py +``` + +--- + +## 4. Summary of stage results + +| Stage | Task | Key metric | Result | +|-------|------|-----------|--------| +| A | Single-config dynamics, fix_M | windowed geo | 2.03×10⁻⁶ ✓ | +| A e2e | End-to-end (M,V,Dw,g) | windowed geo | converges with `--lambda_V_zero` | +| B | Friction identification | Dw_loss | 5.13×10⁻¹⁴ ✓ | +| C | 4-config + 5th held-out | windowed geo | 2.03×10⁻⁶ on all ✓ | +| D | e₃ stabilisation, no wind | conv time | 0.30 s at K_p=0.10 ✓ | +| E-A | e₁ redirect (no wind) | conv time | 0.25 s ✓ | +| E-B | e₂ hold, no wind | K_p_min | 0.020 ≈ theory 0.021 ✓ | +| E-C | e₂ hold, σ=0.05 wind | 95% conv | K_p=0.20, ratio=1.019 ✓ | +| E-D | e₁ hold, σ=0.05 wind | 95% conv | K_p=0.20, ratio=0.987 ✓ | +| E-E | e₃ hold, σ=0.05 wind | 95% conv | K_p=0.20, ratio=0.999 ✓ | +| F-A | rest at R*=I₃, no wind | 95% conv | K_R=0.10, 2.75s ✓ | +| F-B | rest at R*=Ry(π/4), no wind | 95% conv | K_R=0.10, 2.82s ✓ | +| F-C | sanity check K_R=0 | matches Stage D | 0.30s ✓ | +| F-D | rest at R*=I₃, σ=0.05 wind | 95% conv | K_R=1.50, ratio=1.000 ✓ | + +--- + +## 5. Key non-obvious findings + +**Hamiltonian degeneracy (Stage A)**: For a free body `V = 0` is exact physics, +but the optimizer finds a `(M_wrong, V_compensating)` saddle: `M_loss` diverges while +`geo_loss` converges. Fix: add `--lambda_V_zero 1.0` to regularise `V → 0`. + +**M → αM scale invariance**: The free-body Hamiltonian `H = ½ωᵀM⁻¹ω` is invariant +under `M → αM` because the conserved dynamics only depend on ratios of inertia +components. `FixedInertiaFromState` resolves this by anchoring M to absolute physical +values read from the state. + +**ZOH discrete instability** (Stage E): The controller computes `u` once per +`dt = 0.05 s` step (zero-order hold). The discrete-time eigenvalue for the lightest axis +is `z = 1 − K_p T/I₁`. For stability: `K_p < 2I₁/T = 0.229`. K_p = 0.50 gives +`z = −3.39` — the error grows with alternating sign each step, reaching 14 rad/s. + +**Narrow viable window for e₂ + wind**: Two competing constraints, +`K_p > |d|/ε_conv` (wind rejection) and `K_p < 2I₁/T` (ZOH), leave a window of +width ≈ 0.06 N·m·s/rad for the default parameters. Reducing dt or increasing I₁ +(heavier/larger racket head) widens it. + +**Universal ZOH bound across all axes**: K_p_max = 2·I₁/T = 0.229 is the same for +e₁, e₂, and e₃ hold. It is set by I₁ (the smallest inertia — the mode that +reaches instability first), regardless of which axis is the target. Confirmed: +K_p = 0.50 diverges on all three axes with late_max ≈ 14 rad/s. + +**SS offset ratio is universal**: The steady-state formula ‖δω_ss‖ ≈ |d|/K_p holds +across all three axes (Stage E). + +**Orientation–velocity duality under wind (Stage F)**: With ω* ≠ 0 (Stage E), constant +wind d creates a persistent *velocity* offset ‖e_ω‖_ss = |d|/K_p. With ω* = 0 +(Stage F), the body comes to rest but settles at a tilted orientation: the wind torque +is balanced by the orientation restoring torque K_R · e_R, giving ‖e_R‖_ss = |d|/K_R. +The ratio is exactly 1.000 across all K_R values (0.10–1.50 N·m/rad), confirming the +linear prediction is tight in orientation space as well as velocity space. + +**ZOH bound for the attitude loop (Stage F)**: Adding orientation feedback creates a +2nd-order discrete system with an additional ZOH constraint: K_R < K_p/T. For +K_p=0.10, dt=0.05: K_R_max = 2.00 N·m/rad. Viable wind-rejection window: +K_R ∈ (|d|/θ_thr, K_p/T) ≈ (0.82, 2.00). + +--- + +## 6. File index + +| File | Description | +|------|-------------| +| `network_stageC.py` | `DissipativeSO3HamNODE` with 18D state; `FixedInertiaFromState` | +| `train.py` | Stage A training (single-config, torque-free) | +| `train_e2e.py` | Stage A end-to-end (joint M, V, Dw, g) | +| `train_stageB.py` | Stage B (friction) | +| `train_stageC.py` | Stage C (multi-config, 18D state) | +| `subnet_diagnostics_stageC.py` | Per-sample subnet MSE with embedded inertia | +| `eval_stageC_5th.py` | Held-out 5th geometry generalisation test | +| `controller_stageD.py` | `PDBodyFrameController`: u = K_p(ω*−ω) | +| `simulate_stageD.py` | Stage D closed-loop simulation + K_p scan | +| `simulate_stageE.py` | Stage E original — e₁ redirect, e₂ hold no-wind, e₂ hold wind | +| `simulate_stageE_v2.py` | Stage E v2 — adds e₁+wind (Scenario D) and e₃+wind (Scenario E) | +| `controller_stageF.py` | `GeometricAttitudeController`: `hat`, `vee`, `logm_SO3`, `expm_SO3` | +| `simulate_stageF.py` | Stage F — full-state (R*, ω*) stabilisation; 4 scenarios | + +Related files outside this directory: + +| File | Description | +|------|-------------| +| `envs/tennis_racket_3d.py` | Gymnasium env: Lie-group Heun integrator, direct-torque actuator | +| `envs/tennis_racket_3d_friction.py` | Stage B env with friction | +| `datasets/tennis_racket_3d_datagen.py` | Dataset generator (multi-config) | +| `src/models/3D_SO3_Windy_Pendulum/ph_nn_ode_v2/network.py` | Shared `DissipativeSO3HamNODE` base | +| `src/utils/loss_utils.py` | `rotmat_L2_geodesic_loss_safe`, `traj_rotmat_L2_geodesic_loss_safe` | + +--- + +## 7. How to run + +All commands assume the project virtual environment: + +```bash +PYTHON=/Users/katesur/Projects/LieSPHGP/venv/bin/python3 +``` + +### Generate dataset + +```bash +$PYTHON datasets/tennis_racket_3d_datagen.py \ + --ncfg 4 --timesteps 100 --trials 25 \ + --perturb_std 0.05 --dist_std 0.0 --obs_noise 0.0 +``` + +Output: `data/tennis_data/tr3d_dataset_dist0p0_obs_noise0p0_perturb0p05_ncfg4_steps100.pkl` + +### Stage C training + +```bash +cd src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2 +$PYTHON -u train_stageC.py --fix_M --total_steps 5000 \ + --num_points 5 --n_samples 25 | tee train_stageC.log +``` + +Important flags: + +| Flag | Default | Effect | +|------|---------|--------| +| `--fix_M` | off | Pin M_net to FixedInertiaFromState | +| `--lambda_V_zero` | 0 | Weight for V→0 regularisation (set 1.0 for e2e) | +| `--total_steps` | 5000 | Training steps | +| `--num_points` | 5 | Window size for windowed loss | + +### Stage D / E simulation + +```bash +# K_p scan for e₃ stabilisation +$PYTHON -u simulate_stageD.py --Kp_scan "0.01,0.05,0.10,0.20" + +# Full Stage E (e₁, e₂ hold, wind) +$PYTHON -u simulate_stageE.py + +# Custom convergence threshold or longer run +$PYTHON -u simulate_stageE.py --n_steps 400 --conv_thr 0.25 --wind_std 0.10 +``` + +### Stage E simulation (all five scenarios) + +```bash +$PYTHON -u simulate_stageE_v2.py + +# Override wind strength or run length +$PYTHON -u simulate_stageE_v2.py --n_steps 400 --wind_std 0.10 +``` + +### Stage F simulation (geometric attitude control) + +```bash +$PYTHON -u simulate_stageF.py + +# Longer run or tighter attitude threshold +$PYTHON -u simulate_stageF.py --n_steps 400 --theta_thr 0.05 +``` + +### Evaluate held-out geometry + +```bash +$PYTHON -u eval_stageC_5th.py # auto-finds latest checkpoint +``` + +--- + +## 8. Extended documentation + +Two HackMD-formatted documents in `docs/` cover this work in depth: + +| File | Audience | Contents | +|------|----------|----------| +| `docs/hackmd_advisor.md` | Advisor / professor | Full LaTeX derivations, results tables, colored callouts; suitable for a technical overview meeting | +| `docs/hackmd_reference.md` | Future self | Every formula with derivation, every file, every design decision, every bug and fix, complete run commands | + +--- + +## 9. Requirements + +``` +Python 3.10+ +PyTorch 2.12+ +torchdiffeq +gymnasium +numpy +``` + +Install into the project venv: +```bash +pip install torch torchdiffeq gymnasium numpy +``` diff --git a/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/controller_stageD.py b/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/controller_stageD.py new file mode 100644 index 0000000..e497d07 --- /dev/null +++ b/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/controller_stageD.py @@ -0,0 +1,120 @@ +"""Stage D IDA-PBC controller for tennis racket stabilization. + +The tennis racket environment applies the action directly as a body-frame torque: + τ_total = u + τ_disturbance + ω̇ = M⁻¹(τ_total - ω × Mω) + +Because the input coupling is the identity (direct torque → ω̇ via M⁻¹), the +port-Hamiltonian g_net converges to I₃. The IDA-PBC control law simplifies to: + + u = K_p · (ω* - ω) [pure proportional on angular-velocity error] + +Stability analysis at ω* = ω*₃ e₃ (spinning about handle): + - δω₃ mode: λ = -K_p/I₃ (always stable, τ ≈ I₃/K_p) + - δω₁,δω₂ modes: Re(λ) = -(K_p/2)(1/I₁+1/I₂) (stable for all K_p > 0) + - For K_p > I₁I₂/(I₃-I₁)(I₃-I₂)·ω*₃²·(something): overdamped + Empirical recommendation for cfg0: K_p ≥ 0.07 for no oscillations. + +The model is used only to read g_net when `use_model_g=True`. For Stage D with +g=I₃ (known physics), set `use_model_g=False` (the default). +""" +import torch +import numpy as np + + +class PDBodyFrameController: + """Proportional controller for stabilising ω → ω*. + + Parameters + ---------- + omega_star : array-like, shape (3,) + Target angular velocity in body frame [rad/s]. + Kp : float + Proportional gain [N·m·s/rad]. + clip : float, optional + Hard torque clip [N·m] (matches env max_torque=2.0). + use_model_g : bool + If True, use the model's g_net as input coupling: + u = g⁺(R,I) · K_p(ω* - ω) + If False (default), assume g = I₃: + u = K_p(ω* - ω) [exact for direct-torque actuator] + model : DissipativeSO3HamNODE, optional + Required only when use_model_g=True. + I1, I2, I3 : float, optional + Principal inertia values. Required when use_model_g=True. + """ + def __init__( + self, + omega_star, + Kp: float = 0.10, + clip: float = 2.0, + use_model_g: bool = False, + model=None, + I1: float = None, + I2: float = None, + I3: float = None, + ): + self.omega_star = np.asarray(omega_star, dtype=np.float64).reshape(3) + self.Kp = float(Kp) + self.clip = float(clip) + self.use_model_g = use_model_g + self.model = model + self.I1 = I1 + self.I2 = I2 + self.I3 = I3 + + # ── Internal helpers ─────────────────────────────────────────────────── + + def _tau_desired(self, omega: np.ndarray) -> np.ndarray: + """Compute desired body-frame torque from current ω.""" + return self.Kp * (self.omega_star - omega) + + def _invert_g(self, R_flat: np.ndarray, tau_des: np.ndarray) -> np.ndarray: + """Compute u = g⁺ · τ_des using the learned g_net. + + Uses least-squares pseudo-inverse: min_u ‖g·u − τ_des‖². + """ + with torch.no_grad(): + q_ext = torch.tensor( + np.concatenate([R_flat, [self.I1, self.I2, self.I3]]), + dtype=torch.float32, + ).unsqueeze(0) # (1, 12) + g_mat = self.model.g_net(q_ext) # (1, 3, 3) + g_np = g_mat.squeeze(0).numpy() # (3, 3) + + u, _, _, _ = np.linalg.lstsq(g_np, tau_des, rcond=None) + return u + + # ── Public API ───────────────────────────────────────────────────────── + + def __call__( + self, + R_flat: np.ndarray, + omega: np.ndarray, + ) -> np.ndarray: + """Compute the control torque for the current state. + + Parameters + ---------- + R_flat : (9,) float64 — flattened rotation matrix (row-major) + omega : (3,) float64 — body-frame angular velocity [rad/s] + + Returns + ------- + u : (3,) float64 — body-frame torque to apply [N·m] + """ + tau_des = self._tau_desired(np.asarray(omega, dtype=np.float64)) + + if self.use_model_g and self.model is not None: + u = self._invert_g(np.asarray(R_flat, dtype=np.float64), tau_des) + else: + u = tau_des + + # Clip to actuator limits + return np.clip(u, -self.clip, self.clip) + + # ── Diagnostics ──────────────────────────────────────────────────────── + + def omega_error(self, omega: np.ndarray) -> float: + """‖ω - ω*‖₂ in rad/s.""" + return float(np.linalg.norm(np.asarray(omega) - self.omega_star)) diff --git a/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/controller_stageF.py b/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/controller_stageF.py new file mode 100644 index 0000000..f7d11a7 --- /dev/null +++ b/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/controller_stageF.py @@ -0,0 +1,150 @@ +"""Stage F: Geometric Attitude Controller on SO(3) × ℝ³. + +Extends Stage D (proportional ω-only controller) to stabilise a full +target (R*, ω*) using the geodesic attitude error on SO(3). + +Control law +----------- + u = −K_R · e_R − K_p · e_ω + + e_R = vee( logm(R*ᵀ R) ) ∈ ℝ³ geodesic attitude error + e_ω = ω − ω* ∈ ℝ³ angular velocity error + +IDA-PBC interpretation +---------------------- +This is IDA-PBC with desired Hamiltonian + + H_d(R, ω) = H(R, ω) + ½ K_R ‖e_R‖² + ½ K_p ‖e_ω‖² + +and J_d = J (interconnection unchanged), R_d = R + K_p I (added damping). + +Linearised closed-loop per axis i (second-order oscillator): + ωn_i = √(K_R / Iᵢ) + ζ_i = K_p / (2 √(K_R · Iᵢ)) + +Stability notes +--------------- + • Almost-global: stable for any initial (R₀, ω₀) except the antipodal + set {θ_err = π}, which has measure zero on SO(3). + • Same ZOH bound as Stage D: K_p < 2·I₁/dt (set by the smallest inertia). + • K_R has a much softer ZOH bound ≈ 4·I₁/dt² and is not a practical concern. + +When K_R = 0 this reduces exactly to the Stage D proportional controller. +""" +import numpy as np + + +# ───────────────────────────────────────────────────────────────────────────── +# SO(3) geometry helpers +# ───────────────────────────────────────────────────────────────────────────── + +def hat(v): + """Skew-symmetric matrix (hat map) for v ∈ ℝ³. + + hat(v) ω = v × ω for any ω ∈ ℝ³. + """ + v = np.asarray(v, dtype=np.float64).ravel() + return np.array([[ 0.0, -v[2], v[1]], + [ v[2], 0.0, -v[0]], + [-v[1], v[0], 0.0]]) + + +def vee(Omega): + """Axial vector of a skew-symmetric matrix (inverse of hat). + + vee(hat(v)) = v for any v ∈ ℝ³. + """ + return np.array([Omega[2, 1], Omega[0, 2], Omega[1, 0]]) + + +def logm_SO3(R): + """Matrix logarithm of R ∈ SO(3). + + Returns the unique skew-symmetric Ω ∈ so(3) with ‖vee(Ω)‖ ≤ π + such that expm(Ω) = R. Uses the Rodrigues formula: + + Ω = θ / (2 sin θ) · (R − Rᵀ), θ = arccos((tr R − 1)/2) + + Special cases: + θ ≈ 0 (near identity): first-order approximation (R−Rᵀ)/2. + θ ≈ π (antipodal): reconstruct axis from symmetric part of R. + """ + cos_theta = float(np.clip((np.trace(R) - 1.0) / 2.0, -1.0, 1.0)) + theta = float(np.arccos(cos_theta)) + + if theta < 1e-7: + return (R - R.T) / 2.0 + + if abs(theta - np.pi) < 1e-4: + # R = 2 n nᵀ − I → (R + I)/2 = n nᵀ + sym = (R + np.eye(3)) / 2.0 + i = int(np.argmax(np.diag(sym))) + n = sym[:, i] / np.sqrt(max(sym[i, i], 1e-15)) + return theta * hat(n) + + return (theta / (2.0 * np.sin(theta))) * (R - R.T) + + +def expm_SO3(Omega): + """Matrix exponential of Ω ∈ so(3) (Rodrigues formula). + + Useful for constructing target rotations, e.g.: + R_target = expm_SO3(hat([0, np.pi/4, 0])) # 45° about e₂ + """ + v = vee(Omega) + theta = float(np.linalg.norm(v)) + if theta < 1e-7: + return np.eye(3) + Omega + n = v / theta + return (np.cos(theta) * np.eye(3) + + np.sin(theta) * hat(n) + + (1.0 - np.cos(theta)) * np.outer(n, n)) + + +# ───────────────────────────────────────────────────────────────────────────── +# Controller +# ───────────────────────────────────────────────────────────────────────────── + +class GeometricAttitudeController: + """Full-state stabiliser for (R*, ω*) ∈ SO(3) × ℝ³. + + Parameters + ---------- + R_star : (3,3) target orientation in SO(3). + omega_star : (3,) target angular velocity in the body frame [rad/s]. + K_R : orientation gain [N·m/rad]. Zero → Stage D proportional ctrl. + K_p : angular velocity gain [N·m·s/rad]. + clip : torque saturation [N·m]. + """ + + def __init__(self, R_star, omega_star, K_R=0.10, K_p=0.10, clip=2.0): + self.R_star = np.asarray(R_star, dtype=np.float64).reshape(3, 3) + self.omega_star = np.asarray(omega_star, dtype=np.float64).reshape(3) + self.K_R = float(K_R) + self.K_p = float(K_p) + self.clip = float(clip) + + def attitude_error(self, R): + """e_R = vee(logm(R*ᵀ R)) ∈ ℝ³ (zero when R = R*).""" + return vee(logm_SO3(self.R_star.T @ np.asarray(R).reshape(3, 3))) + + def __call__(self, R_flat, omega): + """Compute control torque u = −K_R e_R − K_p e_ω, clipped to ±clip.""" + R = np.asarray(R_flat, dtype=np.float64).reshape(3, 3) + omega = np.asarray(omega, dtype=np.float64).ravel() + e_R = self.attitude_error(R) + e_w = omega - self.omega_star + u = -self.K_R * e_R - self.K_p * e_w + return np.clip(u, -self.clip, self.clip) + + def errors(self, R_flat, omega): + """Returns (‖e_R‖ [rad], ‖e_ω‖ [rad/s]). + + Both are zero at the target (R*, ω*). + ‖e_R‖ lies in [0, π]; ‖e_R‖ = π is the worst-case antipodal point. + """ + R = np.asarray(R_flat, dtype=np.float64).reshape(3, 3) + omega = np.asarray(omega, dtype=np.float64).ravel() + e_R = self.attitude_error(R) + e_w = omega - self.omega_star + return float(np.linalg.norm(e_R)), float(np.linalg.norm(e_w)) diff --git a/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/eval_stageC_5th.py b/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/eval_stageC_5th.py new file mode 100644 index 0000000..13da6b6 --- /dev/null +++ b/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/eval_stageC_5th.py @@ -0,0 +1,206 @@ +"""Stage C held-out generalization test — 5th racket geometry. + +Generates fresh trajectories for a racket geometry not seen during training, +loads the saved Stage C --fix_M checkpoint, and evaluates geodesic + subnet MSE. + +The 4 training geometries all have head_a ∈ {0.195, 0.210} and head_b=0.135. +The 5th config uses head_a=0.225, head_b=0.115 (longer and narrower head), +plus different mass / handle parameters — genuinely out-of-distribution. + +Pass criterion (same as Stage C windowed training metric): + windowed geo_loss < 0.01 rad² +""" +import torch, glob, argparse +import numpy as np +import os, sys + +THIS_FILE_DIR = os.path.dirname(os.path.abspath(__file__)) +PROJECT_ROOT = os.path.abspath(os.path.join(THIS_FILE_DIR, '../../../..')) + +PENDULUM_ODE_DIR = os.path.join( + PROJECT_ROOT, 'src/models/3D_SO3_Windy_Pendulum/ph_nn_ode_v2') + +sys.path.insert(0, PROJECT_ROOT) +sys.path.insert(0, os.path.join(PROJECT_ROOT, 'src/utils')) +sys.path.insert(0, os.path.join(PROJECT_ROOT, 'datasets')) +sys.path.insert(0, PENDULUM_ODE_DIR) +sys.path.insert(0, THIS_FILE_DIR) + +from torchdiffeq import odeint +from network_stageC import DissipativeSO3HamNODE, FixedInertiaFromState +from subnet_diagnostics_stageC import subnet_physics_mse_stageC +from tennis_racket_3d_datagen import sample_tennis_racket_3d, arrange_data +from train_stageC import augment_with_inertia, strip_inertia +from loss_utils import ( + rotmat_L2_geodesic_loss_safe as rotmat_L2_geodesic_loss, + traj_rotmat_L2_geodesic_loss_safe as traj_rotmat_L2_geodesic_loss, +) + + +# ── Training config inertia for reference ───────────────────────────────────── +TRAIN_CONFIGS = [ + dict(I1=0.005719, I2=0.011212, I3=0.013221, + head_a=0.195, head_b=0.135, handle_length=0.25, total_mass=0.29, head_mass_ratio=0.70), + dict(I1=0.006757, I2=0.008643, I3=0.011019, + head_a=0.195, head_b=0.135, handle_length=0.25, total_mass=0.30, head_mass_ratio=0.80), + dict(I1=0.004738, I2=0.014832, I3=0.016495, + head_a=0.195, head_b=0.135, handle_length=0.28, total_mass=0.28, head_mass_ratio=0.60), + dict(I1=0.006336, I2=0.012067, I3=0.014693, + head_a=0.210, head_b=0.135, handle_length=0.25, total_mass=0.29, head_mass_ratio=0.70), +] + + +def get_args(): + p = argparse.ArgumentParser() + p.add_argument('--ckpt_path', type=str, default=None, + help='path to .tar checkpoint; defaults to latest Stage C fixM run') + p.add_argument('--head_a', type=float, default=0.225, + help='head semi-axis along e₁ (long), m') + p.add_argument('--head_b', type=float, default=0.115, + help='head semi-axis along e₂ (short), m') + p.add_argument('--handle_length', type=float, default=0.27) + p.add_argument('--total_mass', type=float, default=0.31) + p.add_argument('--head_mass_ratio', type=float, default=0.75) + p.add_argument('--n_samples', type=int, default=25) + p.add_argument('--timesteps', type=int, default=100) + p.add_argument('--num_points', type=int, default=5, + help='window size for windowed eval (must match training)') + p.add_argument('--seed', type=int, default=42) + p.add_argument('--solver', type=str, default='rk4') + return p.parse_args() + + +def find_latest_ckpt(): + run_base = os.path.join(THIS_FILE_DIR, 'data', 'run_tr3d_stageC_fp32') + subdirs = sorted(glob.glob(os.path.join(run_base, '*fixM*'))) + if not subdirs: + raise FileNotFoundError(f"No fixM Stage C runs under {run_base}") + latest = subdirs[-1] + final = os.path.join(latest, 'tr3d-so3ham-rk4-5p.tar') + if not os.path.exists(final): + raise FileNotFoundError(f"Final checkpoint not found: {final}") + return final + + +def main(): + args = get_args() + float_type = torch.float32 + torch.set_default_dtype(torch.float32) + device = torch.device('cpu') + + # ── 5th config geometry ──────────────────────────────────────────────── + racket_kwargs = dict( + head_a=args.head_a, + head_b=args.head_b, + handle_length=args.handle_length, + total_mass=args.total_mass, + head_mass_ratio=args.head_mass_ratio, + ) + + # Instantiate env just to read inertia (no data generated here) + from envs.tennis_racket_3d import tennis_racket_3d as Env + env5 = Env(**racket_kwargs) + iinfo = env5.get_inertia_info() + env5.close() + I1, I2, I3 = iinfo['I1'], iinfo['I2'], iinfo['I3'] + + print("=" * 60) + print("Stage C — 5th config held-out generalization test") + print("=" * 60) + print(f"\n5th config geometry:") + print(f" head_a={args.head_a:.3f} head_b={args.head_b:.3f} " + f"handle={args.handle_length:.3f} mass={args.total_mass:.3f} " + f"hmr={args.head_mass_ratio:.2f}") + print(f" I1={I1:.6f} I2={I2:.6f} I3={I3:.6f} kg·m²") + print(f" I2/I1={I2/I1:.2f} asym=(I3-I1)/I2={(I3-I1)/I2:.2f}") + + print(f"\nTraining configs (for comparison):") + for ci, c in enumerate(TRAIN_CONFIGS): + print(f" cfg{ci}: I1={c['I1']:.6f} I2={c['I2']:.6f} I3={c['I3']:.6f} " + f"head_a={c['head_a']:.3f} head_b={c['head_b']:.3f} " + f"asym={(c['I3']-c['I1'])/c['I2']:.2f}") + + # ── Generate trajectories ────────────────────────────────────────────── + print(f"\nGenerating {args.n_samples} trajectories " + f"(T={args.timesteps}, seed={args.seed + 99999})...") + trajs, tspan = sample_tennis_racket_3d( + seed=args.seed + 99999, + timesteps=args.timesteps, + trials=args.n_samples, + perturb_std=0.05, + racket_kwargs=racket_kwargs, + ) + # trajs: (T, N, 15) + print(f" Generated: {trajs.shape} t ∈ [{tspan[0]:.3f}, {tspan[-1]:.3f}]s") + + # Augment with embedded inertia: (T, N, 15) → (T, N, 18) + trajs_aug = augment_with_inertia(trajs, I1, I2, I3) + + # ── Load model ───────────────────────────────────────────────────────── + ckpt_path = args.ckpt_path or find_latest_ckpt() + print(f"\nLoading checkpoint:\n {ckpt_path}") + + model = DissipativeSO3HamNODE( + device=device, u_dim=3, init_gain=0.5, inertia_dim=3).to(device) + model.M_net = FixedInertiaFromState().to(device).to(float_type) + + state_dict = torch.load(ckpt_path, map_location=device) + missing, unexpected = model.load_state_dict(state_dict, strict=False) + if unexpected: + print(f" WARNING: unexpected keys in checkpoint: {unexpected}") + if missing: + print(f" INFO: keys not in checkpoint (expected for FixedInertiaFromState): {missing}") + model.eval() + + # ── Windowed eval (same 5-point window as training) ──────────────────── + x_arr, t_eval = arrange_data(trajs_aug[np.newaxis], tspan, + num_points=args.num_points) + x_cat = np.concatenate(x_arr, axis=1) # (num_points, B_windows, 18) + x_cat_t = torch.tensor(x_cat, dtype=float_type) + t_ev_t = torch.tensor(t_eval, dtype=float_type) + + with torch.no_grad(): + x_hat = odeint(model, x_cat_t[0], t_ev_t, method=args.solver) + tgt = strip_inertia(x_cat_t[1:]) + tgt_hat = strip_inertia(x_hat[1:]) + w_loss, w_l2, w_geo = rotmat_L2_geodesic_loss(tgt, tgt_hat, split=[9, 3, 3]) + subnet = subnet_physics_mse_stageC(model, x_hat) + + print(f"\n--- Windowed eval ({args.num_points}-point window, " + f"N={x_cat.shape[1]} windows) ---") + print(f" geo={w_geo:.4e} L2={w_l2:.4e} total={w_loss:.4e}") + print(f" subnet MSE M={subnet['M_loss']:.3e} V={subnet['V_loss']:.3e} " + f"Dw={subnet['Dw_loss']:.3e} g={subnet['g_loss']:.3e}") + + # ── Full-trajectory eval ─────────────────────────────────────────────── + x_full_t = torch.tensor(trajs_aug, dtype=float_type) # (T, N, 18) + t_full_t = torch.tensor(tspan, dtype=float_type) + + with torch.no_grad(): + x_full_hat = odeint(model, x_full_t[0], t_full_t, method=args.solver) + tl, ll, gl = traj_rotmat_L2_geodesic_loss( + strip_inertia(x_full_t), strip_inertia(x_full_hat), split=[9, 3, 3]) + + gl_sum = gl.sum(dim=0) # (N,) — geo summed over all T timesteps + print(f"\n--- Full-trajectory eval (T={args.timesteps}, N={args.n_samples}) ---") + print(f" geo = {gl_sum.mean():.4e} ± {gl_sum.std():.4e}") + print(f" min = {gl_sum.min():.4e} max = {gl_sum.max():.4e}") + + # ── Verdict ──────────────────────────────────────────────────────────── + passed = w_geo.item() < 0.01 + print(f"\n{'PASS ✓' if passed else 'FAIL ✗'} windowed geo = {w_geo:.4e} " + f"(threshold 0.01 rad²)") + + return { + 'I1': I1, 'I2': I2, 'I3': I3, + 'windowed_geo': w_geo.item(), + 'windowed_M': subnet['M_loss'], + 'windowed_V': subnet['V_loss'], + 'windowed_Dw': subnet['Dw_loss'], + 'full_traj_geo_mean': gl_sum.mean().item(), + 'full_traj_geo_std': gl_sum.std().item(), + } + + +if __name__ == '__main__': + main() diff --git a/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/network_stageC.py b/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/network_stageC.py new file mode 100644 index 0000000..e137a99 --- /dev/null +++ b/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/network_stageC.py @@ -0,0 +1,149 @@ +"""Stage C network: SO(3) Hamiltonian NODE with inertia-augmented state. + +Extends network.py by adding `inertia_dim` extra inputs to all sub-networks. +The state vector is (vec(R)∈ℝ⁹, I₁,I₂,I₃∈ℝ³, ω∈ℝ³, u∈ℝ³) = ℝ¹⁸. + +Sub-networks receive q_ext = (vec(R), I₁,I₂,I₃) ∈ ℝ¹² so they can learn +the inertia dependence: M_net learns diag(1/I₁,1/I₂,1/I₃) directly from +the embedded inertia values. This breaks the Stage A/B single-config +restriction and enables multi-config generalization. + +Differences from network.py: + - `inertia_dim` constructor parameter (default 3); rotmatdim = 9 + inertia_dim. + - forward() splits x as (q_ext, q_dot, u) with rotmatdim=12. + - Cross-product geometry uses q_ext[:, :9] (the rotation subspace only). + - dHdq is 12D; only dHdq[:, :9] enters the Lie bracket. + - JVP tangent is padded: (dR, 0₃) since d(inertia)/dt = 0. + - Return is 18D: cat(dq₉, zeros₃, ddq₃, zeros₃). + - FixedInertiaFromState reads I₁,I₂,I₃ from q_ext[:, 9:12] directly. +""" +import torch +import os, sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../utils'))) +from ode_nn_models import MLP, PSD, MatrixNet + + +class FixedInertiaFromState(torch.nn.Module): + """Drop-in M_net for Stage C: reads I₁,I₂,I₃ from q_ext[:, 9:12]. + + The state is (vec(R), I₁,I₂,I₃, ω, u), so each sample carries its own + inertia. This module returns diag(1/I₁, 1/I₂, 1/I₃) per sample without + any learnable parameters — use it to pin M_net to ground truth and check + that V/Dw/g learn correctly before releasing M_net. + """ + def forward(self, q_ext): + I_vals = q_ext[:, 9:12] # (N, 3) + return torch.diag_embed(1.0 / I_vals) # (N, 3, 3) + + +class DissipativeSO3HamNODE(torch.nn.Module): + def __init__(self, M_net=None, Dw_net=None, V_net=None, g_net=None, + device=None, u_dim=3, init_gain=0.01, friction=True, + inertia_dim=3): + super().__init__() + self.inertia_dim = inertia_dim # extra dims appended after vec(R) + self.rotmatdim = 9 + inertia_dim # input dim to all sub-nets + self.angveldim = 3 + self.friction = friction + + # epsilon=1.0 (fp32 stability fix) + self.M_net = M_net or PSD(self.rotmatdim, 20, self.angveldim, + init_gain=init_gain, epsilon=1.0).to(device) + + if friction: + self.Dw_net = Dw_net or PSD(self.rotmatdim, 20, self.angveldim, + init_gain=init_gain, epsilon=0.0).to(device) + self.V_net = V_net or MLP(self.rotmatdim, 20, 1, init_gain=init_gain).to(device) + + self.u_dim = u_dim + if g_net is None: + if u_dim == 1: + self.g_net = MLP(self.rotmatdim, 20, self.angveldim).to(device) + else: + self.g_net = MatrixNet(self.rotmatdim, 20, self.angveldim * self.u_dim, + shape=(self.angveldim, self.u_dim), + init_gain=init_gain).to(device) + else: + self.g_net = g_net + + self.device = device + self.nfe = 0 + + def forward(self, t, x): + with torch.enable_grad(): + self.nfe += 1 + bs = x.shape[0] + zero_vec = torch.zeros(bs, self.u_dim, dtype=x.dtype, device=x.device) + zero_inertia = torch.zeros(bs, self.inertia_dim, dtype=x.dtype, device=x.device) + + if not x.requires_grad: + x = x.detach().requires_grad_(True) + + # x = (q_ext[rotmatdim], q_dot[angveldim], u[u_dim]) + # = (vec(R)[9], I₁I₂I₃[3], ω[3], u[3]) for inertia_dim=3 + q_ext, q_dot, u = torch.split( + x, [self.rotmatdim, self.angveldim, self.u_dim], dim=1) + + # Rotation sub-block of q_ext — used for all geometry (cross products). + R_flat = q_ext[:, :9] # (B, 9) + + # ── Two-call cat-split structure (same as base network) ── + M_q = self.M_net(q_ext) + q_dot_aug = torch.unsqueeze(q_dot, dim=2) + p = torch.squeeze(torch.linalg.solve(M_q, q_dot_aug), dim=2) + + q_p = torch.cat((q_ext, p), dim=1) + q_ext_split, p = torch.split(q_p, [self.rotmatdim, self.angveldim], dim=1) + R_flat_split = q_ext_split[:, :9] # rotation part after cat-split + + M_q_inv = self.M_net(q_ext_split) + V_q = self.V_net(q_ext_split) + g_q = self.g_net(q_ext_split) + Dw_q = self.Dw_net(q_ext_split) + + p_aug = torch.unsqueeze(p, dim=2) + H = (torch.squeeze(torch.matmul(torch.transpose(p_aug, 1, 2), + torch.matmul(M_q_inv, p_aug))) / 2.0 + + torch.squeeze(V_q)) + + dH = torch.autograd.grad(H.sum(), q_p, create_graph=True)[0] + # dH has shape (B, rotmatdim + angveldim); first rotmatdim entries are + # dH/d(q_ext) = [dH/d(vec R), dH/d(I₁), dH/d(I₂), dH/d(I₃)]. + # Only the rotation sub-block enters the Lie bracket. + dHdq_ext, dHdp = torch.split(dH, [self.rotmatdim, self.angveldim], dim=1) + dHdR = dHdq_ext[:, :9] # (B, 9) — gradient w.r.t. vec(R) only + + if self.u_dim == 1: + F = g_q * u + else: + F = torch.squeeze(torch.matmul(g_q, torch.unsqueeze(u, dim=2))) + + # ── #2: Batched cross products ── + q_3x3 = R_flat_split.view(-1, 3, 3) # (B, 3, 3) + dHdp_b = dHdp.unsqueeze(1).expand(-1, 3, -1) # (B, 3, 3) + dq = torch.linalg.cross(q_3x3, dHdp_b, dim=2).reshape(-1, 9) + + dHdR_3x3 = dHdR.view(-1, 3, 3) # (B, 3, 3) + grav = torch.linalg.cross(q_3x3, dHdR_3x3, dim=2).sum(dim=1) # (B, 3) + + if self.friction: + dp = (torch.linalg.cross(p, dHdp, dim=1) + + grav + - torch.squeeze(torch.matmul(Dw_q, torch.unsqueeze(dHdp, dim=2))) + + F) + else: + dp = torch.linalg.cross(p, dHdp, dim=1) + grav + F + + # ── #1: dM_inv/dt via JVP ── + # Inertia is constant: d(I₁,I₂,I₃)/dt = 0. Pad dq to rotmatdim with zeros. + dq_ext_dt = torch.cat( + [dq, torch.zeros(bs, self.inertia_dim, dtype=x.dtype, device=x.device)], + dim=1) + _, dM_inv_dt = torch.func.jvp(self.M_net, (q_ext_split,), (dq_ext_dt,)) + + ddq = (torch.squeeze(torch.matmul(M_q_inv, torch.unsqueeze(dp, dim=2)), dim=2) + + torch.squeeze(torch.matmul(dM_inv_dt, torch.unsqueeze(p, dim=2)), dim=2)) + + # Return state derivative: (dR[9], d_inertia=0[3], dω[3], du=0[3]) + return torch.cat((dq, zero_inertia, ddq, zero_vec), dim=1) diff --git a/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/simulate_stageD.py b/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/simulate_stageD.py new file mode 100644 index 0000000..7cbca55 --- /dev/null +++ b/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/simulate_stageD.py @@ -0,0 +1,281 @@ +"""Stage D simulation: IDA-PBC stabilisation of the tennis racket. + +Task: starting from Dzhanibekov tumbling (spin about unstable e₂), drive +the angular velocity to a stable target ω* = (0, 0, ω₀) (handle spin, e₃). + +The controller is a simple proportional law u = K_p·(ω* - ω), which is +exact for a direct-torque actuator (g = I₃). No model is needed. + +Linearised analysis (about ω* for cfg0): + δω₃ mode: τ₃ = I₃/K_p ≈ 0.132 s + δω₁,δω₂ modes: Re(λ) = -K_p(1/I₁+1/I₂)/2 ≈ -13.2 s⁻¹ (τ ≈ 0.076 s) + Overdamped for K_p > 0.071; K_p = 0.10 is in the overdamped regime. + +Expected convergence: ‖ω-ω*‖ decays to <0.01 rad/s within ~0.5 s (10 steps). + +Usage: + /Users/katesur/Projects/LieSPHGP/venv/bin/python3 -u simulate_stageD.py + +Options: + --Kp Proportional gain [N·m·s/rad] (default 0.1) + --omega_star Target spin rate [rad/s] (default 2π) + --n_trials Number of trials (default 20) + --n_steps Steps per trial (dt=0.05s) (default 200 = 10s) + --axis_start Initial spin axis: 1=e₂ unstable, 0=e₁, 2=e₃ (default 1) + --conv_thr Convergence threshold ‖ω-ω*‖ (default 0.5 rad/s) + --seed RNG seed (default 0) + --Kp_scan Comma-separated Kp values to scan (overrides --Kp) +""" +import argparse +import os +import sys +import numpy as np + +THIS_FILE_DIR = os.path.dirname(os.path.abspath(__file__)) +PROJECT_ROOT = os.path.abspath(os.path.join(THIS_FILE_DIR, '../../../..')) +sys.path.insert(0, PROJECT_ROOT) +sys.path.insert(0, os.path.join(PROJECT_ROOT, 'envs')) +sys.path.insert(0, THIS_FILE_DIR) + +from envs.tennis_racket_3d import tennis_racket_3d as Env +from controller_stageD import PDBodyFrameController + + +# ───────────────────────────────────────────────────────────────────────────── +# Argument parsing +# ───────────────────────────────────────────────────────────────────────────── + +def get_args(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument('--Kp', type=float, default=0.10) + p.add_argument('--omega_star', type=float, default=2 * np.pi, + help='Target spin speed about e₃ [rad/s]') + p.add_argument('--n_trials', type=int, default=20) + p.add_argument('--n_steps', type=int, default=200) + p.add_argument('--axis_start', type=int, default=1, + help='Initial axis: 0=e₁ stable, 1=e₂ unstable, 2=e₃ target') + p.add_argument('--conv_thr', type=float, default=0.50, + help='Convergence threshold on ‖ω-ω*‖ [rad/s]') + p.add_argument('--seed', type=int, default=0) + p.add_argument('--Kp_scan', type=str, default=None, + help='Comma-separated Kp values to scan, e.g. "0.01,0.05,0.10,0.20"') + return p.parse_args() + + +# ───────────────────────────────────────────────────────────────────────────── +# Single Kp run +# ───────────────────────────────────────────────────────────────────────────── + +def run_trials( + Kp: float, + omega_star_vec: np.ndarray, + n_trials: int, + n_steps: int, + axis_start: int, + conv_thr: float, + seed: int, + env: Env, + verbose: bool = True, +) -> dict: + """Run n_trials closed-loop episodes and return convergence statistics.""" + ctrl = PDBodyFrameController( + omega_star=omega_star_vec, + Kp=Kp, + use_model_g=False, + ) + + dt = env.dt + t_grid = np.arange(n_steps + 1) * dt # (n_steps+1,) includes t=0 + + # Storage: omega_error[trial, step] + err_traces = np.zeros((n_trials, n_steps + 1), dtype=np.float64) + conv_steps = [] + + for trial in range(n_trials): + obs, _ = env.reset(seed=seed + trial, + options={"axis": axis_start}) + R_flat = obs[:9].astype(np.float64) + omega = obs[9:12].astype(np.float64) + + err_traces[trial, 0] = ctrl.omega_error(omega) + + for step in range(1, n_steps + 1): + u = ctrl(R_flat, omega) + obs, _, _, _, _ = env.step(u) + R_flat = obs[:9].astype(np.float64) + omega = obs[9:12].astype(np.float64) + err_traces[trial, step] = ctrl.omega_error(omega) + + # First step where error drops below threshold + below = np.where(err_traces[trial] < conv_thr)[0] + conv_steps.append(int(below[0]) if len(below) else n_steps) + + conv_steps = np.array(conv_steps, dtype=float) + final_err = err_traces[:, -1] + + result = { + 'Kp': Kp, + 'err_traces': err_traces, # (n_trials, n_steps+1) + 't_grid': t_grid, + 'conv_steps': conv_steps, # (n_trials,) steps to conv + 'conv_times': conv_steps * dt, # (n_trials,) seconds to conv + 'final_err': final_err, # (n_trials,) final omega error + 'pct_converged': 100 * np.mean(conv_steps < n_steps), + } + + if verbose: + _print_result(result, n_steps, dt, conv_thr) + + return result + + +def _print_result(result: dict, n_steps: int, dt: float, conv_thr: float): + Kp = result['Kp'] + ct = result['conv_times'] + fe = result['final_err'] + err0 = result['err_traces'][:, 0] + pct = result['pct_converged'] + + print(f"\n Kp = {Kp:.4f}") + print(f" Initial ‖ω−ω*‖: {err0.mean():.3f} ± {err0.std():.3f} rad/s") + print(f" Final ‖ω−ω*‖: {fe.mean():.4f} ± {fe.std():.4f} rad/s") + print(f" Time to ‖ω−ω*‖ < {conv_thr:.2f}: " + f"{ct.mean():.3f} ± {ct.std():.3f} s " + f"(min={ct.min():.3f} max={ct.max():.3f})") + print(f" Converged: {pct:.0f}% of trials " + f"within {n_steps*dt:.1f}s") + + +# ───────────────────────────────────────────────────────────────────────────── +# Linearisation sanity check (analytical, no simulation needed) +# ───────────────────────────────────────────────────────────────────────────── + +def print_linearisation(I1, I2, I3, Kp, omega_star): + """Print eigenvalues of the linearised closed-loop at ω*.""" + # δω₃ mode + lam3 = -Kp / I3 + tau3 = -1.0 / lam3 + + # δω₁,δω₂ modes + # A = [-Kp/I1 w*(I2-I3)/I1] + # [w*(I3-I1)/I2 -Kp/I2 ] + a11 = -Kp / I1 + a12 = omega_star * (I2 - I3) / I1 + a21 = omega_star * (I3 - I1) / I2 + a22 = -Kp / I2 + + tr_A = a11 + a22 + det_A = a11 * a22 - a12 * a21 + disc = tr_A**2 - 4 * det_A + + if disc >= 0: + lam_real = (tr_A + np.sqrt(disc)) / 2, (tr_A - np.sqrt(disc)) / 2 + print(f" δω₁,₂ modes: overdamped λ = {lam_real[0]:.3f}, {lam_real[1]:.3f} s⁻¹ " + f"(τ_slow = {-1/lam_real[0]:.4f} s, τ_fast = {-1/lam_real[1]:.4f} s)") + else: + Re_lam = tr_A / 2 + Im_lam = np.sqrt(-disc) / 2 + freq_hz = Im_lam / (2 * np.pi) + print(f" δω₁,₂ modes: oscillatory λ = {Re_lam:.3f} ± {Im_lam:.3f}i s⁻¹ " + f"(τ = {-1/Re_lam:.4f} s, f = {freq_hz:.3f} Hz)") + print(f" δω₃ mode: λ = {lam3:.3f} s⁻¹ (τ = {tau3:.4f} s)") + # Overdamped when disc = Kp²(1/I1-1/I2)² - 4|a12·a21| ≥ 0 + # Threshold: Kp_od = 2√|a12·a21| / |1/I1-1/I2| (a12,a21 already carry ω*) + overdamped_kp = 2 * np.sqrt(abs(a12 * a21)) / abs(1/I1 - 1/I2) + print(f" Overdamped threshold Kp ≥ {overdamped_kp:.4f}") + + +# ───────────────────────────────────────────────────────────────────────────── +# Main +# ───────────────────────────────────────────────────────────────────────────── + +def main(): + args = get_args() + + omega_star_vec = np.array([0.0, 0.0, args.omega_star], dtype=np.float64) + + print("=" * 62) + print("Stage D — IDA-PBC Tennis Racket Stabilisation") + print("=" * 62) + print(f"\nTarget: ω* = (0, 0, {args.omega_star:.4f}) rad/s [{args.omega_star/(2*np.pi):.3f} rev/s]") + print(f"Initial: axis {args.axis_start} " + + {0: "(e₁ stable)", 1: "(e₂ UNSTABLE — Dzhanibekov)", 2: "(e₃ target)"}[args.axis_start]) + print(f"Trials: {args.n_trials} × {args.n_steps} steps " + f"(dt=0.05s → {args.n_steps*0.05:.1f}s per trial)") + print(f"Convergence threshold: ‖ω-ω*‖ < {args.conv_thr:.2f} rad/s") + + # Build env (default cfg0 geometry — head_a=0.195, head_b=0.135, etc.) + env = Env(disturbance_torque_std=0.0, seed=args.seed) + iinfo = env.get_inertia_info() + I1, I2, I3 = iinfo['I1'], iinfo['I2'], iinfo['I3'] + + print(f"\nRacket (cfg0):") + print(f" I1={I1:.6f} I2={I2:.6f} I3={I3:.6f} kg·m²") + print(f" I2/I1={I2/I1:.2f} asym=(I3-I1)/I2={(I3-I1)/I2:.2f}") + print(f" head_a={iinfo['head_a']:.3f} head_b={iinfo['head_b']:.3f} " + f"handle={iinfo['handle_length']:.3f} mass={iinfo['total_mass']:.3f} kg") + + # ── Linearisation analysis ───────────────────────────────────────────── + print(f"\n── Linearised stability (Kp={args.Kp:.4f}) ──") + print_linearisation(I1, I2, I3, args.Kp, args.omega_star) + + # ── Simulation ──────────────────────────────────────────────────────── + Kp_list = ([float(k) for k in args.Kp_scan.split(',')] + if args.Kp_scan else [args.Kp]) + + all_results = {} + for Kp in Kp_list: + if args.Kp_scan: + print(f"\n── Scan Kp = {Kp:.4f} ──") + print_linearisation(I1, I2, I3, Kp, args.omega_star) + result = run_trials( + Kp=Kp, + omega_star_vec=omega_star_vec, + n_trials=args.n_trials, + n_steps=args.n_steps, + axis_start=args.axis_start, + conv_thr=args.conv_thr, + seed=args.seed, + env=env, + verbose=True, + ) + all_results[Kp] = result + + # ── Convergence curve summary (profile at select timesteps) ─────────── + # Among fully-converged Kps, prefer fastest convergence time; else fallback to last + fully_conv = [k for k, r in all_results.items() if r['pct_converged'] >= 95.0] + if fully_conv: + best_Kp = min(fully_conv, key=lambda k: all_results[k]['conv_times'].mean()) + else: + best_Kp = Kp_list[-1] + best = all_results[best_Kp] + + print(f"\n── Error profile (best Kp={best_Kp:.4f}, avg over {args.n_trials} trials) ──") + print(f"{'t [s]':>8} {'mean ‖ω-ω*‖':>14} {'max ‖ω-ω*‖':>14}") + profile_steps = list(range(0, min(21, args.n_steps + 1), 2)) + if args.n_steps not in profile_steps: + profile_steps.append(args.n_steps) + for step in profile_steps: + t_s = step * env.dt + mean_ = best['err_traces'][:, step].mean() + max_ = best['err_traces'][:, step].max() + print(f" {t_s:6.2f} {mean_:14.6f} {max_:14.6f}") + + # ── Verdict ──────────────────────────────────────────────────────────── + final_mean = best['final_err'].mean() + pct_conv = best['pct_converged'] + passed = pct_conv >= 95.0 and final_mean < 0.10 + + print(f"\n{'='*62}") + print(f"{'PASS ✓' if passed else 'FAIL ✗'} " + f"Kp={best_Kp:.4f} | " + f"{pct_conv:.0f}% converged | " + f"final ‖ω-ω*‖ = {final_mean:.4f} rad/s") + print(f"{'='*62}") + + env.close() + return all_results + + +if __name__ == '__main__': + main() diff --git a/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/simulate_stageE.py b/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/simulate_stageE.py new file mode 100644 index 0000000..507e560 --- /dev/null +++ b/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/simulate_stageE.py @@ -0,0 +1,457 @@ +"""Stage E simulation: e₁ stabilisation, e₂ hold, wind robustness. + +Three scenarios: + A. Redirect + stabilise e₁ (ω* = (2π,0,0), start from e₂ tumbling, no wind) + B. Hold e₂ (ω* = (0,2π,0), start near e₂ with small perturbation, no wind) + — Kp scan to find empirical minimum Kp + C. Hold e₂ with stochastic wind (disturbance_torque_std=0.05 N·m) + +Linearised stability (free body, Euler's equations): + + For ω* = ω* eₙ, perturbations in the orthogonal plane form a 2D system. + Off-diagonal Euler coupling → eigenvalues: + + λ² = ω*² (I_L - I_T)(I_S - I_T) / (I_L · I_S) + + where I_T = inertia about target axis, I_S, I_L = smaller/larger inertia. + + e₁ (smallest I): I_T < I_S < I_L → (I_L-I_T)>0, (I_S-I_T)>0, product>0 but wait + Actual formula for ω* = (ω,0,0): + δω̇₂ = (I₃-I₁)ω/I₂ · δω₃ + δω̇₃ = (I₁-I₂)ω/I₃ · δω₂ + λ² = (I₃-I₁)(I₁-I₂)ω²/(I₂I₃) → λ² < 0 (stable oscillations) + + e₂ (middle I): λ² > 0 (saddle, unstable) — Dzhanibekov + e₃ (largest I): λ² < 0 (stable oscillations) + + Kp_min for e₂ stabilisation: + det(A_cl) > 0 requires: Kp² > |(I₂-I₃)(I₁-I₂)| · ω*₂² + Kp_min = ω*₂ · √(|I₂-I₃| · |I₁-I₂|) + +Usage: + /Users/katesur/Projects/LieSPHGP/venv/bin/python3 -u simulate_stageE.py + +Options: + --omega_star Spin rate [rad/s] (default 2π) + --n_trials Trials per Kp value (default 20) + --n_steps Steps per trial (dt=0.05s) (default 200 = 10s) + --hold_perturb ‖δω‖ for hold IC [rad/s] (default 0.3) + --wind_std Disturbance std [N·m] (default 0.05) + --conv_thr Convergence threshold (default 0.5 rad/s) + --seed RNG seed (default 0) +""" +import argparse +import os +import sys +import numpy as np + +THIS_FILE_DIR = os.path.dirname(os.path.abspath(__file__)) +PROJECT_ROOT = os.path.abspath(os.path.join(THIS_FILE_DIR, '../../../..')) +sys.path.insert(0, PROJECT_ROOT) +sys.path.insert(0, os.path.join(PROJECT_ROOT, 'envs')) +sys.path.insert(0, THIS_FILE_DIR) + +from envs.tennis_racket_3d import tennis_racket_3d as Env +from controller_stageD import PDBodyFrameController + + +# ───────────────────────────────────────────────────────────────────────────── +# Argument parsing +# ───────────────────────────────────────────────────────────────────────────── + +def get_args(): + p = argparse.ArgumentParser() + p.add_argument('--omega_star', type=float, default=2 * np.pi) + p.add_argument('--n_trials', type=int, default=20) + p.add_argument('--n_steps', type=int, default=200) + p.add_argument('--hold_perturb', type=float, default=0.30, + help='std of ω perturbation for hold IC [rad/s]') + p.add_argument('--wind_std', type=float, default=0.05, + help='disturbance_torque_std for wind scenario') + p.add_argument('--conv_thr', type=float, default=0.50) + p.add_argument('--seed', type=int, default=0) + return p.parse_args() + + +# ───────────────────────────────────────────────────────────────────────────── +# Stability helpers +# ───────────────────────────────────────────────────────────────────────────── + +def free_body_eigenvalue(I1, I2, I3, omega_star, target_axis): + """Eigenvalue of the free-body linearisation about ω* = ω_star · e_axis. + + Returns the real part λ_real and imaginary part λ_imag of one eigenvalue. + If λ_real > 0: unstable. If λ_real = 0: neutral (oscillatory). + """ + # δω̇_a = (I_c - I_t) ω* / I_a · δω_b + # δω̇_b = (I_t - I_b) ω* / I_b ... use component form + I = [I1, I2, I3] + # Indices of the two orthogonal axes + a, b = [(target_axis + 1) % 3, (target_axis + 2) % 3] + It = I[target_axis] + + # From Euler equations, linearised: + # Ia * δω̇_a = (Ib - It) * ω* * δω_b [from Ia ω̇_a = (Ij - Ik) ωj ωk] + # Ib * δω̇_b = (It - Ia) * ω* * δω_a + Ia, Ib = I[a], I[b] + m12 = (Ib - It) * omega_star / Ia + m21 = (It - Ia) * omega_star / Ib + + lam_sq = m12 * m21 # = (Ib-It)(It-Ia)ω*² / (IaIb) + + if lam_sq > 1e-15: + lam = np.sqrt(lam_sq) + return lam, 0.0 # real ±λ (unstable saddle) + elif lam_sq < -1e-15: + return 0.0, np.sqrt(-lam_sq) # ±i|λ| (stable oscillation) + else: + return 0.0, 0.0 # borderline (marginal) + + +def kp_min_e2(I1, I2, I3, omega_star): + """Minimum Kp to stabilise spinning about the intermediate axis. + + From det(A_cl) > 0: + Kp² > |(I₂-I₃)(I₁-I₂)| · ω*² + Kp_min = ω* · √(|I₂-I₃| · |I₁-I₂|) + """ + return omega_star * np.sqrt(abs(I2 - I3) * abs(I1 - I2)) + + +def kp_max_zoh(I1, dt): + """Maximum Kp before ZOH discrete-time instability on the I₁ (smallest) axis. + + With ZOH period T = dt (env.dt), the discrete-time eigenvalue for the I₁ axis is: + z = 1 - Kp · T / I₁ + Stable iff |z| < 1, i.e. Kp < 2 · I₁ / T. + """ + return 2.0 * I1 / dt + + +def print_stability_table(I1, I2, I3, omega_star, dt): + """Print free-body eigenvalues and Kp bounds for all three axes.""" + names = ["e₁ (smallest I, long head)", "e₂ (middle I, short head — UNSTABLE)", + "e₃ (largest I, handle)"] + kp_max = kp_max_zoh(I1, dt) + print(f" {'Axis':<36} {'λ_free':>12} {'K_p_min':>10} {'Status'}") + print(f" {'-'*36} {'-'*12} {'-'*10} {'-'*15}") + for axis in range(3): + lr, li = free_body_eigenvalue(I1, I2, I3, omega_star, axis) + if lr > 1e-10: + lam_str = f"+{lr:.3f} (real)" + status = "UNSTABLE" + elif li > 1e-10: + lam_str = f"±{li:.3f}i" + status = "stable" + else: + lam_str = "0" + status = "marginal" + + kp_m = (kp_min_e2(I1, I2, I3, omega_star) if axis == 1 else 0.0) + print(f" {names[axis]:<36} {lam_str:>12} {kp_m:>10.4f} {status}") + print(f"\n ZOH stability bound (K_p_max = 2·I₁/T): {kp_max:.4f} N·m·s/rad " + f"[T = dt = {dt:.3f} s]") + print(f" Safe operating range for e₂ hold: " + f"({kp_min_e2(I1, I2, I3, omega_star):.4f}, {kp_max:.4f})") + + +# ───────────────────────────────────────────────────────────────────────────── +# Core simulation +# ───────────────────────────────────────────────────────────────────────────── + +def run_trials( + env: Env, + omega_star_vec: np.ndarray, + Kp: float, + n_trials: int, + n_steps: int, + conv_thr: float, + seed: int, + ic_mode: str = "axis", # "axis" or "hold" + ic_axis: int = 1, # for ic_mode="axis" + hold_perturb: float = 0.3, # for ic_mode="hold" + rng_seed_offset: int = 0, +) -> dict: + """ + ic_mode = "axis": env.reset(options={"axis": ic_axis}) + ic_mode = "hold": omega_init = omega_star_vec + N(0, hold_perturb²) + with a random R from env.reset() + """ + ctrl = PDBodyFrameController(omega_star=omega_star_vec, Kp=Kp) + rng = np.random.default_rng(seed + rng_seed_offset) + + dt = env.dt + err_traces = np.zeros((n_trials, n_steps + 1), dtype=np.float64) + disturbance_norms = np.zeros(n_trials, dtype=np.float64) + conv_steps = [] + + for trial in range(n_trials): + if ic_mode == "axis": + obs, _ = env.reset(seed=seed + trial, + options={"axis": ic_axis}) + else: # hold + noise = rng.normal(0.0, hold_perturb, 3) + omega_0 = omega_star_vec + noise + # Reset gets a random R but we override omega + obs, _ = env.reset(seed=seed + trial, + options={"omega_init": omega_0}) + + disturbance_norms[trial] = float( + np.linalg.norm(env._disturbance_torque)) + + R_flat = obs[:9].astype(np.float64) + omega = obs[9:12].astype(np.float64) + err_traces[trial, 0] = ctrl.omega_error(omega) + + for step in range(1, n_steps + 1): + u = ctrl(R_flat, omega) + obs, _, _, _, _ = env.step(u) + R_flat = obs[:9].astype(np.float64) + omega = obs[9:12].astype(np.float64) + err_traces[trial, step] = ctrl.omega_error(omega) + + below = np.where(err_traces[trial] < conv_thr)[0] + conv_steps.append(int(below[0]) if len(below) else n_steps) + + conv_steps = np.array(conv_steps, dtype=float) + final_err = err_traces[:, -1] + late_max = err_traces[:, n_steps // 2:].max(axis=1) # max in 2nd half + + return { + 'Kp': Kp, + 'err_traces': err_traces, + 't_grid': np.arange(n_steps + 1) * dt, + 'conv_steps': conv_steps, + 'conv_times': conv_steps * dt, + 'final_err': final_err, + 'late_max': late_max, # stability indicator + 'disturbance_norms': disturbance_norms, + 'pct_converged': 100 * np.mean(conv_steps < n_steps), + } + + +def print_trial_result(r: dict, n_steps: int, dt: float, conv_thr: float, + show_dist: bool = False): + Kp = r['Kp'] + ct = r['conv_times'] + fe = r['final_err'] + lm = r['late_max'] + pct = r['pct_converged'] + err0 = r['err_traces'][:, 0] + + label = ("UNSTABLE" if (pct < 50 or lm.mean() > 1.0) + else "CONVERGED" if pct >= 95 else "PARTIAL") + + print(f"\n Kp = {Kp:.4f} [{label}]") + print(f" Initial ‖ω−ω*‖: {err0.mean():.3f} ± {err0.std():.3f} rad/s") + print(f" Final ‖ω−ω*‖: {fe.mean():.4f} ± {fe.std():.4f} rad/s") + print(f" Late max ‖ω−ω*‖: {lm.mean():.4f} ± {lm.std():.4f} rad/s " + f"(2nd half of traj)") + if pct > 0: + print(f" Conv time (<{conv_thr:.1f}): " + f"{ct.mean():.3f} ± {ct.std():.3f} s " + f"(min={ct.min():.2f} max={ct.max():.2f})") + print(f" Converged: {pct:.0f}% of {len(fe)} trials " + f"within {n_steps*dt:.1f}s") + if show_dist: + dn = r['disturbance_norms'] + print(f" Disturbance |d|: {dn.mean():.4f} ± {dn.std():.4f} N·m " + f"(expected SS offset ≈ |d|/Kp = {dn.mean()/Kp:.3f} rad/s)") + + +# ───────────────────────────────────────────────────────────────────────────── +# Scenario runners +# ───────────────────────────────────────────────────────────────────────────── + +def scenario_A(args, I1, I2, I3, env_clean): + """Redirect from e₂ tumbling → e₁. No wind.""" + print(f"\n{'─'*60}") + print(f"Scenario A: Stabilise e₁ (from e₂ tumbling, no wind)") + print(f" Target ω* = ({args.omega_star:.4f}, 0, 0) rad/s") + print(f" IC: axis=1 (Dzhanibekov tumbling) Kp = 0.10") + print(f"{'─'*60}") + + omega_star = np.array([args.omega_star, 0.0, 0.0]) + Kp = 0.10 + + r = run_trials( + env=env_clean, omega_star_vec=omega_star, Kp=Kp, + n_trials=args.n_trials, n_steps=args.n_steps, + conv_thr=args.conv_thr, seed=args.seed, + ic_mode="axis", ic_axis=1, + ) + print_trial_result(r, args.n_steps, env_clean.dt, args.conv_thr) + + # Error profile + print(f"\n Error profile (avg over {args.n_trials} trials):") + print(f" {'t [s]':>8} {'mean ‖ω-ω*‖':>14} {'max ‖ω-ω*‖':>14}") + steps = list(range(0, min(21, args.n_steps + 1), 2)) + if args.n_steps not in steps: + steps.append(args.n_steps) + for step in steps: + t_s = step * env_clean.dt + mean_ = r['err_traces'][:, step].mean() + max_ = r['err_traces'][:, step].max() + print(f" {t_s:8.2f} {mean_:14.6f} {max_:14.6f}") + + return r + + +def scenario_B(args, I1, I2, I3, env_clean): + """Hold e₂ from near-target IC. No wind. K_p scan.""" + print(f"\n{'─'*60}") + print(f"Scenario B: Hold e₂ (start near ω*, no wind)") + print(f" Target ω* = (0, {args.omega_star:.4f}, 0) rad/s") + print(f" IC: ω* + N(0, {args.hold_perturb:.2f}²) [perturbation about e₂]") + Kp_th = kp_min_e2(I1, I2, I3, args.omega_star) + print(f" Theoretical K_p_min = {Kp_th:.4f} N·m·s/rad") + print(f"{'─'*60}") + + omega_star = np.array([0.0, args.omega_star, 0.0]) + + # Scan: below threshold, near threshold, well above, and beyond ZOH bound + Kp_list = [0.010, 0.015, 0.020, 0.025, 0.030, 0.050, 0.100, 0.500] + results = {} + for Kp in Kp_list: + r = run_trials( + env=env_clean, omega_star_vec=omega_star, Kp=Kp, + n_trials=args.n_trials, n_steps=args.n_steps, + conv_thr=args.conv_thr, seed=args.seed, + ic_mode="hold", hold_perturb=args.hold_perturb, + ) + print_trial_result(r, args.n_steps, env_clean.dt, args.conv_thr) + results[Kp] = r + + # Find empirical threshold (smallest Kp with ≥95% convergence) + kp_emp = None + for Kp in sorted(results): + r = results[Kp] + if r['pct_converged'] >= 95 and r['late_max'].mean() < 0.5: + kp_emp = Kp + break + + kp_max = kp_max_zoh(I1, env_clean.dt) + print(f"\n Summary:") + print(f" Theoretical K_p_min (saddle) = {Kp_th:.4f}") + print(f" Theoretical K_p_max (ZOH) = {kp_max:.4f} " + f"[2·I₁/T = 2×{I1:.4f}/{env_clean.dt:.3f}]") + if kp_emp is not None: + print(f" Empirical K_p_min ≤ {kp_emp:.4f} " + f"(first Kp with ≥95% convergence + stable hold)") + else: + print(f" Empirical K_p_min > {max(Kp_list):.4f} " + "(no Kp in scan converged reliably)") + # Check ZOH instability empirically + if 0.500 in results: + r_zoh = results[0.500] + zoh_unstable = r_zoh['late_max'].mean() > 1.0 and r_zoh['final_err'].mean() > 1.0 + print(f" ZOH instability at K_p=0.50: " + f"{'CONFIRMED ✓' if zoh_unstable else 'not observed'}" + f" (late_max={r_zoh['late_max'].mean():.2f} rad/s)") + + return results + + +def scenario_C(args, I1, I2, I3, wind_std): + """Hold e₂ with stochastic wind. K_p scan.""" + print(f"\n{'─'*60}") + print(f"Scenario C: Hold e₂ with wind") + print(f" Target ω* = (0, {args.omega_star:.4f}, 0) rad/s") + print(f" IC: ω* + N(0, {args.hold_perturb:.2f}²)") + print(f" Wind: disturbance_torque_std = {wind_std:.3f} N·m (constant per episode)") + Kp_th = kp_min_e2(I1, I2, I3, args.omega_star) + print(f" K_p_min (no-wind) = {Kp_th:.4f} N·m·s/rad") + print(f"{'─'*60}") + + omega_star = np.array([0.0, args.omega_star, 0.0]) + Kp_list = [0.05, 0.10, 0.20, 0.50] + + results = {} + for Kp in Kp_list: + # Build a fresh env with wind enabled + env_wind = Env(disturbance_torque_std=wind_std, seed=args.seed) + r = run_trials( + env=env_wind, omega_star_vec=omega_star, Kp=Kp, + n_trials=args.n_trials, n_steps=args.n_steps, + conv_thr=args.conv_thr, seed=args.seed, + ic_mode="hold", hold_perturb=args.hold_perturb, + ) + print_trial_result(r, args.n_steps, env_wind.dt, args.conv_thr, + show_dist=True) + env_wind.close() + results[Kp] = r + + # Compare actual final error vs predicted steady-state offset + print(f"\n Final-error vs predicted SS offset (|d|/Kp):") + print(f" {'Kp':>8} {'final err (mean)':>18} {'predicted SS':>14} {'ratio':>8}") + for Kp, r in results.items(): + dn = r['disturbance_norms'].mean() + pred = dn / Kp + fe = r['final_err'].mean() + ratio = fe / pred if pred > 0 else float('nan') + print(f" {Kp:8.4f} {fe:18.4f} {pred:14.4f} {ratio:8.3f}") + + return results + + +# ───────────────────────────────────────────────────────────────────────────── +# Main +# ───────────────────────────────────────────────────────────────────────────── + +def main(): + args = get_args() + + print("=" * 62) + print("Stage E — e₁ stabilisation, e₂ hold, wind robustness") + print("=" * 62) + + # Default env (no wind, cfg0) + env_clean = Env(disturbance_torque_std=0.0, seed=args.seed) + iinfo = env_clean.get_inertia_info() + I1, I2, I3 = iinfo['I1'], iinfo['I2'], iinfo['I3'] + + print(f"\nRacket (cfg0):") + print(f" I1={I1:.6f} I2={I2:.6f} I3={I3:.6f} kg·m²") + print(f" I2/I1={I2/I1:.2f} asym=(I3-I1)/I2={(I3-I1)/I2:.2f}") + print(f"\nTrials: {args.n_trials} × {args.n_steps} steps " + f"(dt={env_clean.dt:.2f}s → {args.n_steps*env_clean.dt:.1f}s per trial)") + print(f"Convergence threshold: ‖ω-ω*‖ < {args.conv_thr:.2f} rad/s") + + # Stability table + print(f"\n── Free-body stability (ω* = {args.omega_star:.4f} rad/s, cfg0) ──") + print_stability_table(I1, I2, I3, args.omega_star, env_clean.dt) + + # Run all three scenarios + resA = scenario_A(args, I1, I2, I3, env_clean) + resB = scenario_B(args, I1, I2, I3, env_clean) + resC = scenario_C(args, I1, I2, I3, args.wind_std) + + # ── Final verdict ────────────────────────────────────────────────────── + A_pass = resA['pct_converged'] >= 95 + # Scenario B: K_p=0.10 is always in the scan + B_r010 = resB.get(0.100, resB[max(resB)]) + B_pass = B_r010['pct_converged'] >= 95 and B_r010['late_max'].mean() < 0.5 + # Find best Kp in Scenario C (highest convergence + stable hold) + C_best_kp = max( + (k for k in resC if resC[k]['late_max'].mean() < 0.5), + key=lambda k: resC[k]['pct_converged'], + default=None, + ) + C_r_best = resC[C_best_kp] if C_best_kp is not None else resC[min(resC)] + C_pass = C_r_best['pct_converged'] >= 95 + + print(f"\n{'='*62}") + print(f"Stage E summary (Kp=0.10 unless noted)") + print(f" A — e₁ stabilisation: {'PASS ✓' if A_pass else 'FAIL ✗'}") + print(f" B — e₂ hold (no wind): {'PASS ✓' if B_pass else 'FAIL ✗'}") + kp_max = kp_max_zoh(I1, env_clean.dt) + print(f" C — e₂ hold (wind {args.wind_std:.2f} N·m): " + f"{'PASS ✓' if C_pass else 'FAIL ✗'}" + + (f" [best Kp={C_best_kp:.3f}]" if C_best_kp else "")) + print(f"{'='*62}") + + env_clean.close() + + +if __name__ == '__main__': + main() diff --git a/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/simulate_stageE_v2.py b/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/simulate_stageE_v2.py new file mode 100644 index 0000000..afe5333 --- /dev/null +++ b/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/simulate_stageE_v2.py @@ -0,0 +1,575 @@ +"""Stage E simulation: e₁ stabilisation, e₂ hold, wind robustness. + +Five scenarios: + A. Redirect + stabilise e₁ (ω* = (2π,0,0), start from e₂ tumbling, no wind) + B. Hold e₂ (ω* = (0,2π,0), start near e₂ with small perturbation, no wind) + — Kp scan to find empirical minimum Kp + C. Hold e₂ with stochastic wind (disturbance_torque_std=0.05 N·m) + D. Hold e₁ with stochastic wind (disturbance_torque_std=0.05 N·m) + — e₁ is stable; K_p_min=0 (no saddle), only ZOH bound K_p_max=2·I₁/T + E. Hold e₃ with stochastic wind (disturbance_torque_std=0.05 N·m) + — e₃ is stable; same ZOH bound applies + +Linearised stability (free body, Euler's equations): + + For ω* = ω* eₙ, perturbations in the orthogonal plane form a 2D system. + Off-diagonal Euler coupling → eigenvalues: + + λ² = ω*² (I_L - I_T)(I_S - I_T) / (I_L · I_S) + + where I_T = inertia about target axis, I_S, I_L = smaller/larger inertia. + + e₁ (smallest I): I_T < I_S < I_L → (I_L-I_T)>0, (I_S-I_T)>0, product>0 but wait + Actual formula for ω* = (ω,0,0): + δω̇₂ = (I₃-I₁)ω/I₂ · δω₃ + δω̇₃ = (I₁-I₂)ω/I₃ · δω₂ + λ² = (I₃-I₁)(I₁-I₂)ω²/(I₂I₃) → λ² < 0 (stable oscillations) + + e₂ (middle I): λ² > 0 (saddle, unstable) — Dzhanibekov + e₃ (largest I): λ² < 0 (stable oscillations) + + Kp_min for e₂ stabilisation: + det(A_cl) > 0 requires: Kp² > |(I₂-I₃)(I₁-I₂)| · ω*₂² + Kp_min = ω*₂ · √(|I₂-I₃| · |I₁-I₂|) + +Usage: + /Users/katesur/Projects/LieSPHGP/venv/bin/python3 -u simulate_stageE.py + +Options: + --omega_star Spin rate [rad/s] (default 2π) + --n_trials Trials per Kp value (default 20) + --n_steps Steps per trial (dt=0.05s) (default 200 = 10s) + --hold_perturb ‖δω‖ for hold IC [rad/s] (default 0.3) + --wind_std Disturbance std [N·m] (default 0.05) + --conv_thr Convergence threshold (default 0.5 rad/s) + --seed RNG seed (default 0) +""" +import argparse +import os +import sys +import numpy as np + +THIS_FILE_DIR = os.path.dirname(os.path.abspath(__file__)) +PROJECT_ROOT = os.path.abspath(os.path.join(THIS_FILE_DIR, '../../../..')) +sys.path.insert(0, PROJECT_ROOT) +sys.path.insert(0, os.path.join(PROJECT_ROOT, 'envs')) +sys.path.insert(0, THIS_FILE_DIR) + +from envs.tennis_racket_3d import tennis_racket_3d as Env +from controller_stageD import PDBodyFrameController + + +# ───────────────────────────────────────────────────────────────────────────── +# Argument parsing +# ───────────────────────────────────────────────────────────────────────────── + +def get_args(): + p = argparse.ArgumentParser() + p.add_argument('--omega_star', type=float, default=2 * np.pi) + p.add_argument('--n_trials', type=int, default=20) + p.add_argument('--n_steps', type=int, default=200) + p.add_argument('--hold_perturb', type=float, default=0.30, + help='std of ω perturbation for hold IC [rad/s]') + p.add_argument('--wind_std', type=float, default=0.05, + help='disturbance_torque_std for wind scenario') + p.add_argument('--conv_thr', type=float, default=0.50) + p.add_argument('--seed', type=int, default=0) + return p.parse_args() + + +# ───────────────────────────────────────────────────────────────────────────── +# Stability helpers +# ───────────────────────────────────────────────────────────────────────────── + +def free_body_eigenvalue(I1, I2, I3, omega_star, target_axis): + """Eigenvalue of the free-body linearisation about ω* = ω_star · e_axis. + + Returns the real part λ_real and imaginary part λ_imag of one eigenvalue. + If λ_real > 0: unstable. If λ_real = 0: neutral (oscillatory). + """ + # δω̇_a = (I_c - I_t) ω* / I_a · δω_b + # δω̇_b = (I_t - I_b) ω* / I_b ... use component form + I = [I1, I2, I3] + # Indices of the two orthogonal axes + a, b = [(target_axis + 1) % 3, (target_axis + 2) % 3] + It = I[target_axis] + + # From Euler equations, linearised: + # Ia * δω̇_a = (Ib - It) * ω* * δω_b [from Ia ω̇_a = (Ij - Ik) ωj ωk] + # Ib * δω̇_b = (It - Ia) * ω* * δω_a + Ia, Ib = I[a], I[b] + m12 = (Ib - It) * omega_star / Ia + m21 = (It - Ia) * omega_star / Ib + + lam_sq = m12 * m21 # = (Ib-It)(It-Ia)ω*² / (IaIb) + + if lam_sq > 1e-15: + lam = np.sqrt(lam_sq) + return lam, 0.0 # real ±λ (unstable saddle) + elif lam_sq < -1e-15: + return 0.0, np.sqrt(-lam_sq) # ±i|λ| (stable oscillation) + else: + return 0.0, 0.0 # borderline (marginal) + + +def kp_min_e2(I1, I2, I3, omega_star): + """Minimum Kp to stabilise spinning about the intermediate axis. + + From det(A_cl) > 0: + Kp² > |(I₂-I₃)(I₁-I₂)| · ω*² + Kp_min = ω* · √(|I₂-I₃| · |I₁-I₂|) + """ + return omega_star * np.sqrt(abs(I2 - I3) * abs(I1 - I2)) + + +def kp_max_zoh(I1, dt): + """Maximum Kp before ZOH discrete-time instability on the I₁ (smallest) axis. + + With ZOH period T = dt (env.dt), the discrete-time eigenvalue for the I₁ axis is: + z = 1 - Kp · T / I₁ + Stable iff |z| < 1, i.e. Kp < 2 · I₁ / T. + """ + return 2.0 * I1 / dt + + +def print_stability_table(I1, I2, I3, omega_star, dt): + """Print free-body eigenvalues and Kp bounds for all three axes.""" + names = ["e₁ (smallest I, long head)", "e₂ (middle I, short head — UNSTABLE)", + "e₃ (largest I, handle)"] + kp_max = kp_max_zoh(I1, dt) + print(f" {'Axis':<36} {'λ_free':>12} {'K_p_min':>10} {'Status'}") + print(f" {'-'*36} {'-'*12} {'-'*10} {'-'*15}") + for axis in range(3): + lr, li = free_body_eigenvalue(I1, I2, I3, omega_star, axis) + if lr > 1e-10: + lam_str = f"+{lr:.3f} (real)" + status = "UNSTABLE" + elif li > 1e-10: + lam_str = f"±{li:.3f}i" + status = "stable" + else: + lam_str = "0" + status = "marginal" + + kp_m = (kp_min_e2(I1, I2, I3, omega_star) if axis == 1 else 0.0) + print(f" {names[axis]:<36} {lam_str:>12} {kp_m:>10.4f} {status}") + print(f"\n ZOH stability bound (K_p_max = 2·I₁/T): {kp_max:.4f} N·m·s/rad " + f"[T = dt = {dt:.3f} s]") + print(f" Safe operating range for e₂ hold: " + f"({kp_min_e2(I1, I2, I3, omega_star):.4f}, {kp_max:.4f})") + + +# ───────────────────────────────────────────────────────────────────────────── +# Core simulation +# ───────────────────────────────────────────────────────────────────────────── + +def run_trials( + env: Env, + omega_star_vec: np.ndarray, + Kp: float, + n_trials: int, + n_steps: int, + conv_thr: float, + seed: int, + ic_mode: str = "axis", # "axis" or "hold" + ic_axis: int = 1, # for ic_mode="axis" + hold_perturb: float = 0.3, # for ic_mode="hold" + rng_seed_offset: int = 0, +) -> dict: + """ + ic_mode = "axis": env.reset(options={"axis": ic_axis}) + ic_mode = "hold": omega_init = omega_star_vec + N(0, hold_perturb²) + with a random R from env.reset() + """ + ctrl = PDBodyFrameController(omega_star=omega_star_vec, Kp=Kp) + rng = np.random.default_rng(seed + rng_seed_offset) + + dt = env.dt + err_traces = np.zeros((n_trials, n_steps + 1), dtype=np.float64) + disturbance_norms = np.zeros(n_trials, dtype=np.float64) + conv_steps = [] + + for trial in range(n_trials): + if ic_mode == "axis": + obs, _ = env.reset(seed=seed + trial, + options={"axis": ic_axis}) + else: # hold + noise = rng.normal(0.0, hold_perturb, 3) + omega_0 = omega_star_vec + noise + # Reset gets a random R but we override omega + obs, _ = env.reset(seed=seed + trial, + options={"omega_init": omega_0}) + + disturbance_norms[trial] = float( + np.linalg.norm(env._disturbance_torque)) + + R_flat = obs[:9].astype(np.float64) + omega = obs[9:12].astype(np.float64) + err_traces[trial, 0] = ctrl.omega_error(omega) + + for step in range(1, n_steps + 1): + u = ctrl(R_flat, omega) + obs, _, _, _, _ = env.step(u) + R_flat = obs[:9].astype(np.float64) + omega = obs[9:12].astype(np.float64) + err_traces[trial, step] = ctrl.omega_error(omega) + + below = np.where(err_traces[trial] < conv_thr)[0] + conv_steps.append(int(below[0]) if len(below) else n_steps) + + conv_steps = np.array(conv_steps, dtype=float) + final_err = err_traces[:, -1] + late_max = err_traces[:, n_steps // 2:].max(axis=1) # max in 2nd half + + return { + 'Kp': Kp, + 'err_traces': err_traces, + 't_grid': np.arange(n_steps + 1) * dt, + 'conv_steps': conv_steps, + 'conv_times': conv_steps * dt, + 'final_err': final_err, + 'late_max': late_max, # stability indicator + 'disturbance_norms': disturbance_norms, + 'pct_converged': 100 * np.mean(conv_steps < n_steps), + } + + +def print_trial_result(r: dict, n_steps: int, dt: float, conv_thr: float, + show_dist: bool = False): + Kp = r['Kp'] + ct = r['conv_times'] + fe = r['final_err'] + lm = r['late_max'] + pct = r['pct_converged'] + err0 = r['err_traces'][:, 0] + + label = ("UNSTABLE" if (pct < 50 or lm.mean() > 1.0) + else "CONVERGED" if pct >= 95 else "PARTIAL") + + print(f"\n Kp = {Kp:.4f} [{label}]") + print(f" Initial ‖ω−ω*‖: {err0.mean():.3f} ± {err0.std():.3f} rad/s") + print(f" Final ‖ω−ω*‖: {fe.mean():.4f} ± {fe.std():.4f} rad/s") + print(f" Late max ‖ω−ω*‖: {lm.mean():.4f} ± {lm.std():.4f} rad/s " + f"(2nd half of traj)") + if pct > 0: + print(f" Conv time (<{conv_thr:.1f}): " + f"{ct.mean():.3f} ± {ct.std():.3f} s " + f"(min={ct.min():.2f} max={ct.max():.2f})") + print(f" Converged: {pct:.0f}% of {len(fe)} trials " + f"within {n_steps*dt:.1f}s") + if show_dist: + dn = r['disturbance_norms'] + print(f" Disturbance |d|: {dn.mean():.4f} ± {dn.std():.4f} N·m " + f"(expected SS offset ≈ |d|/Kp = {dn.mean()/Kp:.3f} rad/s)") + + +# ───────────────────────────────────────────────────────────────────────────── +# Scenario runners +# ───────────────────────────────────────────────────────────────────────────── + +def scenario_A(args, I1, I2, I3, env_clean): + """Redirect from e₂ tumbling → e₁. No wind.""" + print(f"\n{'─'*60}") + print(f"Scenario A: Stabilise e₁ (from e₂ tumbling, no wind)") + print(f" Target ω* = ({args.omega_star:.4f}, 0, 0) rad/s") + print(f" IC: axis=1 (Dzhanibekov tumbling) Kp = 0.10") + print(f"{'─'*60}") + + omega_star = np.array([args.omega_star, 0.0, 0.0]) + Kp = 0.10 + + r = run_trials( + env=env_clean, omega_star_vec=omega_star, Kp=Kp, + n_trials=args.n_trials, n_steps=args.n_steps, + conv_thr=args.conv_thr, seed=args.seed, + ic_mode="axis", ic_axis=1, + ) + print_trial_result(r, args.n_steps, env_clean.dt, args.conv_thr) + + # Error profile + print(f"\n Error profile (avg over {args.n_trials} trials):") + print(f" {'t [s]':>8} {'mean ‖ω-ω*‖':>14} {'max ‖ω-ω*‖':>14}") + steps = list(range(0, min(21, args.n_steps + 1), 2)) + if args.n_steps not in steps: + steps.append(args.n_steps) + for step in steps: + t_s = step * env_clean.dt + mean_ = r['err_traces'][:, step].mean() + max_ = r['err_traces'][:, step].max() + print(f" {t_s:8.2f} {mean_:14.6f} {max_:14.6f}") + + return r + + +def scenario_B(args, I1, I2, I3, env_clean): + """Hold e₂ from near-target IC. No wind. K_p scan.""" + print(f"\n{'─'*60}") + print(f"Scenario B: Hold e₂ (start near ω*, no wind)") + print(f" Target ω* = (0, {args.omega_star:.4f}, 0) rad/s") + print(f" IC: ω* + N(0, {args.hold_perturb:.2f}²) [perturbation about e₂]") + Kp_th = kp_min_e2(I1, I2, I3, args.omega_star) + print(f" Theoretical K_p_min = {Kp_th:.4f} N·m·s/rad") + print(f"{'─'*60}") + + omega_star = np.array([0.0, args.omega_star, 0.0]) + + # Scan: below threshold, near threshold, well above, and beyond ZOH bound + Kp_list = [0.010, 0.015, 0.020, 0.025, 0.030, 0.050, 0.100, 0.500] + results = {} + for Kp in Kp_list: + r = run_trials( + env=env_clean, omega_star_vec=omega_star, Kp=Kp, + n_trials=args.n_trials, n_steps=args.n_steps, + conv_thr=args.conv_thr, seed=args.seed, + ic_mode="hold", hold_perturb=args.hold_perturb, + ) + print_trial_result(r, args.n_steps, env_clean.dt, args.conv_thr) + results[Kp] = r + + # Find empirical threshold (smallest Kp with ≥95% convergence) + kp_emp = None + for Kp in sorted(results): + r = results[Kp] + if r['pct_converged'] >= 95 and r['late_max'].mean() < 0.5: + kp_emp = Kp + break + + kp_max = kp_max_zoh(I1, env_clean.dt) + print(f"\n Summary:") + print(f" Theoretical K_p_min (saddle) = {Kp_th:.4f}") + print(f" Theoretical K_p_max (ZOH) = {kp_max:.4f} " + f"[2·I₁/T = 2×{I1:.4f}/{env_clean.dt:.3f}]") + if kp_emp is not None: + print(f" Empirical K_p_min ≤ {kp_emp:.4f} " + f"(first Kp with ≥95% convergence + stable hold)") + else: + print(f" Empirical K_p_min > {max(Kp_list):.4f} " + "(no Kp in scan converged reliably)") + # Check ZOH instability empirically + if 0.500 in results: + r_zoh = results[0.500] + zoh_unstable = r_zoh['late_max'].mean() > 1.0 and r_zoh['final_err'].mean() > 1.0 + print(f" ZOH instability at K_p=0.50: " + f"{'CONFIRMED ✓' if zoh_unstable else 'not observed'}" + f" (late_max={r_zoh['late_max'].mean():.2f} rad/s)") + + return results + + +def scenario_C(args, I1, I2, I3, wind_std): + """Hold e₂ with stochastic wind. K_p scan.""" + print(f"\n{'─'*60}") + print(f"Scenario C: Hold e₂ with wind") + print(f" Target ω* = (0, {args.omega_star:.4f}, 0) rad/s") + print(f" IC: ω* + N(0, {args.hold_perturb:.2f}²)") + print(f" Wind: disturbance_torque_std = {wind_std:.3f} N·m (constant per episode)") + Kp_th = kp_min_e2(I1, I2, I3, args.omega_star) + print(f" K_p_min (no-wind) = {Kp_th:.4f} N·m·s/rad") + print(f"{'─'*60}") + + omega_star = np.array([0.0, args.omega_star, 0.0]) + Kp_list = [0.05, 0.10, 0.20, 0.50] + + results = {} + for Kp in Kp_list: + # Build a fresh env with wind enabled + env_wind = Env(disturbance_torque_std=wind_std, seed=args.seed) + r = run_trials( + env=env_wind, omega_star_vec=omega_star, Kp=Kp, + n_trials=args.n_trials, n_steps=args.n_steps, + conv_thr=args.conv_thr, seed=args.seed, + ic_mode="hold", hold_perturb=args.hold_perturb, + ) + print_trial_result(r, args.n_steps, env_wind.dt, args.conv_thr, + show_dist=True) + env_wind.close() + results[Kp] = r + + # Compare actual final error vs predicted steady-state offset + print(f"\n Final-error vs predicted SS offset (|d|/Kp):") + print(f" {'Kp':>8} {'final err (mean)':>18} {'predicted SS':>14} {'ratio':>8}") + for Kp, r in results.items(): + dn = r['disturbance_norms'].mean() + pred = dn / Kp + fe = r['final_err'].mean() + ratio = fe / pred if pred > 0 else float('nan') + print(f" {Kp:8.4f} {fe:18.4f} {pred:14.4f} {ratio:8.3f}") + + return results + + +def scenario_D_e1_wind(args, I1, I2, I3, wind_std): + """Hold e₁ with stochastic wind. K_p scan. + + e₁ is the stable (smallest-I) axis. The free-body eigenvalue is imaginary + (stable oscillation), so K_p_min = 0 in the absence of wind. With wind the + steady-state offset is |d|/K_p, so K_p must be large enough to keep that + offset below the convergence threshold — and small enough to stay below the + ZOH discrete-stability bound K_p_max = 2·I₁/T. + """ + print(f"\n{'─'*60}") + print(f"Scenario D: Hold e₁ with wind") + print(f" Target ω* = ({args.omega_star:.4f}, 0, 0) rad/s") + print(f" IC: ω* + N(0, {args.hold_perturb:.2f}²)") + print(f" Wind: disturbance_torque_std = {wind_std:.3f} N·m (constant per episode)") + kp_max = kp_max_zoh(I1, 0.05) + lr, li = free_body_eigenvalue(I1, I2, I3, args.omega_star, target_axis=0) + lam_str = f"±{li:.3f}i (stable osc)" if li > 1e-10 else f"{lr:.3f} (unstable)" + print(f" Free-body λ = {lam_str}; K_p_min=0 K_p_max={kp_max:.4f}") + print(f"{'─'*60}") + + omega_star = np.array([args.omega_star, 0.0, 0.0]) + Kp_list = [0.05, 0.10, 0.20, 0.50] + + results = {} + for Kp in Kp_list: + env_wind = Env(disturbance_torque_std=wind_std, seed=args.seed) + r = run_trials( + env=env_wind, omega_star_vec=omega_star, Kp=Kp, + n_trials=args.n_trials, n_steps=args.n_steps, + conv_thr=args.conv_thr, seed=args.seed, + ic_mode="hold", hold_perturb=args.hold_perturb, + ) + print_trial_result(r, args.n_steps, env_wind.dt, args.conv_thr, + show_dist=True) + env_wind.close() + results[Kp] = r + + print(f"\n Final-error vs predicted SS offset (|d|/Kp):") + print(f" {'Kp':>8} {'final err (mean)':>18} {'predicted SS':>14} {'ratio':>8}") + for Kp, r in results.items(): + dn = r['disturbance_norms'].mean() + pred = dn / Kp + fe = r['final_err'].mean() + ratio = fe / pred if pred > 0 else float('nan') + print(f" {Kp:8.4f} {fe:18.4f} {pred:14.4f} {ratio:8.3f}") + + return results + + +def scenario_E_e3_wind(args, I1, I2, I3, wind_std): + """Hold e₃ with stochastic wind. K_p scan. + + e₃ is the stable (largest-I) axis. Same analysis as e₁+wind: + K_p_min = 0 from stability, K_p_max = 2·I₁/T from ZOH, + wind adds steady-state offset |d|/K_p. + """ + print(f"\n{'─'*60}") + print(f"Scenario E: Hold e₃ with wind") + print(f" Target ω* = (0, 0, {args.omega_star:.4f}) rad/s") + print(f" IC: ω* + N(0, {args.hold_perturb:.2f}²)") + print(f" Wind: disturbance_torque_std = {wind_std:.3f} N·m (constant per episode)") + kp_max = kp_max_zoh(I1, 0.05) + lr, li = free_body_eigenvalue(I1, I2, I3, args.omega_star, target_axis=2) + lam_str = f"±{li:.3f}i (stable osc)" if li > 1e-10 else f"{lr:.3f} (unstable)" + print(f" Free-body λ = {lam_str}; K_p_min=0 K_p_max={kp_max:.4f}") + print(f"{'─'*60}") + + omega_star = np.array([0.0, 0.0, args.omega_star]) + Kp_list = [0.05, 0.10, 0.20, 0.50] + + results = {} + for Kp in Kp_list: + env_wind = Env(disturbance_torque_std=wind_std, seed=args.seed) + r = run_trials( + env=env_wind, omega_star_vec=omega_star, Kp=Kp, + n_trials=args.n_trials, n_steps=args.n_steps, + conv_thr=args.conv_thr, seed=args.seed, + ic_mode="hold", hold_perturb=args.hold_perturb, + ) + print_trial_result(r, args.n_steps, env_wind.dt, args.conv_thr, + show_dist=True) + env_wind.close() + results[Kp] = r + + print(f"\n Final-error vs predicted SS offset (|d|/Kp):") + print(f" {'Kp':>8} {'final err (mean)':>18} {'predicted SS':>14} {'ratio':>8}") + for Kp, r in results.items(): + dn = r['disturbance_norms'].mean() + pred = dn / Kp + fe = r['final_err'].mean() + ratio = fe / pred if pred > 0 else float('nan') + print(f" {Kp:8.4f} {fe:18.4f} {pred:14.4f} {ratio:8.3f}") + + return results + + +# ───────────────────────────────────────────────────────────────────────────── +# Main +# ───────────────────────────────────────────────────────────────────────────── + +def main(): + args = get_args() + + print("=" * 62) + print("Stage E v2 — e₁/e₂/e₃ hold, no-wind + wind robustness") + print("=" * 62) + + # Default env (no wind, cfg0) + env_clean = Env(disturbance_torque_std=0.0, seed=args.seed) + iinfo = env_clean.get_inertia_info() + I1, I2, I3 = iinfo['I1'], iinfo['I2'], iinfo['I3'] + + print(f"\nRacket (cfg0):") + print(f" I1={I1:.6f} I2={I2:.6f} I3={I3:.6f} kg·m²") + print(f" I2/I1={I2/I1:.2f} asym=(I3-I1)/I2={(I3-I1)/I2:.2f}") + print(f"\nTrials: {args.n_trials} × {args.n_steps} steps " + f"(dt={env_clean.dt:.2f}s → {args.n_steps*env_clean.dt:.1f}s per trial)") + print(f"Convergence threshold: ‖ω-ω*‖ < {args.conv_thr:.2f} rad/s") + + # Stability table + print(f"\n── Free-body stability (ω* = {args.omega_star:.4f} rad/s, cfg0) ──") + print_stability_table(I1, I2, I3, args.omega_star, env_clean.dt) + + # Run all five scenarios + resA = scenario_A(args, I1, I2, I3, env_clean) + resB = scenario_B(args, I1, I2, I3, env_clean) + resC = scenario_C(args, I1, I2, I3, args.wind_std) + resD = scenario_D_e1_wind(args, I1, I2, I3, args.wind_std) + resE = scenario_E_e3_wind(args, I1, I2, I3, args.wind_std) + + # ── Final verdict ────────────────────────────────────────────────────── + A_pass = resA['pct_converged'] >= 95 + + B_r010 = resB.get(0.100, resB[max(resB)]) + B_pass = B_r010['pct_converged'] >= 95 and B_r010['late_max'].mean() < 0.5 + + def best_stable_kp(res): + """Return (best_kp, result) — highest convergence among stable Kps.""" + best = max( + (k for k in res if res[k]['late_max'].mean() < 0.5), + key=lambda k: res[k]['pct_converged'], + default=None, + ) + return best, (res[best] if best is not None else res[min(res)]) + + C_best_kp, C_r_best = best_stable_kp(resC) + C_pass = C_r_best['pct_converged'] >= 95 + + D_best_kp, D_r_best = best_stable_kp(resD) + D_pass = D_r_best['pct_converged'] >= 95 + + E_best_kp, E_r_best = best_stable_kp(resE) + E_pass = E_r_best['pct_converged'] >= 95 + + kp_max = kp_max_zoh(I1, env_clean.dt) + print(f"\n{'='*62}") + print(f"Stage E summary (Kp=0.10 unless noted; ZOH bound={kp_max:.3f})") + print(f" A — e₁ redirect (no wind): {'PASS ✓' if A_pass else 'FAIL ✗'}") + print(f" B — e₂ hold (no wind): {'PASS ✓' if B_pass else 'FAIL ✗'}") + print(f" C — e₂ hold (wind {args.wind_std:.2f} N·m): " + f"{'PASS ✓' if C_pass else 'FAIL ✗'}" + + (f" [best Kp={C_best_kp:.3f}]" if C_best_kp else "")) + print(f" D — e₁ hold (wind {args.wind_std:.2f} N·m): " + f"{'PASS ✓' if D_pass else 'FAIL ✗'}" + + (f" [best Kp={D_best_kp:.3f}]" if D_best_kp else "")) + print(f" E — e₃ hold (wind {args.wind_std:.2f} N·m): " + f"{'PASS ✓' if E_pass else 'FAIL ✗'}" + + (f" [best Kp={E_best_kp:.3f}]" if E_best_kp else "")) + print(f"{'='*62}") + + env_clean.close() + + +if __name__ == '__main__': + main() diff --git a/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/simulate_stageF.py b/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/simulate_stageF.py new file mode 100644 index 0000000..6422b01 --- /dev/null +++ b/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/simulate_stageF.py @@ -0,0 +1,481 @@ +"""Stage F simulation: full-state (R*, ω*) stabilisation on SO(3) × ℝ³. + +Stage D/E only controlled angular velocity ω → ω*. Stage F adds an orientation +target R* ∈ SO(3), using the geodesic attitude error e_R = vee(logm(R*ᵀ R)). + +Control law: u = −K_R · e_R − K_p · e_ω + +Four scenarios: + A. Rest at identity: R* = I₃, ω* = 0. K_R scan. + B. Rest at 45° pitch: R* = Ry(π/4), ω* = 0. K_R = 0.10. + C. Sanity check (K_R=0 → Stage D): + R* = I₃, ω* = (0,0,2π). Must match Stage D. + D. Rest at identity with wind (σ = 0.05 N·m): + R* = I₃, ω* = 0. K_R = 0.10. + +All scenarios start from Dzhanibekov tumbling (axis=1, e₂ unstable) with a random +initial orientation R₀ ∈ SO(3). Convergence requires simultaneously: + ‖e_R‖ < theta_thr [rad] (default 0.10 rad ≈ 5.7°) + ‖e_ω‖ < omega_thr [rad/s] (default 0.50 rad/s) + +Usage: + /Users/katesur/Projects/LieSPHGP/venv/bin/python3 -u simulate_stageF.py + +Options: + --omega_star Spin rate for Scenario C [rad/s] (default 2π) + --n_trials Trials per configuration (default 20) + --n_steps Steps per trial (dt=0.05s) (default 300 = 15s) + --hold_perturb Not used for axis starts; kept for compatibility + --wind_std Disturbance std for Scenario D [N·m] (default 0.05) + --theta_thr Attitude convergence threshold [rad] (default 0.10) + --omega_thr Velocity convergence threshold [rad/s] (default 0.50) + --seed RNG seed (default 0) +""" +import argparse +import os +import sys +import numpy as np + +THIS_FILE_DIR = os.path.dirname(os.path.abspath(__file__)) +PROJECT_ROOT = os.path.abspath(os.path.join(THIS_FILE_DIR, '../../../..')) +sys.path.insert(0, PROJECT_ROOT) +sys.path.insert(0, os.path.join(PROJECT_ROOT, 'envs')) +sys.path.insert(0, THIS_FILE_DIR) + +from envs.tennis_racket_3d import tennis_racket_3d as Env +from controller_stageF import GeometricAttitudeController, expm_SO3, hat + + +# ───────────────────────────────────────────────────────────────────────────── +# Argument parsing +# ───────────────────────────────────────────────────────────────────────────── + +def get_args(): + p = argparse.ArgumentParser() + p.add_argument('--omega_star', type=float, default=2 * np.pi) + p.add_argument('--n_trials', type=int, default=20) + p.add_argument('--n_steps', type=int, default=300) + p.add_argument('--wind_std', type=float, default=0.05) + p.add_argument('--theta_thr', type=float, default=0.10, + help='Attitude convergence threshold ‖e_R‖ [rad]') + p.add_argument('--omega_thr', type=float, default=0.50, + help='Velocity convergence threshold ‖e_ω‖ [rad/s]') + p.add_argument('--seed', type=int, default=0) + return p.parse_args() + + +# ───────────────────────────────────────────────────────────────────────────── +# Linearised stability helpers +# ───────────────────────────────────────────────────────────────────────────── + +def print_linearised_stability(I1, I2, I3, K_R, K_p, dt): + """Print ωn, ζ, mode type for each principal axis. + + The linearised closed-loop per axis i is an independent 2nd-order system: + [ė_R_i] = [0 1 ] [e_R_i] + [ė_ω_i] [-K_R/Iᵢ -K_p/Iᵢ] [e_ω_i] + + Natural frequency: ωn = √(K_R / Iᵢ) + Damping ratio: ζ = K_p / (2 √(K_R · Iᵢ)) + Critical K_p: K_p_cd = 2 √(K_R · Iᵢ) + """ + names = ["e₁ (I₁, smallest)", "e₂ (I₂, middle)", "e₃ (I₃, largest)"] + I_vals = [I1, I2, I3] + kp_max_zoh = 2.0 * I1 / dt + print(f" {'Axis':<22} {'ωn [rad/s]':>12} {'ζ':>8} {'mode':>12}") + print(f" {'-'*22} {'-'*12} {'-'*8} {'-'*12}") + for name, I in zip(names, I_vals): + if K_R < 1e-15: + print(f" {name:<22} {'—':>12} {'—':>8} {'K_R=0 (Stage D)':>12}") + continue + wn = np.sqrt(K_R / I) + zet = K_p / (2.0 * np.sqrt(K_R * I)) + mode = "overdamped" if zet > 1.0 else ("critically" if abs(zet - 1) < 0.05 + else "oscillatory") + print(f" {name:<22} {wn:12.3f} {zet:8.3f} {mode:>12}") + kr_max_zoh = K_p / dt # 2nd-order ZOH bound for K_R: K_R < K_p/dt + print(f" ZOH K_p_max = 2·I₁/dt = {kp_max_zoh:.4f} [current K_p={K_p:.4f} — " + + ("OK ✓" if K_p < kp_max_zoh else "EXCEEDS BOUND ✗") + "]") + print(f" ZOH K_R_max = K_p/dt = {kr_max_zoh:.4f} [current K_R={K_R:.4f} — " + + ("OK ✓" if K_R < kr_max_zoh else "EXCEEDS BOUND ✗") + "]") + + +# ───────────────────────────────────────────────────────────────────────────── +# Core simulation +# ───────────────────────────────────────────────────────────────────────────── + +def run_trials( + env: Env, + ctrl: GeometricAttitudeController, + n_trials: int, + n_steps: int, + theta_thr: float, + omega_thr: float, + seed: int, + ic_axis: int = 1, +) -> dict: + """Run n_trials closed-loop episodes. Returns convergence statistics. + + Convergence criterion: ‖e_R‖ < theta_thr AND ‖e_ω‖ < omega_thr + simultaneously. + """ + dt = env.dt + + err_R_traces = np.zeros((n_trials, n_steps + 1), dtype=np.float64) + err_w_traces = np.zeros((n_trials, n_steps + 1), dtype=np.float64) + disturbance_norms = np.zeros(n_trials, dtype=np.float64) + conv_steps = [] + + for trial in range(n_trials): + obs, _ = env.reset(seed=seed + trial, options={"axis": ic_axis}) + R_flat = obs[:9].astype(np.float64) + omega = obs[9:12].astype(np.float64) + disturbance_norms[trial] = float(np.linalg.norm(env._disturbance_torque)) + + eR, ew = ctrl.errors(R_flat, omega) + err_R_traces[trial, 0] = eR + err_w_traces[trial, 0] = ew + + for step in range(1, n_steps + 1): + u = ctrl(R_flat, omega) + obs, _, _, _, _ = env.step(u) + R_flat = obs[:9].astype(np.float64) + omega = obs[9:12].astype(np.float64) + eR, ew = ctrl.errors(R_flat, omega) + err_R_traces[trial, step] = eR + err_w_traces[trial, step] = ew + + # First step both errors simultaneously below threshold + conv_mask = ((err_R_traces[trial] < theta_thr) + & (err_w_traces[trial] < omega_thr)) + below = np.where(conv_mask)[0] + conv_steps.append(int(below[0]) if len(below) else n_steps) + + conv_steps = np.array(conv_steps, dtype=float) + + return { + 'err_R_traces': err_R_traces, # (n_trials, n_steps+1) [rad] + 'err_w_traces': err_w_traces, # (n_trials, n_steps+1) [rad/s] + 't_grid': np.arange(n_steps + 1) * dt, + 'conv_steps': conv_steps, + 'conv_times': conv_steps * dt, # [s] + 'final_eR': err_R_traces[:, -1], + 'final_ew': err_w_traces[:, -1], + 'late_eR': err_R_traces[:, n_steps // 2:].max(axis=1), + 'late_ew': err_w_traces[:, n_steps // 2:].max(axis=1), + 'pct_converged': 100.0 * np.mean(conv_steps < n_steps), + 'disturbance_norms': disturbance_norms, + } + + +def print_trial_result(r: dict, n_steps: int, dt: float, + theta_thr: float, omega_thr: float, + K_R: float, K_p: float, + show_dist: bool = False): + ct = r['conv_times'] + feR = r['final_eR'] + few = r['final_ew'] + lmR = r['late_eR'] + lmw = r['late_ew'] + pct = r['pct_converged'] + eR0 = r['err_R_traces'][:, 0] + ew0 = r['err_w_traces'][:, 0] + + # Distinguish truly unstable (errors grow) from bounded SS offset. + # Compare 2nd-half mean vs 1st-step value — growing if 2nd-half > initial. + truly_unstable = (lmR.mean() > r['err_R_traces'][:, 0].mean() * 1.5 + or lmw.mean() > r['err_w_traces'][:, 0].mean() * 1.5) + has_large_ss = lmR.mean() > theta_thr * 5 or lmw.mean() > omega_thr * 5 + label = ("UNSTABLE" if truly_unstable + else "BOUNDED_SS" if has_large_ss + else "CONVERGED" if pct >= 95 + else "PARTIAL") + + print(f"\n K_R={K_R:.4f} K_p={K_p:.4f} [{label}]") + print(f" Initial ‖e_R‖: {eR0.mean():.3f} ± {eR0.std():.3f} rad") + print(f" Initial ‖e_ω‖: {ew0.mean():.3f} ± {ew0.std():.3f} rad/s") + print(f" Final ‖e_R‖: {feR.mean():.4f} ± {feR.std():.4f} rad") + print(f" Final ‖e_ω‖: {few.mean():.4f} ± {few.std():.4f} rad/s") + print(f" Late max ‖e_R‖: {lmR.mean():.4f} ± {lmR.std():.4f} rad (2nd half)") + print(f" Late max ‖e_ω‖: {lmw.mean():.4f} ± {lmw.std():.4f} rad/s (2nd half)") + if pct > 0: + print(f" Conv (‖e_R‖<{theta_thr:.2f} AND ‖e_ω‖<{omega_thr:.2f}): " + f"{ct.mean():.3f} ± {ct.std():.3f} s " + f"(min={ct.min():.2f} max={ct.max():.2f})") + print(f" Converged: {pct:.0f}% of {len(feR)} trials within {n_steps*dt:.1f}s") + if show_dist: + dn = r['disturbance_norms'] + print(f" |d|: {dn.mean():.4f} ± {dn.std():.4f} N·m " + f"(SS ‖e_ω‖ ≈ {dn.mean()/K_p:.3f} rad/s; " + f"SS ‖e_R‖ ≈ {dn.mean()/K_R:.3f} rad)") + + +def print_error_profile(r: dict, env_dt: float, n_steps: int, n_trials: int): + """Print mean attitude and velocity error at selected timesteps.""" + print(f"\n Error profile (avg over {n_trials} trials):") + print(f" {'t[s]':>6} {'‖e_R‖ mean':>12} {'‖e_R‖ max':>12} " + f"{'‖e_ω‖ mean':>12} {'‖e_ω‖ max':>12}") + steps = list(range(0, min(21, n_steps + 1), 2)) + if n_steps not in steps: + steps.append(n_steps) + for s in steps: + t = s * env_dt + eRm = r['err_R_traces'][:, s].mean() + eRx = r['err_R_traces'][:, s].max() + ewm = r['err_w_traces'][:, s].mean() + ewx = r['err_w_traces'][:, s].max() + print(f" {t:6.2f} {eRm:12.6f} {eRx:12.6f} {ewm:12.6f} {ewx:12.6f}") + + +# ───────────────────────────────────────────────────────────────────────────── +# Scenario runners +# ───────────────────────────────────────────────────────────────────────────── + +def scenario_A(args, I1, I2, I3, env_clean): + """Rest at identity: R* = I₃, ω* = 0. K_R scan.""" + print(f"\n{'─'*62}") + print(f"Scenario A: Rest at identity R* = I₃, ω* = 0") + print(f" IC: Dzhanibekov tumbling (axis=1, random R₀)") + print(f" K_p = 0.10 fixed; scanning K_R") + print(f"{'─'*62}") + + R_star = np.eye(3) + omega_star = np.zeros(3) + K_p = 0.10 + Kp_max = 2.0 * I1 / env_clean.dt + K_R_list = [0.01, 0.05, 0.10, 0.20, 0.50] + + results = {} + for K_R in K_R_list: + print(f"\n ── Stability analysis (K_R={K_R:.4f}, K_p={K_p:.4f}) ──") + print_linearised_stability(I1, I2, I3, K_R, K_p, env_clean.dt) + + ctrl = GeometricAttitudeController(R_star, omega_star, K_R=K_R, K_p=K_p) + r = run_trials( + env=env_clean, ctrl=ctrl, + n_trials=args.n_trials, n_steps=args.n_steps, + theta_thr=args.theta_thr, omega_thr=args.omega_thr, + seed=args.seed, + ) + print_trial_result(r, args.n_steps, env_clean.dt, + args.theta_thr, args.omega_thr, K_R, K_p) + results[K_R] = r + + # Pick best: fastest convergence among fully converged + fully_conv = [k for k, r in results.items() if r['pct_converged'] >= 95] + if fully_conv: + best_KR = min(fully_conv, key=lambda k: results[k]['conv_times'].mean()) + else: + best_KR = K_R_list[-1] + + print(f"\n ── Error profile for best K_R = {best_KR:.4f} ──") + print_error_profile(results[best_KR], env_clean.dt, args.n_steps, args.n_trials) + + return results, best_KR + + +def scenario_B(args, I1, I2, I3, env_clean): + """Rest at 45° pitch: R* = Ry(π/4), ω* = 0. K_R = 0.10.""" + K_R = 0.10 + K_p = 0.10 + R_star = expm_SO3(hat([0.0, np.pi / 4.0, 0.0])) # 45° rotation about e₂ + omega_star = np.zeros(3) + + print(f"\n{'─'*62}") + print(f"Scenario B: Rest at 45° pitch R* = Ry(π/4), ω* = 0") + print(f" Target R*:\n{R_star.round(4)}") + print(f" IC: Dzhanibekov tumbling (axis=1, random R₀)") + print(f" K_R = {K_R:.4f}, K_p = {K_p:.4f}") + print(f"{'─'*62}") + + print(f"\n ── Stability analysis ──") + print_linearised_stability(I1, I2, I3, K_R, K_p, env_clean.dt) + + ctrl = GeometricAttitudeController(R_star, omega_star, K_R=K_R, K_p=K_p) + r = run_trials( + env=env_clean, ctrl=ctrl, + n_trials=args.n_trials, n_steps=args.n_steps, + theta_thr=args.theta_thr, omega_thr=args.omega_thr, + seed=args.seed, + ) + print_trial_result(r, args.n_steps, env_clean.dt, + args.theta_thr, args.omega_thr, K_R, K_p) + print_error_profile(r, env_clean.dt, args.n_steps, args.n_trials) + return r + + +def scenario_C_sanity(args, I1, I2, I3, env_clean): + """Sanity check: K_R = 0 must reproduce Stage D (K_p=0.10, ω*=(0,0,2π)).""" + K_R = 0.0 + K_p = 0.10 + R_star = np.eye(3) + omega_star = np.array([0.0, 0.0, args.omega_star]) + + print(f"\n{'─'*62}") + print(f"Scenario C (sanity): K_R=0 → Stage D " + f"[ω*=(0,0,{args.omega_star:.2f}), K_p={K_p}]") + print(f" Expected: ~100% convergence in ~0.3s (matches Stage D result)") + print(f"{'─'*62}") + + ctrl = GeometricAttitudeController(R_star, omega_star, K_R=K_R, K_p=K_p) + r = run_trials( + env=env_clean, ctrl=ctrl, + n_trials=args.n_trials, n_steps=args.n_steps, + theta_thr=np.pi, # K_R=0 → no R criterion; use π so it never blocks + omega_thr=args.omega_thr, + seed=args.seed, + ) + # For sanity, report only the velocity convergence + ct = r['conv_times'] + few = r['final_ew'] + pct = r['pct_converged'] + print(f"\n Conv time (‖e_ω‖<{args.omega_thr:.2f}): " + f"{ct.mean():.3f} ± {ct.std():.3f} s") + print(f" Final ‖e_ω‖: {few.mean():.4f} ± {few.std():.4f} rad/s") + print(f" Converged: {pct:.0f}% of {args.n_trials} trials " + f"({'PASS ✓' if pct >= 95 else 'FAIL ✗'} matches Stage D)") + return r + + +def scenario_D_wind(args, I1, I2, I3, wind_std): + """Rest at identity with stochastic wind. R* = I₃, ω* = 0. K_R scan. + + Key physics (differs from Stage E): + With ω* = 0, constant wind d ≠ 0 is balanced in steady state by the + orientation restoring torque: + K_R · e_R = d → ‖e_R‖_ss ≈ |d| / K_R + + The angular velocity converges to zero (K_p damps ω → 0), so the wind + creates an ORIENTATION offset, not a velocity offset. This is the dual + of Stage E, where ω* ≠ 0 and wind created a VELOCITY offset |d|/K_p. + + ZOH bound for K_R (2nd-order attitude loop): + The discrete-time I₁ mode has characteristic polynomial + z² − (2 − K_p T/I₁) z + (1 − K_p T/I₁ + K_R T²/I₁) = 0 + Schur stability requires K_R < K_p / T. + For K_p=0.10, T=0.05: K_R_max = 0.10/0.05 = 2.00 N·m/rad. + + Viable K_R with wind to satisfy ‖e_R‖ < θ_thr: + K_R_min(wind) = |d| / θ_thr (SS offset requirement) + K_R_max(ZOH) = K_p / T (discrete stability) + """ + K_p = 0.10 + R_star = np.eye(3) + omega_star = np.zeros(3) + KR_max_zoh = K_p / 0.05 # = 2.00 for K_p=0.10, dt=0.05 + + print(f"\n{'─'*62}") + print(f"Scenario D: Rest at identity with wind") + print(f" R* = I₃, ω* = 0, wind σ = {wind_std:.3f} N·m") + print(f" K_p = {K_p:.4f} K_R scan") + print(f" ZOH K_R_max = K_p/T = {KR_max_zoh:.4f} N·m/rad") + print(f" SS orientation offset: ‖e_R‖_ss ≈ |d|/K_R (wind balanced by K_R e_R)") + print(f"{'─'*62}") + + # Build a fresh wind env per K_R (same disturbances due to fixed seed) + K_R_list = [0.10, 0.50, 1.00, 1.50] + results = {} + d_mean = None + + for K_R in K_R_list: + env_wind = Env(disturbance_torque_std=wind_std, seed=args.seed) + ctrl = GeometricAttitudeController(R_star, omega_star, K_R=K_R, K_p=K_p) + r = run_trials( + env=env_wind, ctrl=ctrl, + n_trials=args.n_trials, n_steps=args.n_steps, + theta_thr=args.theta_thr, omega_thr=args.omega_thr, + seed=args.seed, + ) + if d_mean is None: + d_mean = r['disturbance_norms'].mean() + print_trial_result(r, args.n_steps, env_wind.dt, + args.theta_thr, args.omega_thr, K_R, K_p, + show_dist=(K_R == K_R_list[0])) + env_wind.close() + results[K_R] = r + + # SS-offset prediction table + print(f"\n SS orientation error vs predicted |d|/K_R (|d| ≈ {d_mean:.4f} N·m):") + print(f" {'K_R':>6} {'‖e_R‖ actual':>14} {'|d|/K_R pred':>14} " + f"{'ratio':>7} {'‖e_ω‖ actual':>14}") + for K_R, r in results.items(): + feR = r['final_eR'].mean() + few = r['final_ew'].mean() + pred = d_mean / K_R + ratio = feR / pred if pred > 0 else float('nan') + print(f" {K_R:6.2f} {feR:14.4f} {pred:14.4f} {ratio:7.3f} {few:14.4f}") + + # Find best K_R (max convergence % within ZOH bound) + best_KR = max( + (k for k in results if k < KR_max_zoh * 0.99), + key=lambda k: results[k]['pct_converged'], + default=None, + ) + + print(f"\n ZOH window for tight holding: K_R ∈ ({d_mean/args.theta_thr:.3f}, " + f"{KR_max_zoh:.3f}) [θ_thr={args.theta_thr:.2f} rad]") + if best_KR: + r_best = results[best_KR] + print(f" Best K_R = {best_KR:.2f}: {r_best['pct_converged']:.0f}% conv, " + f"late ‖e_R‖ = {r_best['late_eR'].mean():.4f} rad") + + return results, best_KR + + +# ───────────────────────────────────────────────────────────────────────────── +# Main +# ───────────────────────────────────────────────────────────────────────────── + +def main(): + args = get_args() + + print("=" * 62) + print("Stage F — Geometric Attitude Control on SO(3) × ℝ³") + print("=" * 62) + + env_clean = Env(disturbance_torque_std=0.0, seed=args.seed) + iinfo = env_clean.get_inertia_info() + I1, I2, I3 = iinfo['I1'], iinfo['I2'], iinfo['I3'] + + print(f"\nRacket (cfg0):") + print(f" I1={I1:.6f} I2={I2:.6f} I3={I3:.6f} kg·m²") + print(f" K_p_max (ZOH) = 2·I₁/dt = {2*I1/env_clean.dt:.4f} N·m·s/rad") + print(f"\nTrials: {args.n_trials} × {args.n_steps} steps " + f"(dt={env_clean.dt:.2f}s → {args.n_steps*env_clean.dt:.1f}s per trial)") + print(f"Convergence: ‖e_R‖ < {args.theta_thr:.2f} rad AND " + f"‖e_ω‖ < {args.omega_thr:.2f} rad/s (simultaneously)") + + resA, best_KR_A = scenario_A(args, I1, I2, I3, env_clean) + resB = scenario_B(args, I1, I2, I3, env_clean) + resC = scenario_C_sanity(args, I1, I2, I3, env_clean) + resD, best_KR_D = scenario_D_wind(args, I1, I2, I3, args.wind_std) + + # ── Final verdict ────────────────────────────────────────────────────── + A_pass = resA[best_KR_A]['pct_converged'] >= 95 + + B_pass = resB['pct_converged'] >= 95 + + C_pass = resC['pct_converged'] >= 95 + + # Scenario D: converges if best K_R achieves ≥95% within ZOH bound + D_r_best = resD[best_KR_D] if best_KR_D is not None else resD[min(resD)] + D_pass = D_r_best['pct_converged'] >= 95 + + KR_max = K_p_val = 0.10 / env_clean.dt # K_p=0.10 at dt=0.05 + + print(f"\n{'='*62}") + print(f"Stage F summary") + print(f" A — rest at I₃ (no wind): {'PASS ✓' if A_pass else 'FAIL ✗'}" + f" [best K_R={best_KR_A:.3f}]") + print(f" B — rest at Ry(π/4)(no wind): {'PASS ✓' if B_pass else 'FAIL ✗'}") + print(f" C — Stage D sanity (K_R=0): {'PASS ✓' if C_pass else 'FAIL ✗'}") + print(f" D — rest at I₃ (wind {args.wind_std:.2f} N·m): " + f"{'PASS ✓' if D_pass else 'FAIL ✗'}" + + (f" [best K_R={best_KR_D:.2f}]" if best_KR_D else "")) + print(f"{'='*62}") + + env_clean.close() + + +if __name__ == '__main__': + main() diff --git a/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/train.py b/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/train.py new file mode 100644 index 0000000..eb5991d --- /dev/null +++ b/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/train.py @@ -0,0 +1,530 @@ +"""fp32 training for the tennis-racket free-rigid-body SO(3) Hamiltonian NODE. + +Copied from 3D_SO3_Windy_Pendulum/ph_nn_ode_v2/train.py and adapted. + +Differences from the pendulum version: + - Loads the tennis-racket dataset directly from --data_path (a .pkl file). + - Selects one geometry config via --config_idx (Stage A: single-config training). + - FixedInertia replaces FixedInverseMass: returns diag(1/I1,1/I2,1/I3). + - pretrain_M_net target is diag(1/I1,1/I2,1/I3) read from the dataset. + - subnet diagnostics use subnet_physics_mse_tennis (V_tgt=0, Dw_tgt=0). + - Pendulum-specific args removed: friction_coeff, wind, external_force, random_u. + - network.py and loss_utils.py are imported from the shared pendulum ph_nn_ode_v2 + directory (identical architecture, no copy needed). + +Stage A assumptions (torque-free, no friction): + - disturbance_torque_std = 0 → Dw_tgt = 0 in diagnostics + - V_net should converge to a near-constant (V_tgt = 0 for free body) + - M_net should converge to diag(1/I1, 1/I2, 1/I3) + - g_net should converge to I₃ + +Benchmarks to declare Stage A successful: + 1. eval_M_loss < 1e-4 by end of training + 2. test_geo_loss < 0.01 rad² on the windowed eval + 3. eval_V_loss is small relative to eval_M_loss + 4. det(R̂) ≈ 1 throughout (guaranteed by architecture, verify numerically) +""" +import torch, argparse +import numpy as np +import os, sys +import time +import pickle + +THIS_FILE_DIR = os.path.dirname(os.path.abspath(__file__)) +PROJECT_ROOT = os.path.abspath(os.path.join(THIS_FILE_DIR, '../../../..')) + +# Shared network.py and loss_utils.py live in the pendulum model directory. +# The architecture is identical; we reuse without copying. +PENDULUM_ODE_DIR = os.path.join( + PROJECT_ROOT, 'src/models/3D_SO3_Windy_Pendulum/ph_nn_ode_v2') + +sys.path.insert(0, PROJECT_ROOT) +sys.path.insert(0, os.path.join(PROJECT_ROOT, 'src/utils')) +sys.path.insert(0, os.path.join(PROJECT_ROOT, 'datasets')) +sys.path.insert(0, PENDULUM_ODE_DIR) + +from torchdiffeq import odeint + +from ode_utils import to_pickle +from subnet_diagnostics_tennis import subnet_physics_mse_tennis +from tennis_racket_3d_datagen import arrange_data +from network import DissipativeSO3HamNODE +from loss_utils import ( + rotmat_L2_geodesic_loss_safe as rotmat_L2_geodesic_loss, + traj_rotmat_L2_geodesic_loss_safe as traj_rotmat_L2_geodesic_loss, + power_balance_loss, + consistency_subnet_losses, +) + + +DEFAULT_SAVE_DIR = os.path.join(THIS_FILE_DIR, 'data', 'run_tr3d_fp32') +DEFAULT_DATA_PATH = os.path.join( + PROJECT_ROOT, + 'data/tennis_data/' + 'tr3d_dataset_dist0p0_obs_noise0p0_perturb0p05_ncfg4_steps100.pkl' +) + + +# ───────────────────────────────────────────────────────────────────────────── +# Fixed inverse-inertia module (drop-in M_net for Stage A sanity checks) +# ───────────────────────────────────────────────────────────────────────────── + +class FixedInertia(torch.nn.Module): + """Drop-in M_net that returns diag(1/I1, 1/I2, 1/I3) — no learnable parameters. + + M_net outputs M⁻¹ (the inverse inertia tensor). For the tennis racket this + is diag(1/I1, 1/I2, 1/I3). Use --fix_M to pin M_net here and let the + optimizer focus on V_net, Dw_net, g_net. + """ + def __init__(self, I1: float, I2: float, I3: float): + super().__init__() + I_inv = torch.tensor([1.0/I1, 1.0/I2, 1.0/I3], dtype=torch.float32) + self.register_buffer('M_inv', torch.diag(I_inv)) + + def forward(self, q): + return self.M_inv.unsqueeze(0).expand(q.shape[0], 3, 3) + + +# ───────────────────────────────────────────────────────────────────────────── +# CLI +# ───────────────────────────────────────────────────────────────────────────── + +def get_args(): + parser = argparse.ArgumentParser(description=None) + parser.add_argument('--learn_rate', default=1e-3, type=float) + parser.add_argument('--total_steps', default=10000, type=int) + parser.add_argument('--eval_every', default=50, type=int, + help='windowed eval + diagnostics + checkpoint cadence') + parser.add_argument('--name', default='tr3d', type=str) + parser.add_argument('--verbose', action='store_true') + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--save_dir', default=DEFAULT_SAVE_DIR, type=str) + parser.add_argument('--data_path', default=DEFAULT_DATA_PATH, type=str, + help='path to the tennis-racket .pkl dataset') + parser.add_argument('--gpu', type=int, default=0) + parser.add_argument('--num_points', type=int, default=5) + parser.add_argument('--solver', default='rk4', type=str) + parser.add_argument('--init_gain', default=0.5, type=float) + + parser.add_argument('--samples', type=int, default=25) + parser.add_argument('--timesteps', type=int, default=100) + parser.add_argument('--obs_noise_std', type=float, default=0.0) + parser.add_argument('--disturbance_torque_std', type=float, default=0.0) + parser.add_argument('--perturb_std', type=float, default=0.05) + + # Which geometry config in the dataset to train on (0-indexed). + parser.add_argument('--config_idx', type=int, default=0, + help='index into data[inertia_info] to select (0-3 for the default 4-config dataset)') + + # M_net pretraining: anchors M_net near diag(1/I1,1/I2,1/I3) before joint training. + parser.add_argument('--pretrain_M_steps', type=int, default=200) + parser.add_argument('--pretrain_M_lr', type=float, default=1e-3) + parser.add_argument('--pretrain_M_print_every', type=int, default=20) + + # Physics-informed auxiliary losses (off by default). + parser.add_argument('--lambda_power', type=float, default=0.0, + help='weight for power-balance loss; 0 disables') + parser.add_argument('--lambda_V', type=float, default=0.0, + help='weight for V back-solving loss; 0 disables') + parser.add_argument('--lambda_B', type=float, default=0.0, + help='weight for B (input coupling) back-solving loss; 0 disables') + parser.add_argument('--lambda_D', type=float, default=0.0, + help='weight for D (dissipation) back-solving loss; 0 disables') + + # Pin M_net to the analytic ground-truth diag(1/I1,1/I2,1/I3). + parser.add_argument('--fix_M', action='store_true', + help='use FixedInertia (non-learnable M_net) with true I from the data') + + return parser.parse_args() + + +# ───────────────────────────────────────────────────────────────────────────── +# Utilities +# ───────────────────────────────────────────────────────────────────────────── + +def get_model_parm_nums(model): + return sum(p.nelement() for p in model.parameters()) + + +def _fmt_num(x): + """Format a number for filesystem paths: 0.01 → 0p01, -1.0 → n1.""" + s = f"{x:g}" + return s.replace('.', 'p').replace('-', 'n').replace('+', '') + + +def build_run_name(args): + """Build a run-folder name encoding training specs + a YYMMDD-HHMM stamp.""" + parts = [ + f"obs{_fmt_num(args.obs_noise_std)}", + f"dist{_fmt_num(args.disturbance_torque_std)}", + f"cfg{args.config_idx}", + f"lP{_fmt_num(args.lambda_power)}", + f"lV{_fmt_num(args.lambda_V)}", + f"lB{_fmt_num(args.lambda_B)}", + f"lD{_fmt_num(args.lambda_D)}", + f"lr{_fmt_num(args.learn_rate)}", + f"s{args.total_steps}", + f"np{args.num_points}", + f"smp{args.samples}", + f"T{args.timesteps}", + f"{args.solver}", + f"seed{args.seed}", + ] + if args.fix_M: + parts.append('fixM') + stamp = time.strftime('%y%m%d-%H%M%S') + return '_'.join(parts) + '_' + stamp + + +def _state_dict(model): + return (model._orig_mod if hasattr(model, '_orig_mod') else model).state_dict() + + +def _inner(model): + """Unwrap a possibly torch.compile-wrapped model.""" + return model._orig_mod if hasattr(model, '_orig_mod') else model + + +# ───────────────────────────────────────────────────────────────────────────── +# M_net pretraining +# ───────────────────────────────────────────────────────────────────────────── + +def pretrain_M_net(model, q_samples, n_steps, lr, print_every, I1, I2, I3): + """Pretrain M_net to output the true inverse inertia diag(1/I1, 1/I2, 1/I3). + + For a free rigid body the inertia tensor is constant (independent of R), so + M_net needs to learn a constant diagonal PSD matrix. Pretraining anchors it + near the correct value before joint training begins. + """ + if n_steps <= 0: + return + + inner = _inner(model) + device = q_samples.device + dtype = q_samples.dtype + + I_inv_diag = torch.tensor([1.0/I1, 1.0/I2, 1.0/I3], dtype=dtype, device=device) + target_mat = torch.diag(I_inv_diag) + target = target_mat.unsqueeze(0).expand(q_samples.shape[0], 3, 3) + + print(f"\nPretraining M_net for {n_steps} steps (lr={lr})") + print(f" target: diag(1/I1,1/I2,1/I3) = " + f"diag({1/I1:.4f}, {1/I2:.4f}, {1/I3:.4f})") + print(f" {q_samples.shape[0]} q-samples drawn from training data") + + optim = torch.optim.Adam(inner.M_net.parameters(), lr=lr, weight_decay=1e-4) + q_no_grad = q_samples.detach() + + initial_loss = None + for step in range(n_steps): + M_pred = inner.M_net(q_no_grad) + loss = (M_pred - target).pow(2).mean() + loss.backward() + optim.step() + optim.zero_grad() + if initial_loss is None: + initial_loss = loss.item() + if step % max(1, print_every) == 0 or step == n_steps - 1: + print(f" pretrain step {step:>4d}: loss={loss.item():.3e}") + + with torch.no_grad(): + M_check = inner.M_net(q_no_grad[:1]) + deviation = (M_check - target_mat).abs().max().item() + print(f" pretrain done. initial={initial_loss:.3e} " + f"final={loss.item():.3e} max|M(q₀)−target|={deviation:.3e}") + + +# ───────────────────────────────────────────────────────────────────────────── +# Training +# ───────────────────────────────────────────────────────────────────────────── + +def train(args): + float_type = torch.float32 + torch.set_default_dtype(torch.float32) + + device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + run_name = build_run_name(args) + args.save_dir = os.path.join(args.save_dir, run_name) + os.makedirs(args.save_dir, exist_ok=True) + print(f"Run dir : {args.save_dir}") + + if args.verbose: + print(f"Start training (fp32) num_points={args.num_points} " + f"solver={args.solver} eval_every={args.eval_every} device={device}") + + # ── Load dataset ────────────────────────────────────────────────────── + print(f"Loading dataset: {args.data_path}") + with open(args.data_path, 'rb') as f: + data = pickle.load(f) + + inertia = data['inertia_info'][args.config_idx] + I1, I2, I3 = inertia['I1'], inertia['I2'], inertia['I3'] + print(f"\nConfig {args.config_idx}: I1={I1:.6f} I2={I2:.6f} I3={I3:.6f} kg·m²") + print(f" I2/I1={I2/I1:.2f} I3/I1={I3/I1:.2f} " + f"(I3-I1)/I2={(I3-I1)/I2:.2f} (asymmetry index for Dzhanibekov strength)") + + # Slice to single config: (num_configs, T, N, 15) → (1, T, N, 15) + cfg_x = data['x'][[args.config_idx]] + cfg_test = data['test_x'][[args.config_idx]] + + # ── Build model ─────────────────────────────────────────────────────── + model = DissipativeSO3HamNODE( + device=device, u_dim=3, init_gain=args.init_gain).to(device) + + if args.fix_M: + _inner(model).M_net = FixedInertia( + I1=I1, I2=I2, I3=I3).to(device).to(float_type) + print(f"M_net fixed to diag(1/I1,1/I2,1/I3) — not trained.") + + print(f'Model: {get_model_parm_nums(model)} parameters') + + optim = torch.optim.Adam(model.parameters(), args.learn_rate, weight_decay=1e-4) + + # ── Arrange data ────────────────────────────────────────────────────── + train_x, t_eval = arrange_data(cfg_x, data['t'], num_points=args.num_points) + test_x, _ = arrange_data(cfg_test, data['t'], num_points=args.num_points) + train_x_cat = np.concatenate(train_x, axis=1) + test_x_cat = np.concatenate(test_x, axis=1) + + train_x_cat = torch.tensor(train_x_cat, requires_grad=True, + dtype=float_type).to(device) + test_x_cat = torch.tensor(test_x_cat, requires_grad=True, + dtype=float_type).to(device) + t_eval = torch.tensor(t_eval, requires_grad=True, + dtype=float_type).to(device) + + # ── M_net pretraining ───────────────────────────────────────────────── + if args.pretrain_M_steps > 0 and not args.fix_M: + q_pretrain = train_x_cat.detach().reshape(-1, 15)[:, :9] + pretrain_M_net( + model=model, + q_samples=q_pretrain, + n_steps=args.pretrain_M_steps, + lr=args.pretrain_M_lr, + print_every=args.pretrain_M_print_every, + I1=I1, I2=I2, I3=I3, + ) + + split = [9, 3, 3] + + stats = { + 'train_loss': [], 'train_l2_loss': [], 'train_geo_loss': [], + 'train_power_loss': [], + 'train_V_cons_loss': [], 'train_B_cons_loss': [], 'train_D_cons_loss': [], + 'forward_time': [], 'backward_time': [], 'nfe': [], + # Eval (windowed) — recorded every eval_every steps + 'eval_step': [], + 'test_loss': [], 'test_l2_loss': [], 'test_geo_loss': [], + 'eval_M_loss': [], 'eval_V_loss': [], + 'eval_Dw_loss': [], 'eval_g_loss': [], + # Config metadata saved alongside training stats + 'config_idx': args.config_idx, + 'I1': I1, 'I2': I2, 'I3': I3, + } + + dt_train = (t_eval[1] - t_eval[0]).detach().item() + + os.makedirs(args.save_dir, exist_ok=True) + label = '-so3ham' + stats_path = (f'{args.save_dir}/{args.name}{label}' + f'-{args.solver}-{args.num_points}p-stats.pkl') + + loss_buffer = [] + fwd_buffer = [] + bwd_buffer = [] + nfe_buffer = [] + + for step in range(args.total_steps + 1): + + # ── Training step ──────────────────────────────────────────────── + t = time.time() + train_x_hat = odeint(model, train_x_cat[0, :, :], t_eval, method=args.solver) + forward_time = time.time() - t + + target = train_x_cat[1:, :, :] + target_hat = train_x_hat[1:, :, :] + train_loss, train_l2_loss, train_geo_loss = rotmat_L2_geodesic_loss( + target, target_hat, split=split) + + if args.lambda_power > 0.0: + L_power = power_balance_loss(model, train_x_cat, dt_train) + else: + L_power = torch.zeros((), device=device, dtype=float_type) + + if args.lambda_V > 0.0 or args.lambda_B > 0.0 or args.lambda_D > 0.0: + L_V, L_B, L_D = consistency_subnet_losses(model, train_x_cat, dt_train) + else: + L_V = torch.zeros((), device=device, dtype=float_type) + L_B = torch.zeros((), device=device, dtype=float_type) + L_D = torch.zeros((), device=device, dtype=float_type) + + total_loss = (train_loss + + args.lambda_power * L_power + + args.lambda_V * L_V + + args.lambda_B * L_B + + args.lambda_D * L_D) + + t = time.time() + total_loss.backward() + optim.step() + optim.zero_grad() + backward_time = time.time() - t + + loss_buffer.append(torch.stack([ + total_loss.detach(), train_l2_loss.detach(), train_geo_loss.detach(), + L_power.detach(), L_V.detach(), L_B.detach(), L_D.detach() + ])) + fwd_buffer.append(forward_time) + bwd_buffer.append(backward_time) + nfe = getattr(model, 'nfe', + getattr(getattr(model, '_orig_mod', model), 'nfe', 0)) + nfe_buffer.append(nfe) + + # ── Windowed eval + diagnostics + checkpoint ───────────────────── + if step % args.eval_every == 0: + with torch.no_grad(): + test_x_hat = odeint( + model, test_x_cat[0, :, :], t_eval, method=args.solver) + tgt = test_x_cat[1:, :, :] + tgt_hat = test_x_hat[1:, :, :] + test_loss, test_l2_loss, test_geo_loss = rotmat_L2_geodesic_loss( + tgt, tgt_hat, split=split) + subnet = subnet_physics_mse_tennis( + model, test_x_hat, I1=I1, I2=I2, I3=I3) + + if loss_buffer: + drained = torch.stack(loss_buffer, dim=0).cpu().numpy() # (N, 7) + stats['train_loss'].extend(drained[:, 0].tolist()) + stats['train_l2_loss'].extend(drained[:, 1].tolist()) + stats['train_geo_loss'].extend(drained[:, 2].tolist()) + stats['train_power_loss'].extend(drained[:, 3].tolist()) + stats['train_V_cons_loss'].extend(drained[:, 4].tolist()) + stats['train_B_cons_loss'].extend(drained[:, 5].tolist()) + stats['train_D_cons_loss'].extend(drained[:, 6].tolist()) + stats['forward_time'].extend(fwd_buffer) + stats['backward_time'].extend(bwd_buffer) + stats['nfe'].extend(nfe_buffer) + loss_buffer = []; fwd_buffer = []; bwd_buffer = []; nfe_buffer = [] + + test_pack = torch.stack([ + test_loss.detach(), test_l2_loss.detach(), test_geo_loss.detach() + ]).cpu().numpy() + train_total = stats['train_loss'][-1] + train_l2 = stats['train_l2_loss'][-1] + train_geo = stats['train_geo_loss'][-1] + + stats['eval_step'].append(step) + stats['test_loss'].append(float(test_pack[0])) + stats['test_l2_loss'].append(float(test_pack[1])) + stats['test_geo_loss'].append(float(test_pack[2])) + stats['eval_M_loss'].append(subnet['M_loss']) + stats['eval_V_loss'].append(subnet['V_loss']) + stats['eval_Dw_loss'].append(subnet['Dw_loss']) + stats['eval_g_loss'].append(subnet['g_loss']) + + train_pow = stats['train_power_loss'][-1] + train_LV = stats['train_V_cons_loss'][-1] + train_LB = stats['train_B_cons_loss'][-1] + train_LD = stats['train_D_cons_loss'][-1] + + print(f"[step {step:>6d}]") + print(f" train: total={train_total:.4e} " + f"L2={train_l2:.4e} geo={train_geo:.4e} " + f"power={train_pow:.4e}") + print(f" cons : L_V={train_LV:.4e} " + f"L_B={train_LB:.4e} L_D={train_LD:.4e}") + print(f" test : total={test_pack[0]:.4e} " + f"L2={test_pack[1]:.4e} geo={test_pack[2]:.4e}") + print(f" subnet MSE M={subnet['M_loss']:.3e} " + f"V={subnet['V_loss']:.3e} Dw={subnet['Dw_loss']:.3e} " + f"g={subnet['g_loss']:.3e} | nfe={nfe}") + + ckpt = (f'{args.save_dir}/{args.name}{label}' + f'-{args.solver}-{args.num_points}p-{step}.tar') + torch.save(_state_dict(model), ckpt) + to_pickle(stats, stats_path) + + # ── Final per-trajectory eval ───────────────────────────────────────── + cfg_x_full = torch.tensor(cfg_x, requires_grad=True, + dtype=float_type).to(device) + cfg_test_full = torch.tensor(cfg_test, requires_grad=True, + dtype=float_type).to(device) + t_full = torch.tensor(data['t'], requires_grad=True, + dtype=float_type).to(device) + + train_loss_l, test_loss_l = [], [] + train_l2_l, test_l2_l = [], [] + train_geo_l, test_geo_l = [], [] + train_data_hat, test_data_hat = [], [] + + for i in range(cfg_x_full.shape[0]): # single iteration for Stage A + train_x_hat = odeint( + model, cfg_x_full[i, 0, :, :], t_full, method=args.solver) + total_loss, l2_loss, geo_loss = traj_rotmat_L2_geodesic_loss( + cfg_x_full[i, :, :, :], train_x_hat, split=split) + train_loss_l.append(total_loss) + train_l2_l.append(l2_loss) + train_geo_l.append(geo_loss) + train_data_hat.append(train_x_hat.detach().cpu().numpy()) + + test_x_hat = odeint( + model, cfg_test_full[i, 0, :, :], t_full, method=args.solver) + total_loss, l2_loss, geo_loss = traj_rotmat_L2_geodesic_loss( + cfg_test_full[i, :, :, :], test_x_hat, split=split) + test_loss_l.append(total_loss) + test_l2_l.append(l2_loss) + test_geo_l.append(geo_loss) + test_data_hat.append(test_x_hat.detach().cpu().numpy()) + + def _per_traj(loss_list): + return torch.sum(torch.cat(loss_list, dim=1), dim=0) + + train_loss_pt = _per_traj(train_loss_l) + test_loss_pt = _per_traj(test_loss_l) + train_l2_pt = _per_traj(train_l2_l) + test_l2_pt = _per_traj(test_l2_l) + train_geo_pt = _per_traj(train_geo_l) + test_geo_pt = _per_traj(test_geo_l) + + print('Final trajectory train loss {:.4e} +/- {:.4e}\n' + 'Final trajectory test loss {:.4e} +/- {:.4e}'.format( + train_loss_pt.mean().item(), train_loss_pt.std().item(), + test_loss_pt.mean().item(), test_loss_pt.std().item())) + print('Final trajectory train geo {:.4e} +/- {:.4e}\n' + 'Final trajectory test geo {:.4e} +/- {:.4e}'.format( + train_geo_pt.mean().item(), train_geo_pt.std().item(), + test_geo_pt.mean().item(), test_geo_pt.std().item())) + + stats['traj_train_loss'] = train_loss_pt.detach().cpu().numpy() + stats['traj_test_loss'] = test_loss_pt.detach().cpu().numpy() + stats['train_x'] = cfg_x_full.detach().cpu().numpy() + stats['test_x'] = cfg_test_full.detach().cpu().numpy() + stats['train_x_hat'] = np.array(train_data_hat) + stats['test_x_hat'] = np.array(test_data_hat) + stats['t_eval'] = t_full.detach().cpu().numpy() + return model, stats + + +# ───────────────────────────────────────────────────────────────────────────── +# Entry point +# ───────────────────────────────────────────────────────────────────────────── + +if __name__ == "__main__": + args = get_args() + model, stats = train(args) + + os.makedirs(args.save_dir, exist_ok=True) + label = '-so3ham' + final_ckpt = (f'{args.save_dir}/{args.name}{label}' + f'-{args.solver}-{args.num_points}p.tar') + torch.save(_state_dict(model), final_ckpt) + final_stats = (f'{args.save_dir}/{args.name}{label}' + f'-{args.solver}-{args.num_points}p-stats.pkl') + print("Saved final stats:", final_stats) + to_pickle(stats, final_stats) diff --git a/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/train_e2e.py b/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/train_e2e.py new file mode 100644 index 0000000..8e59ef1 --- /dev/null +++ b/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/train_e2e.py @@ -0,0 +1,545 @@ +"""End-to-end Stage A training: M_net is learned, not fixed. + +Copied from train.py and extended to fix the Hamiltonian identifiability +degeneracy that arises when M_net is trained jointly with V_net. + +Problem: without constraints, the optimizer can fit trajectories equally well +with (M_wrong, V_compensating) as with (M_true, V=0). The degeneracy +manifests as M_loss diverging while test_geo_loss converges — M_net drifts +away from ground truth while V_net absorbs the residual. + +Fixes applied here vs. train.py: + 1. --pretrain_M_lr defaults to 0.1 (was 1e-3 — too small for O(100) targets) + 2. --pretrain_M_steps defaults to 1000 (was 200) + 3. --lambda_V_zero: penalises mean(V_net(q)²) directly. For a free rigid + body V=0 is the correct physics; suppressing V forces M_net to account + for all the kinetic energy and breaks the (M,V) trade-off. + +Stage A end-to-end pass criteria (same as train.py): + 1. eval_M_loss < 1e-4 + 2. test_geo_loss < 0.01 rad² + 3. eval_V_loss small + 4. det(R̂) ≈ 1 (guaranteed by architecture) +""" +import torch, argparse +import numpy as np +import os, sys +import time +import pickle + +THIS_FILE_DIR = os.path.dirname(os.path.abspath(__file__)) +PROJECT_ROOT = os.path.abspath(os.path.join(THIS_FILE_DIR, '../../../..')) + +# Shared network.py and loss_utils.py live in the pendulum model directory. +# The architecture is identical; we reuse without copying. +PENDULUM_ODE_DIR = os.path.join( + PROJECT_ROOT, 'src/models/3D_SO3_Windy_Pendulum/ph_nn_ode_v2') + +sys.path.insert(0, PROJECT_ROOT) +sys.path.insert(0, os.path.join(PROJECT_ROOT, 'src/utils')) +sys.path.insert(0, os.path.join(PROJECT_ROOT, 'datasets')) +sys.path.insert(0, PENDULUM_ODE_DIR) + +from torchdiffeq import odeint + +from ode_utils import to_pickle +from subnet_diagnostics_tennis import subnet_physics_mse_tennis +from tennis_racket_3d_datagen import arrange_data +from network import DissipativeSO3HamNODE +from loss_utils import ( + rotmat_L2_geodesic_loss_safe as rotmat_L2_geodesic_loss, + traj_rotmat_L2_geodesic_loss_safe as traj_rotmat_L2_geodesic_loss, + power_balance_loss, + consistency_subnet_losses, +) + + +DEFAULT_SAVE_DIR = os.path.join(THIS_FILE_DIR, 'data', 'run_tr3d_e2e_fp32') +DEFAULT_DATA_PATH = os.path.join( + PROJECT_ROOT, + 'data/tennis_data/' + 'tr3d_dataset_dist0p0_obs_noise0p0_perturb0p05_ncfg4_steps100.pkl' +) + + +# ───────────────────────────────────────────────────────────────────────────── +# Fixed inverse-inertia module (drop-in M_net for Stage A sanity checks) +# ───────────────────────────────────────────────────────────────────────────── + +class FixedInertia(torch.nn.Module): + """Drop-in M_net that returns diag(1/I1, 1/I2, 1/I3) — no learnable parameters. + + M_net outputs M⁻¹ (the inverse inertia tensor). For the tennis racket this + is diag(1/I1, 1/I2, 1/I3). Use --fix_M to pin M_net here and let the + optimizer focus on V_net, Dw_net, g_net. + """ + def __init__(self, I1: float, I2: float, I3: float): + super().__init__() + I_inv = torch.tensor([1.0/I1, 1.0/I2, 1.0/I3], dtype=torch.float32) + self.register_buffer('M_inv', torch.diag(I_inv)) + + def forward(self, q): + return self.M_inv.unsqueeze(0).expand(q.shape[0], 3, 3) + + +# ───────────────────────────────────────────────────────────────────────────── +# CLI +# ───────────────────────────────────────────────────────────────────────────── + +def get_args(): + parser = argparse.ArgumentParser(description=None) + parser.add_argument('--learn_rate', default=1e-3, type=float) + parser.add_argument('--total_steps', default=10000, type=int) + parser.add_argument('--eval_every', default=50, type=int, + help='windowed eval + diagnostics + checkpoint cadence') + parser.add_argument('--name', default='tr3d', type=str) + parser.add_argument('--verbose', action='store_true') + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--save_dir', default=DEFAULT_SAVE_DIR, type=str) + parser.add_argument('--data_path', default=DEFAULT_DATA_PATH, type=str, + help='path to the tennis-racket .pkl dataset') + parser.add_argument('--gpu', type=int, default=0) + parser.add_argument('--num_points', type=int, default=5) + parser.add_argument('--solver', default='rk4', type=str) + parser.add_argument('--init_gain', default=0.5, type=float) + + parser.add_argument('--samples', type=int, default=25) + parser.add_argument('--timesteps', type=int, default=100) + parser.add_argument('--obs_noise_std', type=float, default=0.0) + parser.add_argument('--disturbance_torque_std', type=float, default=0.0) + parser.add_argument('--perturb_std', type=float, default=0.05) + + # Which geometry config in the dataset to train on (0-indexed). + parser.add_argument('--config_idx', type=int, default=0, + help='index into data[inertia_info] to select (0-3 for the default 4-config dataset)') + + # M_net pretraining: anchors M_net near diag(1/I1,1/I2,1/I3) before joint training. + # Defaults are higher than train.py: targets are O(100) so lr=0.1 and 1000 steps + # are needed to converge before joint training begins. + parser.add_argument('--pretrain_M_steps', type=int, default=1000) + parser.add_argument('--pretrain_M_lr', type=float, default=0.1) + parser.add_argument('--pretrain_M_print_every', type=int, default=100) + + # Physics-informed auxiliary losses (off by default). + parser.add_argument('--lambda_power', type=float, default=0.0, + help='weight for power-balance loss; 0 disables') + parser.add_argument('--lambda_V', type=float, default=0.0, + help='weight for V back-solving loss; 0 disables') + parser.add_argument('--lambda_B', type=float, default=0.0, + help='weight for B (input coupling) back-solving loss; 0 disables') + parser.add_argument('--lambda_D', type=float, default=0.0, + help='weight for D (dissipation) back-solving loss; 0 disables') + # V-suppression: penalises mean(V_net(q)²) to break the (M,V) Hamiltonian degeneracy. + # For a torque-free free body, V=0 is exact physics; this forces M_net to carry all + # kinetic energy and makes it uniquely identifiable from trajectory data. + parser.add_argument('--lambda_V_zero', type=float, default=1.0, + help='weight for V=0 suppression loss mean(V(q)²); 0 disables') + + # Pin M_net to the analytic ground-truth diag(1/I1,1/I2,1/I3). + parser.add_argument('--fix_M', action='store_true', + help='use FixedInertia (non-learnable M_net) with true I from the data') + + return parser.parse_args() + + +# ───────────────────────────────────────────────────────────────────────────── +# Utilities +# ───────────────────────────────────────────────────────────────────────────── + +def get_model_parm_nums(model): + return sum(p.nelement() for p in model.parameters()) + + +def _fmt_num(x): + """Format a number for filesystem paths: 0.01 → 0p01, -1.0 → n1.""" + s = f"{x:g}" + return s.replace('.', 'p').replace('-', 'n').replace('+', '') + + +def build_run_name(args): + """Build a run-folder name encoding training specs + a YYMMDD-HHMM stamp.""" + parts = [ + f"obs{_fmt_num(args.obs_noise_std)}", + f"dist{_fmt_num(args.disturbance_torque_std)}", + f"cfg{args.config_idx}", + f"lP{_fmt_num(args.lambda_power)}", + f"lV0{_fmt_num(args.lambda_V_zero)}", + f"lV{_fmt_num(args.lambda_V)}", + f"lB{_fmt_num(args.lambda_B)}", + f"lD{_fmt_num(args.lambda_D)}", + f"lr{_fmt_num(args.learn_rate)}", + f"s{args.total_steps}", + f"np{args.num_points}", + f"smp{args.samples}", + f"T{args.timesteps}", + f"{args.solver}", + f"seed{args.seed}", + ] + if args.fix_M: + parts.append('fixM') + stamp = time.strftime('%y%m%d-%H%M%S') + return '_'.join(parts) + '_' + stamp + + +def _state_dict(model): + return (model._orig_mod if hasattr(model, '_orig_mod') else model).state_dict() + + +def _inner(model): + """Unwrap a possibly torch.compile-wrapped model.""" + return model._orig_mod if hasattr(model, '_orig_mod') else model + + +# ───────────────────────────────────────────────────────────────────────────── +# M_net pretraining +# ───────────────────────────────────────────────────────────────────────────── + +def pretrain_M_net(model, q_samples, n_steps, lr, print_every, I1, I2, I3): + """Pretrain M_net to output the true inverse inertia diag(1/I1, 1/I2, 1/I3). + + For a free rigid body the inertia tensor is constant (independent of R), so + M_net needs to learn a constant diagonal PSD matrix. Pretraining anchors it + near the correct value before joint training begins. + """ + if n_steps <= 0: + return + + inner = _inner(model) + device = q_samples.device + dtype = q_samples.dtype + + I_inv_diag = torch.tensor([1.0/I1, 1.0/I2, 1.0/I3], dtype=dtype, device=device) + target_mat = torch.diag(I_inv_diag) + target = target_mat.unsqueeze(0).expand(q_samples.shape[0], 3, 3) + + print(f"\nPretraining M_net for {n_steps} steps (lr={lr})") + print(f" target: diag(1/I1,1/I2,1/I3) = " + f"diag({1/I1:.4f}, {1/I2:.4f}, {1/I3:.4f})") + print(f" {q_samples.shape[0]} q-samples drawn from training data") + + optim = torch.optim.Adam(inner.M_net.parameters(), lr=lr, weight_decay=1e-4) + q_no_grad = q_samples.detach() + + initial_loss = None + for step in range(n_steps): + M_pred = inner.M_net(q_no_grad) + loss = (M_pred - target).pow(2).mean() + loss.backward() + optim.step() + optim.zero_grad() + if initial_loss is None: + initial_loss = loss.item() + if step % max(1, print_every) == 0 or step == n_steps - 1: + print(f" pretrain step {step:>4d}: loss={loss.item():.3e}") + + with torch.no_grad(): + M_check = inner.M_net(q_no_grad[:1]) + deviation = (M_check - target_mat).abs().max().item() + print(f" pretrain done. initial={initial_loss:.3e} " + f"final={loss.item():.3e} max|M(q₀)−target|={deviation:.3e}") + + +# ───────────────────────────────────────────────────────────────────────────── +# Training +# ───────────────────────────────────────────────────────────────────────────── + +def train(args): + float_type = torch.float32 + torch.set_default_dtype(torch.float32) + + device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + run_name = build_run_name(args) + args.save_dir = os.path.join(args.save_dir, run_name) + os.makedirs(args.save_dir, exist_ok=True) + print(f"Run dir : {args.save_dir}") + + if args.verbose: + print(f"Start training (fp32) num_points={args.num_points} " + f"solver={args.solver} eval_every={args.eval_every} device={device}") + + # ── Load dataset ────────────────────────────────────────────────────── + print(f"Loading dataset: {args.data_path}") + with open(args.data_path, 'rb') as f: + data = pickle.load(f) + + inertia = data['inertia_info'][args.config_idx] + I1, I2, I3 = inertia['I1'], inertia['I2'], inertia['I3'] + print(f"\nConfig {args.config_idx}: I1={I1:.6f} I2={I2:.6f} I3={I3:.6f} kg·m²") + print(f" I2/I1={I2/I1:.2f} I3/I1={I3/I1:.2f} " + f"(I3-I1)/I2={(I3-I1)/I2:.2f} (asymmetry index for Dzhanibekov strength)") + + # Slice to single config: (num_configs, T, N, 15) → (1, T, N, 15) + cfg_x = data['x'][[args.config_idx]] + cfg_test = data['test_x'][[args.config_idx]] + + # ── Build model ─────────────────────────────────────────────────────── + model = DissipativeSO3HamNODE( + device=device, u_dim=3, init_gain=args.init_gain).to(device) + + if args.fix_M: + _inner(model).M_net = FixedInertia( + I1=I1, I2=I2, I3=I3).to(device).to(float_type) + print(f"M_net fixed to diag(1/I1,1/I2,1/I3) — not trained.") + + print(f'Model: {get_model_parm_nums(model)} parameters') + + optim = torch.optim.Adam(model.parameters(), args.learn_rate, weight_decay=1e-4) + + # ── Arrange data ────────────────────────────────────────────────────── + train_x, t_eval = arrange_data(cfg_x, data['t'], num_points=args.num_points) + test_x, _ = arrange_data(cfg_test, data['t'], num_points=args.num_points) + train_x_cat = np.concatenate(train_x, axis=1) + test_x_cat = np.concatenate(test_x, axis=1) + + train_x_cat = torch.tensor(train_x_cat, requires_grad=True, + dtype=float_type).to(device) + test_x_cat = torch.tensor(test_x_cat, requires_grad=True, + dtype=float_type).to(device) + t_eval = torch.tensor(t_eval, requires_grad=True, + dtype=float_type).to(device) + + # ── M_net pretraining ───────────────────────────────────────────────── + if args.pretrain_M_steps > 0 and not args.fix_M: + q_pretrain = train_x_cat.detach().reshape(-1, 15)[:, :9] + pretrain_M_net( + model=model, + q_samples=q_pretrain, + n_steps=args.pretrain_M_steps, + lr=args.pretrain_M_lr, + print_every=args.pretrain_M_print_every, + I1=I1, I2=I2, I3=I3, + ) + + split = [9, 3, 3] + + stats = { + 'train_loss': [], 'train_l2_loss': [], 'train_geo_loss': [], + 'train_power_loss': [], 'train_V_zero_loss': [], + 'train_V_cons_loss': [], 'train_B_cons_loss': [], 'train_D_cons_loss': [], + 'forward_time': [], 'backward_time': [], 'nfe': [], + # Eval (windowed) — recorded every eval_every steps + 'eval_step': [], + 'test_loss': [], 'test_l2_loss': [], 'test_geo_loss': [], + 'eval_M_loss': [], 'eval_V_loss': [], + 'eval_Dw_loss': [], 'eval_g_loss': [], + # Config metadata saved alongside training stats + 'config_idx': args.config_idx, + 'I1': I1, 'I2': I2, 'I3': I3, + } + + dt_train = (t_eval[1] - t_eval[0]).detach().item() + + os.makedirs(args.save_dir, exist_ok=True) + label = '-so3ham' + stats_path = (f'{args.save_dir}/{args.name}{label}' + f'-{args.solver}-{args.num_points}p-stats.pkl') + + loss_buffer = [] + fwd_buffer = [] + bwd_buffer = [] + nfe_buffer = [] + + for step in range(args.total_steps + 1): + + # ── Training step ──────────────────────────────────────────────── + t = time.time() + train_x_hat = odeint(model, train_x_cat[0, :, :], t_eval, method=args.solver) + forward_time = time.time() - t + + target = train_x_cat[1:, :, :] + target_hat = train_x_hat[1:, :, :] + train_loss, train_l2_loss, train_geo_loss = rotmat_L2_geodesic_loss( + target, target_hat, split=split) + + if args.lambda_power > 0.0: + L_power = power_balance_loss(model, train_x_cat, dt_train) + else: + L_power = torch.zeros((), device=device, dtype=float_type) + + if args.lambda_V > 0.0 or args.lambda_B > 0.0 or args.lambda_D > 0.0: + L_V, L_B, L_D = consistency_subnet_losses(model, train_x_cat, dt_train) + else: + L_V = torch.zeros((), device=device, dtype=float_type) + L_B = torch.zeros((), device=device, dtype=float_type) + L_D = torch.zeros((), device=device, dtype=float_type) + + if args.lambda_V_zero > 0.0: + q_flat = train_x_cat.detach().reshape(-1, 15)[:, :9] + V_vals = _inner(model).V_net(q_flat).squeeze(-1) + L_V_zero = (V_vals ** 2).mean() + else: + L_V_zero = torch.zeros((), device=device, dtype=float_type) + + total_loss = (train_loss + + args.lambda_power * L_power + + args.lambda_V * L_V + + args.lambda_B * L_B + + args.lambda_D * L_D + + args.lambda_V_zero * L_V_zero) + + t = time.time() + total_loss.backward() + optim.step() + optim.zero_grad() + backward_time = time.time() - t + + loss_buffer.append(torch.stack([ + total_loss.detach(), train_l2_loss.detach(), train_geo_loss.detach(), + L_power.detach(), L_V_zero.detach(), L_V.detach(), L_B.detach(), L_D.detach() + ])) + fwd_buffer.append(forward_time) + bwd_buffer.append(backward_time) + nfe = getattr(model, 'nfe', + getattr(getattr(model, '_orig_mod', model), 'nfe', 0)) + nfe_buffer.append(nfe) + + # ── Windowed eval + diagnostics + checkpoint ───────────────────── + if step % args.eval_every == 0: + with torch.no_grad(): + test_x_hat = odeint( + model, test_x_cat[0, :, :], t_eval, method=args.solver) + tgt = test_x_cat[1:, :, :] + tgt_hat = test_x_hat[1:, :, :] + test_loss, test_l2_loss, test_geo_loss = rotmat_L2_geodesic_loss( + tgt, tgt_hat, split=split) + subnet = subnet_physics_mse_tennis( + model, test_x_hat, I1=I1, I2=I2, I3=I3) + + if loss_buffer: + drained = torch.stack(loss_buffer, dim=0).cpu().numpy() # (N, 8) + stats['train_loss'].extend(drained[:, 0].tolist()) + stats['train_l2_loss'].extend(drained[:, 1].tolist()) + stats['train_geo_loss'].extend(drained[:, 2].tolist()) + stats['train_power_loss'].extend(drained[:, 3].tolist()) + stats['train_V_zero_loss'].extend(drained[:, 4].tolist()) + stats['train_V_cons_loss'].extend(drained[:, 5].tolist()) + stats['train_B_cons_loss'].extend(drained[:, 6].tolist()) + stats['train_D_cons_loss'].extend(drained[:, 7].tolist()) + stats['forward_time'].extend(fwd_buffer) + stats['backward_time'].extend(bwd_buffer) + stats['nfe'].extend(nfe_buffer) + loss_buffer = []; fwd_buffer = []; bwd_buffer = []; nfe_buffer = [] + + test_pack = torch.stack([ + test_loss.detach(), test_l2_loss.detach(), test_geo_loss.detach() + ]).cpu().numpy() + train_total = stats['train_loss'][-1] + train_l2 = stats['train_l2_loss'][-1] + train_geo = stats['train_geo_loss'][-1] + + stats['eval_step'].append(step) + stats['test_loss'].append(float(test_pack[0])) + stats['test_l2_loss'].append(float(test_pack[1])) + stats['test_geo_loss'].append(float(test_pack[2])) + stats['eval_M_loss'].append(subnet['M_loss']) + stats['eval_V_loss'].append(subnet['V_loss']) + stats['eval_Dw_loss'].append(subnet['Dw_loss']) + stats['eval_g_loss'].append(subnet['g_loss']) + + train_pow = stats['train_power_loss'][-1] + train_V_zero = stats['train_V_zero_loss'][-1] + train_LV = stats['train_V_cons_loss'][-1] + train_LB = stats['train_B_cons_loss'][-1] + train_LD = stats['train_D_cons_loss'][-1] + + print(f"[step {step:>6d}]") + print(f" train: total={train_total:.4e} " + f"L2={train_l2:.4e} geo={train_geo:.4e} " + f"power={train_pow:.4e} V0={train_V_zero:.4e}") + print(f" cons : L_V={train_LV:.4e} " + f"L_B={train_LB:.4e} L_D={train_LD:.4e}") + print(f" test : total={test_pack[0]:.4e} " + f"L2={test_pack[1]:.4e} geo={test_pack[2]:.4e}") + print(f" subnet MSE M={subnet['M_loss']:.3e} " + f"V={subnet['V_loss']:.3e} Dw={subnet['Dw_loss']:.3e} " + f"g={subnet['g_loss']:.3e} | nfe={nfe}") + + ckpt = (f'{args.save_dir}/{args.name}{label}' + f'-{args.solver}-{args.num_points}p-{step}.tar') + torch.save(_state_dict(model), ckpt) + to_pickle(stats, stats_path) + + # ── Final per-trajectory eval ───────────────────────────────────────── + cfg_x_full = torch.tensor(cfg_x, requires_grad=True, + dtype=float_type).to(device) + cfg_test_full = torch.tensor(cfg_test, requires_grad=True, + dtype=float_type).to(device) + t_full = torch.tensor(data['t'], requires_grad=True, + dtype=float_type).to(device) + + train_loss_l, test_loss_l = [], [] + train_l2_l, test_l2_l = [], [] + train_geo_l, test_geo_l = [], [] + train_data_hat, test_data_hat = [], [] + + for i in range(cfg_x_full.shape[0]): # single iteration for Stage A + train_x_hat = odeint( + model, cfg_x_full[i, 0, :, :], t_full, method=args.solver) + total_loss, l2_loss, geo_loss = traj_rotmat_L2_geodesic_loss( + cfg_x_full[i, :, :, :], train_x_hat, split=split) + train_loss_l.append(total_loss) + train_l2_l.append(l2_loss) + train_geo_l.append(geo_loss) + train_data_hat.append(train_x_hat.detach().cpu().numpy()) + + test_x_hat = odeint( + model, cfg_test_full[i, 0, :, :], t_full, method=args.solver) + total_loss, l2_loss, geo_loss = traj_rotmat_L2_geodesic_loss( + cfg_test_full[i, :, :, :], test_x_hat, split=split) + test_loss_l.append(total_loss) + test_l2_l.append(l2_loss) + test_geo_l.append(geo_loss) + test_data_hat.append(test_x_hat.detach().cpu().numpy()) + + def _per_traj(loss_list): + return torch.sum(torch.cat(loss_list, dim=1), dim=0) + + train_loss_pt = _per_traj(train_loss_l) + test_loss_pt = _per_traj(test_loss_l) + train_l2_pt = _per_traj(train_l2_l) + test_l2_pt = _per_traj(test_l2_l) + train_geo_pt = _per_traj(train_geo_l) + test_geo_pt = _per_traj(test_geo_l) + + print('Final trajectory train loss {:.4e} +/- {:.4e}\n' + 'Final trajectory test loss {:.4e} +/- {:.4e}'.format( + train_loss_pt.mean().item(), train_loss_pt.std().item(), + test_loss_pt.mean().item(), test_loss_pt.std().item())) + print('Final trajectory train geo {:.4e} +/- {:.4e}\n' + 'Final trajectory test geo {:.4e} +/- {:.4e}'.format( + train_geo_pt.mean().item(), train_geo_pt.std().item(), + test_geo_pt.mean().item(), test_geo_pt.std().item())) + + stats['traj_train_loss'] = train_loss_pt.detach().cpu().numpy() + stats['traj_test_loss'] = test_loss_pt.detach().cpu().numpy() + stats['train_x'] = cfg_x_full.detach().cpu().numpy() + stats['test_x'] = cfg_test_full.detach().cpu().numpy() + stats['train_x_hat'] = np.array(train_data_hat) + stats['test_x_hat'] = np.array(test_data_hat) + stats['t_eval'] = t_full.detach().cpu().numpy() + return model, stats + + +# ───────────────────────────────────────────────────────────────────────────── +# Entry point +# ───────────────────────────────────────────────────────────────────────────── + +if __name__ == "__main__": + args = get_args() + model, stats = train(args) + + os.makedirs(args.save_dir, exist_ok=True) + label = '-so3ham' + final_ckpt = (f'{args.save_dir}/{args.name}{label}' + f'-{args.solver}-{args.num_points}p.tar') + torch.save(_state_dict(model), final_ckpt) + final_stats = (f'{args.save_dir}/{args.name}{label}' + f'-{args.solver}-{args.num_points}p-stats.pkl') + print("Saved final stats:", final_stats) + to_pickle(stats, final_stats) diff --git a/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/train_stageB.py b/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/train_stageB.py new file mode 100644 index 0000000..ce5583b --- /dev/null +++ b/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/train_stageB.py @@ -0,0 +1,526 @@ +"""Stage B training: viscous friction data, learn Dw_net → friction_coeff · I₃. + +Adapted from train.py. Single change from Stage A: + - --data_path defaults to the friction dataset (fric0p01). + - --friction_coeff is read and passed to subnet_physics_mse_tennis so the + Dw diagnostic target is friction_coeff · I₃ instead of 0. + - --fix_M still recommended (M is pinned to ground truth, same as Stage A). + +Stage B pass criteria: + 1. eval_M_loss = 0 (pinned via --fix_M) + 2. test_geo_loss < 0.01 rad² + 3. eval_Dw_loss converges toward (friction_coeff² / 3) — the MSE of a + perfectly-learned scalar-times-identity dissipation matrix goes to 0, + so eval_Dw_loss should be much smaller than friction_coeff² + 4. det(R̂) ≈ 1 throughout +""" +import torch, argparse +import numpy as np +import os, sys +import time +import pickle + +THIS_FILE_DIR = os.path.dirname(os.path.abspath(__file__)) +PROJECT_ROOT = os.path.abspath(os.path.join(THIS_FILE_DIR, '../../../..')) + +# Shared network.py and loss_utils.py live in the pendulum model directory. +# The architecture is identical; we reuse without copying. +PENDULUM_ODE_DIR = os.path.join( + PROJECT_ROOT, 'src/models/3D_SO3_Windy_Pendulum/ph_nn_ode_v2') + +sys.path.insert(0, PROJECT_ROOT) +sys.path.insert(0, os.path.join(PROJECT_ROOT, 'src/utils')) +sys.path.insert(0, os.path.join(PROJECT_ROOT, 'datasets')) +sys.path.insert(0, PENDULUM_ODE_DIR) + +from torchdiffeq import odeint + +from ode_utils import to_pickle +from subnet_diagnostics_tennis import subnet_physics_mse_tennis +from tennis_racket_3d_datagen import arrange_data +from network import DissipativeSO3HamNODE +from loss_utils import ( + rotmat_L2_geodesic_loss_safe as rotmat_L2_geodesic_loss, + traj_rotmat_L2_geodesic_loss_safe as traj_rotmat_L2_geodesic_loss, + power_balance_loss, + consistency_subnet_losses, +) + + +DEFAULT_SAVE_DIR = os.path.join(THIS_FILE_DIR, 'data', 'run_tr3d_stageB_fp32') +DEFAULT_DATA_PATH = os.path.join( + PROJECT_ROOT, + 'data/tennis_data/' + 'tr3d_dataset_dist0p0_obs_noise0p0_perturb0p05_fric0p01_ncfg4_steps100.pkl' +) + + +# ───────────────────────────────────────────────────────────────────────────── +# Fixed inverse-inertia module (drop-in M_net for Stage A sanity checks) +# ───────────────────────────────────────────────────────────────────────────── + +class FixedInertia(torch.nn.Module): + """Drop-in M_net that returns diag(1/I1, 1/I2, 1/I3) — no learnable parameters. + + M_net outputs M⁻¹ (the inverse inertia tensor). For the tennis racket this + is diag(1/I1, 1/I2, 1/I3). Use --fix_M to pin M_net here and let the + optimizer focus on V_net, Dw_net, g_net. + """ + def __init__(self, I1: float, I2: float, I3: float): + super().__init__() + I_inv = torch.tensor([1.0/I1, 1.0/I2, 1.0/I3], dtype=torch.float32) + self.register_buffer('M_inv', torch.diag(I_inv)) + + def forward(self, q): + return self.M_inv.unsqueeze(0).expand(q.shape[0], 3, 3) + + +# ───────────────────────────────────────────────────────────────────────────── +# CLI +# ───────────────────────────────────────────────────────────────────────────── + +def get_args(): + parser = argparse.ArgumentParser(description=None) + parser.add_argument('--learn_rate', default=1e-3, type=float) + parser.add_argument('--total_steps', default=10000, type=int) + parser.add_argument('--eval_every', default=50, type=int, + help='windowed eval + diagnostics + checkpoint cadence') + parser.add_argument('--name', default='tr3d', type=str) + parser.add_argument('--verbose', action='store_true') + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--save_dir', default=DEFAULT_SAVE_DIR, type=str) + parser.add_argument('--data_path', default=DEFAULT_DATA_PATH, type=str, + help='path to the tennis-racket .pkl dataset') + parser.add_argument('--gpu', type=int, default=0) + parser.add_argument('--num_points', type=int, default=5) + parser.add_argument('--solver', default='rk4', type=str) + parser.add_argument('--init_gain', default=0.5, type=float) + + parser.add_argument('--samples', type=int, default=25) + parser.add_argument('--timesteps', type=int, default=100) + parser.add_argument('--obs_noise_std', type=float, default=0.0) + parser.add_argument('--disturbance_torque_std', type=float, default=0.0) + parser.add_argument('--perturb_std', type=float, default=0.05) + + # Which geometry config in the dataset to train on (0-indexed). + parser.add_argument('--config_idx', type=int, default=0, + help='index into data[inertia_info] to select (0-3 for the default 4-config dataset)') + + # M_net pretraining: anchors M_net near diag(1/I1,1/I2,1/I3) before joint training. + parser.add_argument('--pretrain_M_steps', type=int, default=200) + parser.add_argument('--pretrain_M_lr', type=float, default=1e-3) + parser.add_argument('--pretrain_M_print_every', type=int, default=20) + + # Physics-informed auxiliary losses (off by default). + parser.add_argument('--lambda_power', type=float, default=0.0, + help='weight for power-balance loss; 0 disables') + parser.add_argument('--lambda_V', type=float, default=0.0, + help='weight for V back-solving loss; 0 disables') + parser.add_argument('--lambda_B', type=float, default=0.0, + help='weight for B (input coupling) back-solving loss; 0 disables') + parser.add_argument('--lambda_D', type=float, default=0.0, + help='weight for D (dissipation) back-solving loss; 0 disables') + + # Pin M_net to the analytic ground-truth diag(1/I1,1/I2,1/I3). + parser.add_argument('--fix_M', action='store_true', + help='use FixedInertia (non-learnable M_net) with true I from the data') + + # Friction coefficient — must match the dataset used. + # Passed to subnet_physics_mse_tennis so Dw diagnostic target = friction_coeff·I₃. + parser.add_argument('--friction_coeff', type=float, default=0.01, + help='viscous friction coefficient used in data generation') + + return parser.parse_args() + + +# ───────────────────────────────────────────────────────────────────────────── +# Utilities +# ───────────────────────────────────────────────────────────────────────────── + +def get_model_parm_nums(model): + return sum(p.nelement() for p in model.parameters()) + + +def _fmt_num(x): + """Format a number for filesystem paths: 0.01 → 0p01, -1.0 → n1.""" + s = f"{x:g}" + return s.replace('.', 'p').replace('-', 'n').replace('+', '') + + +def build_run_name(args): + """Build a run-folder name encoding training specs + a YYMMDD-HHMM stamp.""" + parts = [ + f"obs{_fmt_num(args.obs_noise_std)}", + f"dist{_fmt_num(args.disturbance_torque_std)}", + f"cfg{args.config_idx}", + f"lP{_fmt_num(args.lambda_power)}", + f"lV{_fmt_num(args.lambda_V)}", + f"lB{_fmt_num(args.lambda_B)}", + f"lD{_fmt_num(args.lambda_D)}", + f"lr{_fmt_num(args.learn_rate)}", + f"s{args.total_steps}", + f"np{args.num_points}", + f"smp{args.samples}", + f"T{args.timesteps}", + f"{args.solver}", + f"seed{args.seed}", + ] + if args.fix_M: + parts.append('fixM') + stamp = time.strftime('%y%m%d-%H%M%S') + return '_'.join(parts) + '_' + stamp + + +def _state_dict(model): + return (model._orig_mod if hasattr(model, '_orig_mod') else model).state_dict() + + +def _inner(model): + """Unwrap a possibly torch.compile-wrapped model.""" + return model._orig_mod if hasattr(model, '_orig_mod') else model + + +# ───────────────────────────────────────────────────────────────────────────── +# M_net pretraining +# ───────────────────────────────────────────────────────────────────────────── + +def pretrain_M_net(model, q_samples, n_steps, lr, print_every, I1, I2, I3): + """Pretrain M_net to output the true inverse inertia diag(1/I1, 1/I2, 1/I3). + + For a free rigid body the inertia tensor is constant (independent of R), so + M_net needs to learn a constant diagonal PSD matrix. Pretraining anchors it + near the correct value before joint training begins. + """ + if n_steps <= 0: + return + + inner = _inner(model) + device = q_samples.device + dtype = q_samples.dtype + + I_inv_diag = torch.tensor([1.0/I1, 1.0/I2, 1.0/I3], dtype=dtype, device=device) + target_mat = torch.diag(I_inv_diag) + target = target_mat.unsqueeze(0).expand(q_samples.shape[0], 3, 3) + + print(f"\nPretraining M_net for {n_steps} steps (lr={lr})") + print(f" target: diag(1/I1,1/I2,1/I3) = " + f"diag({1/I1:.4f}, {1/I2:.4f}, {1/I3:.4f})") + print(f" {q_samples.shape[0]} q-samples drawn from training data") + + optim = torch.optim.Adam(inner.M_net.parameters(), lr=lr, weight_decay=1e-4) + q_no_grad = q_samples.detach() + + initial_loss = None + for step in range(n_steps): + M_pred = inner.M_net(q_no_grad) + loss = (M_pred - target).pow(2).mean() + loss.backward() + optim.step() + optim.zero_grad() + if initial_loss is None: + initial_loss = loss.item() + if step % max(1, print_every) == 0 or step == n_steps - 1: + print(f" pretrain step {step:>4d}: loss={loss.item():.3e}") + + with torch.no_grad(): + M_check = inner.M_net(q_no_grad[:1]) + deviation = (M_check - target_mat).abs().max().item() + print(f" pretrain done. initial={initial_loss:.3e} " + f"final={loss.item():.3e} max|M(q₀)−target|={deviation:.3e}") + + +# ───────────────────────────────────────────────────────────────────────────── +# Training +# ───────────────────────────────────────────────────────────────────────────── + +def train(args): + float_type = torch.float32 + torch.set_default_dtype(torch.float32) + + device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + run_name = build_run_name(args) + args.save_dir = os.path.join(args.save_dir, run_name) + os.makedirs(args.save_dir, exist_ok=True) + print(f"Run dir : {args.save_dir}") + + if args.verbose: + print(f"Start training (fp32) num_points={args.num_points} " + f"solver={args.solver} eval_every={args.eval_every} device={device}") + + # ── Load dataset ────────────────────────────────────────────────────── + print(f"Loading dataset: {args.data_path}") + with open(args.data_path, 'rb') as f: + data = pickle.load(f) + + inertia = data['inertia_info'][args.config_idx] + I1, I2, I3 = inertia['I1'], inertia['I2'], inertia['I3'] + print(f"\nConfig {args.config_idx}: I1={I1:.6f} I2={I2:.6f} I3={I3:.6f} kg·m²") + print(f" I2/I1={I2/I1:.2f} I3/I1={I3/I1:.2f} " + f"(I3-I1)/I2={(I3-I1)/I2:.2f} (asymmetry index for Dzhanibekov strength)") + + # Slice to single config: (num_configs, T, N, 15) → (1, T, N, 15) + cfg_x = data['x'][[args.config_idx]] + cfg_test = data['test_x'][[args.config_idx]] + + # ── Build model ─────────────────────────────────────────────────────── + model = DissipativeSO3HamNODE( + device=device, u_dim=3, init_gain=args.init_gain).to(device) + + if args.fix_M: + _inner(model).M_net = FixedInertia( + I1=I1, I2=I2, I3=I3).to(device).to(float_type) + print(f"M_net fixed to diag(1/I1,1/I2,1/I3) — not trained.") + + print(f'Model: {get_model_parm_nums(model)} parameters') + + optim = torch.optim.Adam(model.parameters(), args.learn_rate, weight_decay=1e-4) + + # ── Arrange data ────────────────────────────────────────────────────── + train_x, t_eval = arrange_data(cfg_x, data['t'], num_points=args.num_points) + test_x, _ = arrange_data(cfg_test, data['t'], num_points=args.num_points) + train_x_cat = np.concatenate(train_x, axis=1) + test_x_cat = np.concatenate(test_x, axis=1) + + train_x_cat = torch.tensor(train_x_cat, requires_grad=True, + dtype=float_type).to(device) + test_x_cat = torch.tensor(test_x_cat, requires_grad=True, + dtype=float_type).to(device) + t_eval = torch.tensor(t_eval, requires_grad=True, + dtype=float_type).to(device) + + # ── M_net pretraining ───────────────────────────────────────────────── + if args.pretrain_M_steps > 0 and not args.fix_M: + q_pretrain = train_x_cat.detach().reshape(-1, 15)[:, :9] + pretrain_M_net( + model=model, + q_samples=q_pretrain, + n_steps=args.pretrain_M_steps, + lr=args.pretrain_M_lr, + print_every=args.pretrain_M_print_every, + I1=I1, I2=I2, I3=I3, + ) + + split = [9, 3, 3] + + stats = { + 'train_loss': [], 'train_l2_loss': [], 'train_geo_loss': [], + 'train_power_loss': [], + 'train_V_cons_loss': [], 'train_B_cons_loss': [], 'train_D_cons_loss': [], + 'forward_time': [], 'backward_time': [], 'nfe': [], + # Eval (windowed) — recorded every eval_every steps + 'eval_step': [], + 'test_loss': [], 'test_l2_loss': [], 'test_geo_loss': [], + 'eval_M_loss': [], 'eval_V_loss': [], + 'eval_Dw_loss': [], 'eval_g_loss': [], + # Config metadata saved alongside training stats + 'config_idx': args.config_idx, + 'I1': I1, 'I2': I2, 'I3': I3, + } + + dt_train = (t_eval[1] - t_eval[0]).detach().item() + + os.makedirs(args.save_dir, exist_ok=True) + label = '-so3ham' + stats_path = (f'{args.save_dir}/{args.name}{label}' + f'-{args.solver}-{args.num_points}p-stats.pkl') + + loss_buffer = [] + fwd_buffer = [] + bwd_buffer = [] + nfe_buffer = [] + + for step in range(args.total_steps + 1): + + # ── Training step ──────────────────────────────────────────────── + t = time.time() + train_x_hat = odeint(model, train_x_cat[0, :, :], t_eval, method=args.solver) + forward_time = time.time() - t + + target = train_x_cat[1:, :, :] + target_hat = train_x_hat[1:, :, :] + train_loss, train_l2_loss, train_geo_loss = rotmat_L2_geodesic_loss( + target, target_hat, split=split) + + if args.lambda_power > 0.0: + L_power = power_balance_loss(model, train_x_cat, dt_train) + else: + L_power = torch.zeros((), device=device, dtype=float_type) + + if args.lambda_V > 0.0 or args.lambda_B > 0.0 or args.lambda_D > 0.0: + L_V, L_B, L_D = consistency_subnet_losses(model, train_x_cat, dt_train) + else: + L_V = torch.zeros((), device=device, dtype=float_type) + L_B = torch.zeros((), device=device, dtype=float_type) + L_D = torch.zeros((), device=device, dtype=float_type) + + total_loss = (train_loss + + args.lambda_power * L_power + + args.lambda_V * L_V + + args.lambda_B * L_B + + args.lambda_D * L_D) + + t = time.time() + total_loss.backward() + optim.step() + optim.zero_grad() + backward_time = time.time() - t + + loss_buffer.append(torch.stack([ + total_loss.detach(), train_l2_loss.detach(), train_geo_loss.detach(), + L_power.detach(), L_V.detach(), L_B.detach(), L_D.detach() + ])) + fwd_buffer.append(forward_time) + bwd_buffer.append(backward_time) + nfe = getattr(model, 'nfe', + getattr(getattr(model, '_orig_mod', model), 'nfe', 0)) + nfe_buffer.append(nfe) + + # ── Windowed eval + diagnostics + checkpoint ───────────────────── + if step % args.eval_every == 0: + with torch.no_grad(): + test_x_hat = odeint( + model, test_x_cat[0, :, :], t_eval, method=args.solver) + tgt = test_x_cat[1:, :, :] + tgt_hat = test_x_hat[1:, :, :] + test_loss, test_l2_loss, test_geo_loss = rotmat_L2_geodesic_loss( + tgt, tgt_hat, split=split) + subnet = subnet_physics_mse_tennis( + model, test_x_hat, I1=I1, I2=I2, I3=I3, + friction_coeff=args.friction_coeff) + + if loss_buffer: + drained = torch.stack(loss_buffer, dim=0).cpu().numpy() # (N, 7) + stats['train_loss'].extend(drained[:, 0].tolist()) + stats['train_l2_loss'].extend(drained[:, 1].tolist()) + stats['train_geo_loss'].extend(drained[:, 2].tolist()) + stats['train_power_loss'].extend(drained[:, 3].tolist()) + stats['train_V_cons_loss'].extend(drained[:, 4].tolist()) + stats['train_B_cons_loss'].extend(drained[:, 5].tolist()) + stats['train_D_cons_loss'].extend(drained[:, 6].tolist()) + stats['forward_time'].extend(fwd_buffer) + stats['backward_time'].extend(bwd_buffer) + stats['nfe'].extend(nfe_buffer) + loss_buffer = []; fwd_buffer = []; bwd_buffer = []; nfe_buffer = [] + + test_pack = torch.stack([ + test_loss.detach(), test_l2_loss.detach(), test_geo_loss.detach() + ]).cpu().numpy() + train_total = stats['train_loss'][-1] + train_l2 = stats['train_l2_loss'][-1] + train_geo = stats['train_geo_loss'][-1] + + stats['eval_step'].append(step) + stats['test_loss'].append(float(test_pack[0])) + stats['test_l2_loss'].append(float(test_pack[1])) + stats['test_geo_loss'].append(float(test_pack[2])) + stats['eval_M_loss'].append(subnet['M_loss']) + stats['eval_V_loss'].append(subnet['V_loss']) + stats['eval_Dw_loss'].append(subnet['Dw_loss']) + stats['eval_g_loss'].append(subnet['g_loss']) + + train_pow = stats['train_power_loss'][-1] + train_LV = stats['train_V_cons_loss'][-1] + train_LB = stats['train_B_cons_loss'][-1] + train_LD = stats['train_D_cons_loss'][-1] + + print(f"[step {step:>6d}]") + print(f" train: total={train_total:.4e} " + f"L2={train_l2:.4e} geo={train_geo:.4e} " + f"power={train_pow:.4e}") + print(f" cons : L_V={train_LV:.4e} " + f"L_B={train_LB:.4e} L_D={train_LD:.4e}") + print(f" test : total={test_pack[0]:.4e} " + f"L2={test_pack[1]:.4e} geo={test_pack[2]:.4e}") + print(f" subnet MSE M={subnet['M_loss']:.3e} " + f"V={subnet['V_loss']:.3e} Dw={subnet['Dw_loss']:.3e} " + f"g={subnet['g_loss']:.3e} | nfe={nfe}") + + ckpt = (f'{args.save_dir}/{args.name}{label}' + f'-{args.solver}-{args.num_points}p-{step}.tar') + torch.save(_state_dict(model), ckpt) + to_pickle(stats, stats_path) + + # ── Final per-trajectory eval ───────────────────────────────────────── + cfg_x_full = torch.tensor(cfg_x, requires_grad=True, + dtype=float_type).to(device) + cfg_test_full = torch.tensor(cfg_test, requires_grad=True, + dtype=float_type).to(device) + t_full = torch.tensor(data['t'], requires_grad=True, + dtype=float_type).to(device) + + train_loss_l, test_loss_l = [], [] + train_l2_l, test_l2_l = [], [] + train_geo_l, test_geo_l = [], [] + train_data_hat, test_data_hat = [], [] + + for i in range(cfg_x_full.shape[0]): # single iteration for Stage A + train_x_hat = odeint( + model, cfg_x_full[i, 0, :, :], t_full, method=args.solver) + total_loss, l2_loss, geo_loss = traj_rotmat_L2_geodesic_loss( + cfg_x_full[i, :, :, :], train_x_hat, split=split) + train_loss_l.append(total_loss) + train_l2_l.append(l2_loss) + train_geo_l.append(geo_loss) + train_data_hat.append(train_x_hat.detach().cpu().numpy()) + + test_x_hat = odeint( + model, cfg_test_full[i, 0, :, :], t_full, method=args.solver) + total_loss, l2_loss, geo_loss = traj_rotmat_L2_geodesic_loss( + cfg_test_full[i, :, :, :], test_x_hat, split=split) + test_loss_l.append(total_loss) + test_l2_l.append(l2_loss) + test_geo_l.append(geo_loss) + test_data_hat.append(test_x_hat.detach().cpu().numpy()) + + def _per_traj(loss_list): + return torch.sum(torch.cat(loss_list, dim=1), dim=0) + + train_loss_pt = _per_traj(train_loss_l) + test_loss_pt = _per_traj(test_loss_l) + train_l2_pt = _per_traj(train_l2_l) + test_l2_pt = _per_traj(test_l2_l) + train_geo_pt = _per_traj(train_geo_l) + test_geo_pt = _per_traj(test_geo_l) + + print('Final trajectory train loss {:.4e} +/- {:.4e}\n' + 'Final trajectory test loss {:.4e} +/- {:.4e}'.format( + train_loss_pt.mean().item(), train_loss_pt.std().item(), + test_loss_pt.mean().item(), test_loss_pt.std().item())) + print('Final trajectory train geo {:.4e} +/- {:.4e}\n' + 'Final trajectory test geo {:.4e} +/- {:.4e}'.format( + train_geo_pt.mean().item(), train_geo_pt.std().item(), + test_geo_pt.mean().item(), test_geo_pt.std().item())) + + stats['traj_train_loss'] = train_loss_pt.detach().cpu().numpy() + stats['traj_test_loss'] = test_loss_pt.detach().cpu().numpy() + stats['train_x'] = cfg_x_full.detach().cpu().numpy() + stats['test_x'] = cfg_test_full.detach().cpu().numpy() + stats['train_x_hat'] = np.array(train_data_hat) + stats['test_x_hat'] = np.array(test_data_hat) + stats['t_eval'] = t_full.detach().cpu().numpy() + return model, stats + + +# ───────────────────────────────────────────────────────────────────────────── +# Entry point +# ───────────────────────────────────────────────────────────────────────────── + +if __name__ == "__main__": + args = get_args() + model, stats = train(args) + + os.makedirs(args.save_dir, exist_ok=True) + label = '-so3ham' + final_ckpt = (f'{args.save_dir}/{args.name}{label}' + f'-{args.solver}-{args.num_points}p.tar') + torch.save(_state_dict(model), final_ckpt) + final_stats = (f'{args.save_dir}/{args.name}{label}' + f'-{args.solver}-{args.num_points}p-stats.pkl') + print("Saved final stats:", final_stats) + to_pickle(stats, final_stats) diff --git a/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/train_stageC.py b/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/train_stageC.py new file mode 100644 index 0000000..47b2fa0 --- /dev/null +++ b/src/models/3D_SO3_Tennis_Racket/ph_nn_ode_v2/train_stageC.py @@ -0,0 +1,527 @@ +"""Stage C training: multi-config generalization with inertia-augmented state. + +State extended from 15D to 18D: (vec(R)[9], I₁I₂I₃[3], ω[3], u[3]). +All four geometry configs are trained jointly; the model learns the mapping +(R, I₁,I₂,I₃) → M⁻¹(R,I) = diag(1/I₁,1/I₂,1/I₃). + +Key differences from train.py (Stage A): + - Uses network_stageC.DissipativeSO3HamNODE (inertia_dim=3, rotmatdim=12). + - Uses FixedInertiaFromState instead of FixedInertia (reads I from state). + - augment_with_inertia() inserts (I₁,I₂,I₃) into every state vector. + - All 4 configs are concatenated along the batch axis for joint training. + - Per-config eval reported separately so generalisation is visible. + - loss_utils.py is still imported from the pendulum directory (unchanged). + +Stage C pass criteria: + 1. eval_M_loss < 1e-4 across all 4 configs (FixedInertiaFromState gives 0) + 2. test_geo_loss < 0.01 rad² for all 4 configs + 3. eval_V_loss, eval_Dw_loss small for all configs + 4. Generalisation: held-out 5th config (unseen geometry) achieves geo < 0.01 +""" +import torch, argparse +import numpy as np +import os, sys +import time +import pickle + +THIS_FILE_DIR = os.path.dirname(os.path.abspath(__file__)) +PROJECT_ROOT = os.path.abspath(os.path.join(THIS_FILE_DIR, '../../../..')) + +PENDULUM_ODE_DIR = os.path.join( + PROJECT_ROOT, 'src/models/3D_SO3_Windy_Pendulum/ph_nn_ode_v2') + +sys.path.insert(0, PROJECT_ROOT) +sys.path.insert(0, os.path.join(PROJECT_ROOT, 'src/utils')) +sys.path.insert(0, os.path.join(PROJECT_ROOT, 'datasets')) +sys.path.insert(0, PENDULUM_ODE_DIR) +sys.path.insert(0, THIS_FILE_DIR) # for network_stageC + +from torchdiffeq import odeint + +from ode_utils import to_pickle +from subnet_diagnostics_stageC import subnet_physics_mse_stageC +from tennis_racket_3d_datagen import arrange_data +from network_stageC import DissipativeSO3HamNODE, FixedInertiaFromState +from loss_utils import ( + rotmat_L2_geodesic_loss_safe as rotmat_L2_geodesic_loss, + traj_rotmat_L2_geodesic_loss_safe as traj_rotmat_L2_geodesic_loss, + power_balance_loss, + consistency_subnet_losses, +) + + +DEFAULT_SAVE_DIR = os.path.join(THIS_FILE_DIR, 'data', 'run_tr3d_stageC_fp32') +DEFAULT_DATA_PATH = os.path.join( + PROJECT_ROOT, + 'data/tennis_data/' + 'tr3d_dataset_dist0p0_obs_noise0p0_perturb0p05_ncfg4_steps100.pkl' +) + + +# FixedInertiaFromState is imported from network_stageC — reads I from state. + + +# ───────────────────────────────────────────────────────────────────────────── +# CLI +# ───────────────────────────────────────────────────────────────────────────── + +def get_args(): + parser = argparse.ArgumentParser(description=None) + parser.add_argument('--learn_rate', default=1e-3, type=float) + parser.add_argument('--total_steps', default=10000, type=int) + parser.add_argument('--eval_every', default=50, type=int, + help='windowed eval + diagnostics + checkpoint cadence') + parser.add_argument('--name', default='tr3d', type=str) + parser.add_argument('--verbose', action='store_true') + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--save_dir', default=DEFAULT_SAVE_DIR, type=str) + parser.add_argument('--data_path', default=DEFAULT_DATA_PATH, type=str, + help='path to the tennis-racket .pkl dataset') + parser.add_argument('--gpu', type=int, default=0) + parser.add_argument('--num_points', type=int, default=5) + parser.add_argument('--solver', default='rk4', type=str) + parser.add_argument('--init_gain', default=0.5, type=float) + + parser.add_argument('--samples', type=int, default=25) + parser.add_argument('--timesteps', type=int, default=100) + parser.add_argument('--obs_noise_std', type=float, default=0.0) + parser.add_argument('--disturbance_torque_std', type=float, default=0.0) + parser.add_argument('--perturb_std', type=float, default=0.05) + + # M_net pretraining: anchors M_net near diag(1/I1,1/I2,1/I3) before joint training. + parser.add_argument('--pretrain_M_steps', type=int, default=200) + parser.add_argument('--pretrain_M_lr', type=float, default=1e-3) + parser.add_argument('--pretrain_M_print_every', type=int, default=20) + + # Physics-informed auxiliary losses (off by default). + parser.add_argument('--lambda_power', type=float, default=0.0, + help='weight for power-balance loss; 0 disables') + parser.add_argument('--lambda_V', type=float, default=0.0, + help='weight for V back-solving loss; 0 disables') + parser.add_argument('--lambda_B', type=float, default=0.0, + help='weight for B (input coupling) back-solving loss; 0 disables') + parser.add_argument('--lambda_D', type=float, default=0.0, + help='weight for D (dissipation) back-solving loss; 0 disables') + + # Pin M_net to the analytic ground-truth diag(1/I1,1/I2,1/I3). + parser.add_argument('--fix_M', action='store_true', + help='use FixedInertia (non-learnable M_net) with true I from the data') + + return parser.parse_args() + + +# ───────────────────────────────────────────────────────────────────────────── +# Utilities +# ───────────────────────────────────────────────────────────────────────────── + +def augment_with_inertia(x_np, I1, I2, I3): + """Insert (I1,I2,I3) into every state vector after vec(R). + + Input: x_np (T, N, 15) — [vec(R)(9), ω(3), u(3)] + Output: x_aug (T, N, 18) — [vec(R)(9), I1I2I3(3), ω(3), u(3)] + """ + T, N, _ = x_np.shape + I_col = np.array([I1, I2, I3], dtype=x_np.dtype) + I_tile = np.tile(I_col, (T, N, 1)) # (T, N, 3) + return np.concatenate([x_np[:, :, :9], I_tile, x_np[:, :, 9:]], axis=2) + + +def strip_inertia(x): + """Remove the embedded I₁I₂I₃ columns, (*, 18) → (*, 15). + + The 18D Stage C state is [vec(R)(9), I₁I₂I₃(3), ω(3), u(3)]. + Removing positions 9:12 gives the 15D form [vec(R), ω, u] that + rotmat_L2_geodesic_loss_safe expects with split=[9,3,3]. + Works for any leading dimensions (T, B, …). + """ + return torch.cat([x[..., :9], x[..., 12:]], dim=-1) + + +def get_model_parm_nums(model): + return sum(p.nelement() for p in model.parameters()) + + +def _fmt_num(x): + """Format a number for filesystem paths: 0.01 → 0p01, -1.0 → n1.""" + s = f"{x:g}" + return s.replace('.', 'p').replace('-', 'n').replace('+', '') + + +def build_run_name(args): + """Build a run-folder name encoding training specs + a YYMMDD-HHMM stamp.""" + parts = [ + f"obs{_fmt_num(args.obs_noise_std)}", + f"dist{_fmt_num(args.disturbance_torque_std)}", + "cfgALL", + f"lP{_fmt_num(args.lambda_power)}", + f"lV{_fmt_num(args.lambda_V)}", + f"lB{_fmt_num(args.lambda_B)}", + f"lD{_fmt_num(args.lambda_D)}", + f"lr{_fmt_num(args.learn_rate)}", + f"s{args.total_steps}", + f"np{args.num_points}", + f"smp{args.samples}", + f"T{args.timesteps}", + f"{args.solver}", + f"seed{args.seed}", + ] + if args.fix_M: + parts.append('fixM') + stamp = time.strftime('%y%m%d-%H%M%S') + return '_'.join(parts) + '_' + stamp + + +def _state_dict(model): + return (model._orig_mod if hasattr(model, '_orig_mod') else model).state_dict() + + +def _inner(model): + """Unwrap a possibly torch.compile-wrapped model.""" + return model._orig_mod if hasattr(model, '_orig_mod') else model + + +# ───────────────────────────────────────────────────────────────────────────── +# M_net pretraining +# ───────────────────────────────────────────────────────────────────────────── + +def pretrain_M_net(model, q_samples, n_steps, lr, print_every): + """Pretrain M_net toward diag(1/I₁,1/I₂,1/I₃) using per-sample targets. + + q_samples must be (N, 12) = (vec(R), I₁,I₂,I₃). Inertia values are read + from columns 9:12 so each sample gets the correct target for its config. + """ + if n_steps <= 0: + return + + inner = _inner(model) + device = q_samples.device + dtype = q_samples.dtype + + q_no_grad = q_samples.detach() + I_vals = q_no_grad[:, 9:12] # (N, 3) + target = torch.diag_embed(1.0 / I_vals).detach() # (N, 3, 3) + + print(f"\nPretraining M_net for {n_steps} steps (lr={lr})") + print(f" {q_no_grad.shape[0]} q-samples, per-sample targets from embedded inertia") + + optim = torch.optim.Adam(inner.M_net.parameters(), lr=lr, weight_decay=1e-4) + + initial_loss = None + for step in range(n_steps): + M_pred = inner.M_net(q_no_grad) + loss = (M_pred - target).pow(2).mean() + loss.backward() + optim.step() + optim.zero_grad() + if initial_loss is None: + initial_loss = loss.item() + if step % max(1, print_every) == 0 or step == n_steps - 1: + print(f" pretrain step {step:>4d}: loss={loss.item():.3e}") + + with torch.no_grad(): + M_check = inner.M_net(q_no_grad[:1]) + deviation = (M_check - target[:1]).abs().max().item() + print(f" pretrain done. initial={initial_loss:.3e} " + f"final={loss.item():.3e} max|M(q₀)−target|={deviation:.3e}") + + +# ───────────────────────────────────────────────────────────────────────────── +# Training +# ───────────────────────────────────────────────────────────────────────────── + +def train(args): + float_type = torch.float32 + torch.set_default_dtype(torch.float32) + + device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + run_name = build_run_name(args) + args.save_dir = os.path.join(args.save_dir, run_name) + os.makedirs(args.save_dir, exist_ok=True) + print(f"Run dir : {args.save_dir}") + + if args.verbose: + print(f"Start training (fp32) num_points={args.num_points} " + f"solver={args.solver} eval_every={args.eval_every} device={device}") + + # ── Load dataset ────────────────────────────────────────────────────── + print(f"Loading dataset: {args.data_path}") + with open(args.data_path, 'rb') as f: + data = pickle.load(f) + + all_inertia = data['inertia_info'] + num_configs = len(all_inertia) + print(f"\nDataset: {num_configs} geometry configs") + for ci, info in enumerate(all_inertia): + I1c, I2c, I3c = info['I1'], info['I2'], info['I3'] + print(f" cfg{ci}: I1={I1c:.6f} I2={I2c:.6f} I3={I3c:.6f} kg·m² " + f"(I2/I1={I2c/I1c:.2f} asym=(I3-I1)/I2={(I3c-I1c)/I2c:.2f})") + + # Augment each config's data with its embedded inertia: (T,N,15) → (T,N,18) + x_aug_all = np.stack([ + augment_with_inertia(data['x'][ci], + all_inertia[ci]['I1'], all_inertia[ci]['I2'], all_inertia[ci]['I3']) + for ci in range(num_configs) + ]) # (num_configs, T, N, 18) + + x_test_aug_all = np.stack([ + augment_with_inertia(data['test_x'][ci], + all_inertia[ci]['I1'], all_inertia[ci]['I2'], all_inertia[ci]['I3']) + for ci in range(num_configs) + ]) # (num_configs, T, N, 18) + + # ── Build model ─────────────────────────────────────────────────────── + model = DissipativeSO3HamNODE( + device=device, u_dim=3, init_gain=args.init_gain, inertia_dim=3).to(device) + + if args.fix_M: + _inner(model).M_net = FixedInertiaFromState().to(device).to(float_type) + print("M_net fixed to FixedInertiaFromState — reads I₁,I₂,I₃ from state.") + + print(f'Model: {get_model_parm_nums(model)} parameters') + + optim = torch.optim.Adam(model.parameters(), args.learn_rate, weight_decay=1e-4) + + # ── Arrange data ────────────────────────────────────────────────────── + # arrange_data slices along time; all 4 config batch dims are concatenated. + train_x, t_eval = arrange_data(x_aug_all, data['t'], num_points=args.num_points) + test_x, _ = arrange_data(x_test_aug_all, data['t'], num_points=args.num_points) + train_x_cat = np.concatenate(train_x, axis=1) # (num_points, B_all, 18) + test_x_cat = np.concatenate(test_x, axis=1) # (num_points, B_all, 18) + + train_x_cat = torch.tensor(train_x_cat, requires_grad=True, + dtype=float_type).to(device) + test_x_cat = torch.tensor(test_x_cat, requires_grad=True, + dtype=float_type).to(device) + t_eval = torch.tensor(t_eval, requires_grad=True, + dtype=float_type).to(device) + + # ── M_net pretraining ───────────────────────────────────────────────── + if args.pretrain_M_steps > 0 and not args.fix_M: + q_pretrain = train_x_cat.detach().reshape(-1, 18)[:, :12] # (N, 12) + pretrain_M_net( + model=model, + q_samples=q_pretrain, + n_steps=args.pretrain_M_steps, + lr=args.pretrain_M_lr, + print_every=args.pretrain_M_print_every, + ) + + split = [9, 3, 3] + + stats = { + 'train_loss': [], 'train_l2_loss': [], 'train_geo_loss': [], + 'train_power_loss': [], + 'train_V_cons_loss': [], 'train_B_cons_loss': [], 'train_D_cons_loss': [], + 'forward_time': [], 'backward_time': [], 'nfe': [], + # Eval (windowed) — recorded every eval_every steps + 'eval_step': [], + 'test_loss': [], 'test_l2_loss': [], 'test_geo_loss': [], + 'eval_M_loss': [], 'eval_V_loss': [], + 'eval_Dw_loss': [], 'eval_g_loss': [], + # Config metadata saved alongside training stats + 'num_configs': num_configs, + 'all_inertia': all_inertia, + } + + dt_train = (t_eval[1] - t_eval[0]).detach().item() + + os.makedirs(args.save_dir, exist_ok=True) + label = '-so3ham' + stats_path = (f'{args.save_dir}/{args.name}{label}' + f'-{args.solver}-{args.num_points}p-stats.pkl') + + loss_buffer = [] + fwd_buffer = [] + bwd_buffer = [] + nfe_buffer = [] + + for step in range(args.total_steps + 1): + + # ── Training step ──────────────────────────────────────────────── + t = time.time() + train_x_hat = odeint(model, train_x_cat[0, :, :], t_eval, method=args.solver) + forward_time = time.time() - t + + target = strip_inertia(train_x_cat[1:, :, :]) + target_hat = strip_inertia(train_x_hat[1:, :, :]) + train_loss, train_l2_loss, train_geo_loss = rotmat_L2_geodesic_loss( + target, target_hat, split=split) + + if args.lambda_power > 0.0: + L_power = power_balance_loss(model, train_x_cat, dt_train) + else: + L_power = torch.zeros((), device=device, dtype=float_type) + + if args.lambda_V > 0.0 or args.lambda_B > 0.0 or args.lambda_D > 0.0: + L_V, L_B, L_D = consistency_subnet_losses(model, train_x_cat, dt_train) + else: + L_V = torch.zeros((), device=device, dtype=float_type) + L_B = torch.zeros((), device=device, dtype=float_type) + L_D = torch.zeros((), device=device, dtype=float_type) + + total_loss = (train_loss + + args.lambda_power * L_power + + args.lambda_V * L_V + + args.lambda_B * L_B + + args.lambda_D * L_D) + + t = time.time() + total_loss.backward() + optim.step() + optim.zero_grad() + backward_time = time.time() - t + + loss_buffer.append(torch.stack([ + total_loss.detach(), train_l2_loss.detach(), train_geo_loss.detach(), + L_power.detach(), L_V.detach(), L_B.detach(), L_D.detach() + ])) + fwd_buffer.append(forward_time) + bwd_buffer.append(backward_time) + nfe = getattr(model, 'nfe', + getattr(getattr(model, '_orig_mod', model), 'nfe', 0)) + nfe_buffer.append(nfe) + + # ── Windowed eval + diagnostics + checkpoint ───────────────────── + if step % args.eval_every == 0: + with torch.no_grad(): + test_x_hat = odeint( + model, test_x_cat[0, :, :], t_eval, method=args.solver) + tgt = strip_inertia(test_x_cat[1:, :, :]) + tgt_hat = strip_inertia(test_x_hat[1:, :, :]) + test_loss, test_l2_loss, test_geo_loss = rotmat_L2_geodesic_loss( + tgt, tgt_hat, split=split) + subnet = subnet_physics_mse_stageC(model, test_x_hat) + + if loss_buffer: + drained = torch.stack(loss_buffer, dim=0).cpu().numpy() # (N, 7) + stats['train_loss'].extend(drained[:, 0].tolist()) + stats['train_l2_loss'].extend(drained[:, 1].tolist()) + stats['train_geo_loss'].extend(drained[:, 2].tolist()) + stats['train_power_loss'].extend(drained[:, 3].tolist()) + stats['train_V_cons_loss'].extend(drained[:, 4].tolist()) + stats['train_B_cons_loss'].extend(drained[:, 5].tolist()) + stats['train_D_cons_loss'].extend(drained[:, 6].tolist()) + stats['forward_time'].extend(fwd_buffer) + stats['backward_time'].extend(bwd_buffer) + stats['nfe'].extend(nfe_buffer) + loss_buffer = []; fwd_buffer = []; bwd_buffer = []; nfe_buffer = [] + + test_pack = torch.stack([ + test_loss.detach(), test_l2_loss.detach(), test_geo_loss.detach() + ]).cpu().numpy() + train_total = stats['train_loss'][-1] + train_l2 = stats['train_l2_loss'][-1] + train_geo = stats['train_geo_loss'][-1] + + stats['eval_step'].append(step) + stats['test_loss'].append(float(test_pack[0])) + stats['test_l2_loss'].append(float(test_pack[1])) + stats['test_geo_loss'].append(float(test_pack[2])) + stats['eval_M_loss'].append(subnet['M_loss']) + stats['eval_V_loss'].append(subnet['V_loss']) + stats['eval_Dw_loss'].append(subnet['Dw_loss']) + stats['eval_g_loss'].append(subnet['g_loss']) + + train_pow = stats['train_power_loss'][-1] + train_LV = stats['train_V_cons_loss'][-1] + train_LB = stats['train_B_cons_loss'][-1] + train_LD = stats['train_D_cons_loss'][-1] + + print(f"[step {step:>6d}]") + print(f" train: total={train_total:.4e} " + f"L2={train_l2:.4e} geo={train_geo:.4e} " + f"power={train_pow:.4e}") + print(f" cons : L_V={train_LV:.4e} " + f"L_B={train_LB:.4e} L_D={train_LD:.4e}") + print(f" test : total={test_pack[0]:.4e} " + f"L2={test_pack[1]:.4e} geo={test_pack[2]:.4e}") + print(f" subnet MSE M={subnet['M_loss']:.3e} " + f"V={subnet['V_loss']:.3e} Dw={subnet['Dw_loss']:.3e} " + f"g={subnet['g_loss']:.3e} | nfe={nfe}") + + ckpt = (f'{args.save_dir}/{args.name}{label}' + f'-{args.solver}-{args.num_points}p-{step}.tar') + torch.save(_state_dict(model), ckpt) + to_pickle(stats, stats_path) + + # ── Final per-config per-trajectory eval ───────────────────────────── + t_full = torch.tensor(data['t'], requires_grad=True, + dtype=float_type).to(device) + + all_train_x_hat, all_test_x_hat = [], [] + all_train_geo, all_test_geo = [], [] + + print("\n=== Final trajectory eval (all configs) ===") + for ci in range(num_configs): + I = all_inertia[ci] + I1c, I2c, I3c = I['I1'], I['I2'], I['I3'] + + x_aug = augment_with_inertia(data['x'][ci], I1c, I2c, I3c) # (T, N, 18) + xt_aug = augment_with_inertia(data['test_x'][ci], I1c, I2c, I3c) # (T, N, 18) + + cfg_x_t = torch.tensor(x_aug, requires_grad=True, + dtype=float_type).to(device) # (T, N, 18) + cfg_xt_t = torch.tensor(xt_aug, requires_grad=True, + dtype=float_type).to(device) # (T, N, 18) + + with torch.no_grad(): + tr_hat = odeint(model, cfg_x_t[0], t_full, method=args.solver) + te_hat = odeint(model, cfg_xt_t[0], t_full, method=args.solver) + + tr_loss, tr_l2, tr_geo = traj_rotmat_L2_geodesic_loss( + strip_inertia(cfg_x_t), strip_inertia(tr_hat), split=split) + te_loss, te_l2, te_geo = traj_rotmat_L2_geodesic_loss( + strip_inertia(cfg_xt_t), strip_inertia(te_hat), split=split) + + def _mean_std(t): + v = t.detach() + return v.mean().item(), v.std().item() + + tr_geo_sum = tr_geo.sum(dim=0) + te_geo_sum = te_geo.sum(dim=0) + + print(f" cfg{ci} (I1={I1c:.4f} I2={I2c:.4f} I3={I3c:.4f}):") + print(f" train geo {_mean_std(tr_geo_sum)[0]:.4e} ± {_mean_std(tr_geo_sum)[1]:.4e}") + print(f" test geo {_mean_std(te_geo_sum)[0]:.4e} ± {_mean_std(te_geo_sum)[1]:.4e}") + + all_train_x_hat.append(tr_hat.detach().cpu().numpy()) + all_test_x_hat.append(te_hat.detach().cpu().numpy()) + all_train_geo.append(tr_geo_sum.detach().cpu().numpy()) + all_test_geo.append(te_geo_sum.detach().cpu().numpy()) + + all_train_geo_np = np.concatenate(all_train_geo) + all_test_geo_np = np.concatenate(all_test_geo) + print(f"\nAll configs — train geo {all_train_geo_np.mean():.4e} ± {all_train_geo_np.std():.4e}") + print(f"All configs — test geo {all_test_geo_np.mean():.4e} ± {all_test_geo_np.std():.4e}") + + stats['traj_train_geo'] = all_train_geo_np + stats['traj_test_geo'] = all_test_geo_np + stats['train_x_hat'] = all_train_x_hat + stats['test_x_hat'] = all_test_x_hat + stats['t_eval'] = t_full.detach().cpu().numpy() + return model, stats + + +# ───────────────────────────────────────────────────────────────────────────── +# Entry point +# ───────────────────────────────────────────────────────────────────────────── + +if __name__ == "__main__": + args = get_args() + model, stats = train(args) + + os.makedirs(args.save_dir, exist_ok=True) + label = '-so3ham' + final_ckpt = (f'{args.save_dir}/{args.name}{label}' + f'-{args.solver}-{args.num_points}p.tar') + torch.save(_state_dict(model), final_ckpt) + final_stats = (f'{args.save_dir}/{args.name}{label}' + f'-{args.solver}-{args.num_points}p-stats.pkl') + print("Saved final stats:", final_stats) + to_pickle(stats, final_stats) diff --git a/src/utils/subnet_diagnostics_stageC.py b/src/utils/subnet_diagnostics_stageC.py new file mode 100644 index 0000000..79a7bdb --- /dev/null +++ b/src/utils/subnet_diagnostics_stageC.py @@ -0,0 +1,81 @@ +"""Per-subnetwork physics-target MSE diagnostics for Stage C (18D state). + +Adapted from subnet_diagnostics_tennis.py for the inertia-augmented state: + state = (vec(R)[9], I₁I₂I₃[3], ω[3], u[3]) = ℝ¹⁸. + +Sub-networks receive q_ext = (vec(R), I₁,I₂,I₃) ∈ ℝ¹², so M_tgt is per-sample: + M_tgt[i] = diag(1/I₁[i], 1/I₂[i], 1/I₃[i]) +where I₁,I₂,I₃ are read from q_ext[:, 9:12]. + +Key differences from subnet_diagnostics_tennis.py: + - No I1,I2,I3 kwargs — inertia is embedded in the state. + - flat = pred.reshape(T1 * B, 18) (18D, not 15D). + - q_ext = flat[:, :12] passed to sub-networks. + - M_tgt is per-sample (N, 3, 3), not broadcast from scalars. +""" +import torch + + +def _unwrap(model): + return model._orig_mod if hasattr(model, '_orig_mod') else model + + +@torch.no_grad() +def subnet_physics_mse_stageC( + model, + x_hat, + *, + friction_coeff: float = 0.0, +): + """Compute mean MSE between each subnetwork's outputs and the true physics. + + Args: + model: DissipativeSO3HamNODE with inertia_dim=3. + x_hat: odeint output, shape (T, B, 18) — initial frame included. + friction_coeff: scalar viscous friction (0 for Stage C torque-free data). + + Returns: + dict with scalar floats: {'M_loss', 'V_loss', 'Dw_loss', 'g_loss'}. + """ + inner = _unwrap(model) + device = x_hat.device + dtype = x_hat.dtype + + # Drop the initial condition (ground-truth) timestep + pred = x_hat[1:] # (T-1, B, 18) + T1, B, _ = pred.shape + flat = pred.reshape(T1 * B, 18) + q_ext = flat[:, :12] # (N, 12) = vec(R) + I₁I₂I₃ + N = q_ext.shape[0] + + # ── Subnet outputs ────────────────────────────────────────────────── + M_pred = inner.M_net(q_ext) # (N, 3, 3) — this is M⁻¹ + V_pred = inner.V_net(q_ext).squeeze(-1) # (N,) + Dw_pred = inner.Dw_net(q_ext) # (N, 3, 3) + g_pred = inner.g_net(q_ext) # (N, 3, 3) + + # ── Ground-truth targets ───────────────────────────────────────────── + I3_eye = torch.eye(3, device=device, dtype=dtype) + + # M⁻¹ = diag(1/I₁, 1/I₂, 1/I₃) per sample, read from embedded inertia + I_vals = q_ext[:, 9:12] # (N, 3) + M_tgt = torch.diag_embed(1.0 / I_vals) # (N, 3, 3) + + # V = 0 (free body; centre before MSE to be gauge-invariant) + V_pred_c = V_pred - V_pred.mean() + V_tgt_c = torch.zeros_like(V_pred_c) + + # Dw = friction_coeff · I₃ (zero for Stage C torque-free data) + Dw_tgt = (friction_coeff * I3_eye).unsqueeze(0).expand(N, 3, 3) + + # g = I₃ (direct body-frame torque input) + g_tgt = I3_eye.unsqueeze(0).expand(N, 3, 3) + + # ── Mean MSE per subnet ────────────────────────────────────────────── + M_loss = (M_pred - M_tgt ).pow(2).mean().item() + V_loss = (V_pred_c - V_tgt_c).pow(2).mean().item() + Dw_loss = (Dw_pred - Dw_tgt ).pow(2).mean().item() + g_loss = (g_pred - g_tgt ).pow(2).mean().item() + + return {'M_loss': M_loss, 'V_loss': V_loss, + 'Dw_loss': Dw_loss, 'g_loss': g_loss} diff --git a/src/utils/subnet_diagnostics_tennis.py b/src/utils/subnet_diagnostics_tennis.py new file mode 100644 index 0000000..0a7e2eb --- /dev/null +++ b/src/utils/subnet_diagnostics_tennis.py @@ -0,0 +1,88 @@ +"""Per-subnetwork physics-target MSE diagnostics for the tennis-racket SO(3) NODE. + +Copied from subnet_diagnostics.py and adapted for the free rigid body. + +Key differences from the pendulum version: + - M_tgt = diag(1/I1, 1/I2, 1/I3) (M_net outputs M⁻¹; here that is I⁻¹) + - V_tgt = 0 (free body, no gravitational potential) + - Dw_tgt = 0 (Stage A: torque-free data, no friction) + - g_tgt = I₃ (direct body-frame torque input, same as pendulum) + +Note on the pendulum version's M_tgt: + subnet_diagnostics.py uses M_tgt = (m·l²)·I₃ = M (the mass matrix), but M_net + outputs M⁻¹ = (1/m·l²)·I₃. For the default m=l=1 these are both I₃ so the + error is invisible. Here we correctly compare M_pred against M⁻¹. + +Dw_tgt = 0 is appropriate for Stage A (torque-free dataset). When friction data +is added (Stage B), pass friction_coeff and set Dw_tgt = friction_coeff · I₃. +""" +import torch + + +def _unwrap(model): + return model._orig_mod if hasattr(model, '_orig_mod') else model + + +@torch.no_grad() +def subnet_physics_mse_tennis( + model, + x_hat, + *, + I1: float, + I2: float, + I3: float, + friction_coeff: float = 0.0, +): + """Compute mean MSE between each subnetwork's outputs and the true physics. + + Args: + model: DissipativeSO3HamNODE (or torch.compile wrapper). + x_hat: odeint output, shape (T, B, 15) — initial frame dropped. + I1, I2, I3: principal moments of inertia [kg·m²], I1 < I2 < I3. + friction_coeff: scalar viscous friction (0 for Stage A torque-free data). + + Returns: + dict with scalar floats: {'M_loss', 'V_loss', 'Dw_loss', 'g_loss'}. + """ + inner = _unwrap(model) + device = x_hat.device + dtype = x_hat.dtype + + # Use only predicted timesteps (drop the ground-truth initial condition) + pred = x_hat[1:] # (T-1, B, 15) + T1, B, _ = pred.shape + flat = pred.reshape(T1 * B, 15) + q = flat[:, :9] # (N, 9) + N = q.shape[0] + + # ── Subnet outputs ────────────────────────────────────────────────── + M_pred = inner.M_net(q) # (N, 3, 3) — this is M⁻¹ + V_pred = inner.V_net(q).squeeze(-1) # (N,) + Dw_pred = inner.Dw_net(q) # (N, 3, 3) + g_pred = inner.g_net(q) # (N, 3, 3) + + # ── Ground-truth targets ───────────────────────────────────────────── + I3_eye = torch.eye(3, device=device, dtype=dtype) + + # M⁻¹ = diag(1/I1, 1/I2, 1/I3) — constant, independent of R + I_inv_diag = torch.tensor([1.0/I1, 1.0/I2, 1.0/I3], device=device, dtype=dtype) + M_tgt = torch.diag(I_inv_diag).unsqueeze(0).expand(N, 3, 3) + + # V = 0 (free body; center both before MSE to be gauge-invariant) + V_pred_c = V_pred - V_pred.mean() + V_tgt_c = torch.zeros_like(V_pred_c) + + # Dw = friction_coeff · I₃ (zero for Stage A, non-zero once friction enabled) + Dw_tgt = (friction_coeff * I3_eye).unsqueeze(0).expand(N, 3, 3) + + # g = I₃ (direct body-frame torque input) + g_tgt = I3_eye.unsqueeze(0).expand(N, 3, 3) + + # ── Mean MSE per subnet ────────────────────────────────────────────── + M_loss = (M_pred - M_tgt ).pow(2).mean().item() + V_loss = (V_pred_c - V_tgt_c).pow(2).mean().item() + Dw_loss = (Dw_pred - Dw_tgt).pow(2).mean().item() + g_loss = (g_pred - g_tgt ).pow(2).mean().item() + + return {'M_loss': M_loss, 'V_loss': V_loss, + 'Dw_loss': Dw_loss, 'g_loss': g_loss}