diff --git a/genesis/engine/sensors/raycaster.py b/genesis/engine/sensors/raycaster.py index 7406235688..e2d5faeb39 100644 --- a/genesis/engine/sensors/raycaster.py +++ b/genesis/engine/sensors/raycaster.py @@ -8,11 +8,13 @@ import genesis as gs import genesis.utils.array_class as array_class +from genesis.engine.bvh import AABB, LBVH, STACK_SIZE from genesis.options.sensors import ( Raycaster as RaycasterOptions, +) +from genesis.options.sensors import ( RaycastPattern, ) -from genesis.engine.bvh import AABB, LBVH, STACK_SIZE from genesis.utils.geom import ( ti_normalize, ti_transform_by_quat, @@ -21,6 +23,7 @@ transform_by_trans_quat, ) from genesis.utils.misc import concat_with_tensor, make_tensor_field +from genesis.utils.raycast import kernel_update_aabbs, ray_aabb_intersection, ray_triangle_intersection from genesis.vis.rasterizer_context import RasterizerContext from .base_sensor import ( @@ -36,119 +39,6 @@ from genesis.utils.ring_buffer import TensorRingBuffer -@ti.func -def ray_triangle_intersection(ray_start, ray_dir, v0, v1, v2): - """ - Moller-Trumbore ray-triangle intersection. - - Returns: vec4(t, u, v, hit) where hit=1.0 if intersection found, 0.0 otherwise - """ - result = ti.Vector.zero(gs.ti_float, 4) - - edge1 = v1 - v0 - edge2 = v2 - v0 - - # Begin calculating determinant - also used to calculate u parameter - h = ray_dir.cross(edge2) - a = edge1.dot(h) - - # Check all conditions in sequence without early returns - valid = True - - t = gs.ti_float(0.0) - u = gs.ti_float(0.0) - v = gs.ti_float(0.0) - f = gs.ti_float(0.0) - s = ti.Vector.zero(gs.ti_float, 3) - q = ti.Vector.zero(gs.ti_float, 3) - - # If determinant is near zero, ray lies in plane of triangle - if ti.abs(a) < gs.EPS: - valid = False - - if valid: - f = 1.0 / a - s = ray_start - v0 - u = f * s.dot(h) - - if u < 0.0 or u > 1.0: - valid = False - - if valid: - q = s.cross(edge1) - v = f * ray_dir.dot(q) - - if v < 0.0 or u + v > 1.0: - valid = False - - if valid: - # At this stage we can compute t to find out where the intersection point is on the line - t = f * edge2.dot(q) - - # Ray intersection - if t <= gs.EPS: - valid = False - - if valid: - result = ti.math.vec4(t, u, v, 1.0) - - return result - - -@ti.func -def ray_aabb_intersection(ray_start, ray_dir, aabb_min, aabb_max): - """ - Fast ray-AABB intersection test. - Returns the t value of intersection, or -1.0 if no intersection. - """ - result = -1.0 - - # Use the slab method for ray-AABB intersection - sign = ti.select(ray_dir >= 0.0, 1.0, -1.0) - ray_dir = sign * ti.max(ti.abs(ray_dir), gs.EPS) - inv_dir = 1.0 / ray_dir - - t1 = (aabb_min - ray_start) * inv_dir - t2 = (aabb_max - ray_start) * inv_dir - - tmin = ti.min(t1, t2) - tmax = ti.max(t1, t2) - - t_near = ti.max(tmin.x, tmin.y, tmin.z, 0.0) - t_far = ti.min(tmax.x, tmax.y, tmax.z) - - # Check if ray intersects AABB - if t_near <= t_far: - result = t_near - - return result - - -@ti.kernel -def kernel_update_aabbs( - free_verts_state: array_class.VertsState, - fixed_verts_state: array_class.VertsState, - verts_info: array_class.VertsInfo, - faces_info: array_class.FacesInfo, - aabb_state: ti.template(), -): - for i_b, i_f in ti.ndrange(free_verts_state.pos.shape[1], faces_info.verts_idx.shape[0]): - aabb_state.aabbs[i_b, i_f].min.fill(ti.math.inf) - aabb_state.aabbs[i_b, i_f].max.fill(-ti.math.inf) - - for i in ti.static(range(3)): - i_v = faces_info.verts_idx[i_f][i] - i_fv = verts_info.verts_state_idx[i_v] - if verts_info.is_fixed[i_v]: - pos_v = fixed_verts_state.pos[i_fv] - aabb_state.aabbs[i_b, i_f].min = ti.min(aabb_state.aabbs[i_b, i_f].min, pos_v) - aabb_state.aabbs[i_b, i_f].max = ti.max(aabb_state.aabbs[i_b, i_f].max, pos_v) - else: - pos_v = free_verts_state.pos[i_fv, i_b] - aabb_state.aabbs[i_b, i_f].min = ti.min(aabb_state.aabbs[i_b, i_f].min, pos_v) - aabb_state.aabbs[i_b, i_f].max = ti.max(aabb_state.aabbs[i_b, i_f].max, pos_v) - - @ti.kernel def kernel_cast_rays( fixed_verts_state: array_class.VertsState, @@ -195,8 +85,7 @@ def kernel_cast_rays( ray_dir_local = ti.math.vec3(ray_directions[i_p, 0], ray_directions[i_p, 1], ray_directions[i_p, 2]) ray_direction_world = ti_normalize(ti_transform_by_quat(ray_dir_local, link_quat), gs.EPS) - # --- 2. BVH Traversal --- - # FIXME: this duplicates the logic in LBVH.query() which also does traversal + # --- 2. BVH Traversal for ray intersection --- max_range = max_ranges[i_s] hit_face = -1 diff --git a/genesis/ext/pyrender/interaction/__init__.py b/genesis/ext/pyrender/interaction/__init__.py new file mode 100644 index 0000000000..5dbb3a8ebd --- /dev/null +++ b/genesis/ext/pyrender/interaction/__init__.py @@ -0,0 +1,12 @@ +from .base_interaction import ( + EVENT_HANDLE_STATE, + EVENT_HANDLED, + VIEWER_PLUGIN_MAP, + BaseViewerInteraction, + register_viewer_plugin, +) +from .plugins.mesh_selector import MeshPointSelectorPlugin +from .plugins.mouse_interaction import MouseSpringViewerPlugin +from .plugins.viewer_controls import ViewerDefaultControls +from .ray import Plane, Ray, RayHit +from .vec3 import Color, Pose, Quat, Vec3 diff --git a/genesis/ext/pyrender/interaction/base_interaction.py b/genesis/ext/pyrender/interaction/base_interaction.py new file mode 100644 index 0000000000..9e0bf7badf --- /dev/null +++ b/genesis/ext/pyrender/interaction/base_interaction.py @@ -0,0 +1,127 @@ +from typing import TYPE_CHECKING, Literal, Type + +import numpy as np + +from .ray import Ray +from .vec3 import Vec3 + +if TYPE_CHECKING: + from genesis.engine.scene import Scene + from genesis.ext.pyrender.node import Node + from genesis.options.viewer_interactions import ViewerInteraction as ViewerPluginOptions + + +EVENT_HANDLE_STATE = Literal[True] | None +EVENT_HANDLED: Literal[True] = True + +# Global map from options class to viewer plugin class +VIEWER_PLUGIN_MAP: dict[Type["ViewerPluginOptions"], Type["BaseViewerInteraction"]] = {} + + +def register_viewer_plugin(options_cls: Type["ViewerPluginOptions"]): + """ + Decorator to register a viewer plugin class with its corresponding options class. + + Parameters + ---------- + options_cls : Type[ViewerPluginOptions] + The options class that configures this viewer plugin. + + Returns + ------- + Callable + The decorator function that registers the plugin class. + + Example + ------- + @register_viewer_plugin(ViewerInteractionOptions) + class ViewerInteraction(ViewerInteractionBase): + ... + """ + def _impl(plugin_cls: Type["BaseViewerInteraction"]): + VIEWER_PLUGIN_MAP[options_cls] = plugin_cls + return plugin_cls + return _impl + +# Note: Viewer window is based on pyglet.window.Window, mouse events are defined in pyglet.window.BaseWindow + +class BaseViewerInteraction(): + """ + Base class for handling pyglet.window.Window events. + """ + + def __init__( + self, + viewer, + options: "ViewerPluginOptions", + camera: "Node", + scene: "Scene", + viewport_size: tuple[int, int], + ): + self.viewer = viewer + self.options: "ViewerPluginOptions" = options + self.camera: 'Node' = camera + self.scene: 'Scene' = scene + self.viewport_size: tuple[int, int] = viewport_size + + self.camera_yfov: float = camera.camera.yfov + self.tan_half_fov: float = np.tan(0.5 * self.camera_yfov) + + def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STATE: + pass + + def on_mouse_drag(self, x: int, y: int, dx: int, dy: int, buttons: int, modifiers: int) -> EVENT_HANDLE_STATE: + pass + + def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: + pass + + def on_mouse_release(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: + pass + + def on_key_press(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: + pass + + def on_key_release(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: + pass + + def on_resize(self, width: int, height: int) -> EVENT_HANDLE_STATE: + self.viewport_size = (width, height) + self.tan_half_fov = np.tan(0.5 * self.camera_yfov) + + def update_on_sim_step(self) -> None: + pass + + def on_draw(self) -> None: + pass + + def on_close(self) -> None: + pass + + def _screen_position_to_ray(self, x: float, y: float) -> Ray: + # convert screen position to ray + x = x - 0.5 * self.viewport_size[0] + y = y - 0.5 * self.viewport_size[1] + x = 2.0 * x / self.viewport_size[1] * self.tan_half_fov + y = 2.0 * y / self.viewport_size[1] * self.tan_half_fov + + # Note: ignoring pixel aspect ratio + + mtx = self.camera.matrix + position = Vec3.from_array(mtx[:3, 3]) + forward = Vec3.from_array(-mtx[:3, 2]) + right = Vec3.from_array(mtx[:3, 0]) + up = Vec3.from_array(mtx[:3, 1]) + + direction = forward + right * x + up * y + return Ray(position, direction) + + def _get_camera_forward(self) -> Vec3: + mtx = self.camera.matrix + return Vec3.from_array(-mtx[:3, 2]) + + def _get_camera_ray(self) -> Ray: + mtx = self.camera.matrix + position = Vec3.from_array(mtx[:3, 3]) + forward = Vec3.from_array(-mtx[:3, 2]) + return Ray(position, forward) \ No newline at end of file diff --git a/genesis/ext/pyrender/interaction/keybindings.py b/genesis/ext/pyrender/interaction/keybindings.py new file mode 100644 index 0000000000..e461a8e6b6 --- /dev/null +++ b/genesis/ext/pyrender/interaction/keybindings.py @@ -0,0 +1,27 @@ +from pyglet.window.key import symbol_string + + +class Keybindings: + def __init__(self, map: dict[str, int] = {}, **kwargs: dict[str, int]): + self._map: dict[str, int] = {**map, **kwargs} + + def __getattr__(self, name: str) -> int: + if name in self._map: + return self._map[name] + raise AttributeError(f"Action '{name}' not found in keybindings.") + + def as_instruction_texts(self, padding, exclude: tuple[str]) -> list[str]: + width = 4 + padding + return [ + f"{'[' + symbol_string(self._map[action]).lower():>{width}}]: " + + action.replace('_', ' ') for action in self._map.keys() if action not in exclude + ] + + def extend(self, mapping: dict[str, int], replace_only: bool = False) -> None: + current_keys = self._map.keys() + for action, key in mapping.items(): + if replace_only and action not in self._map: + raise KeyError(f"Action '{action}' not found. Available actions: {list(self._map.keys())}") + if key in current_keys: + raise ValueError(f"Key '{symbol_string(key)}' is already assigned to another action.") + self._map[action] = key \ No newline at end of file diff --git a/genesis/ext/pyrender/interaction/mouse_spring.py b/genesis/ext/pyrender/interaction/mouse_spring.py deleted file mode 100644 index 7aace796a0..0000000000 --- a/genesis/ext/pyrender/interaction/mouse_spring.py +++ /dev/null @@ -1,95 +0,0 @@ -from typing import TYPE_CHECKING - -import torch - -from .ray import Plane, Ray, RayHit -from .vec3 import Pose, Quat, Vec3, Color - -if TYPE_CHECKING: - from genesis.engine.entities.rigid_entity.rigid_link import RigidLink - - -MOUSE_SPRING_POSITION_CORRECTION_FACTOR = 1.0 -MOUSE_SPRING_VELOCITY_CORRECTION_FACTOR = 1.0 - - -class MouseSpring: - def __init__(self) -> None: - self.held_link: "RigidLink | None" = None - self.held_point_in_local: Vec3 | None = None - self.prev_control_point: Vec3 | None = None - - def attach(self, picked_link: "RigidLink", control_point: Vec3) -> None: - # for now, we just pick the first geometry - self.held_link = picked_link - pose: Pose = Pose.from_link(self.held_link) - self.held_point_in_local = pose.inverse_transform_point(control_point) - self.prev_control_point = control_point - - def detach(self) -> None: - self.held_link = None - - def apply_force(self, control_point: Vec3, delta_time: float) -> None: - # note when threaded: apply_force is called before attach! - # note2: that was before we added a lock to ViewerInteraction; this migth be fixed now - if not self.held_link: - return - - self.prev_control_point = control_point - - # do simple force on COM only: - link: "RigidLink" = self.held_link - lin_vel: Vec3 = Vec3.from_tensor(link.get_vel()) - ang_vel: Vec3 = Vec3.from_tensor(link.get_ang()) - link_pose: Pose = Pose.from_link(link) - held_point_in_world: Vec3 = link_pose.transform_point(self.held_point_in_local) - - # note: we should assert earlier that link inertial_pos/quat are not None - # todo: verify inertial_pos/quat are stored in local frame - link_T_principal: Pose = Pose(Vec3.from_arraylike(link.inertial_pos), Quat.from_arraylike(link.inertial_quat)) - world_T_principal: Pose = link_pose * link_T_principal - - arm_in_principal: Vec3 = link_T_principal.inverse_transform_point(self.held_point_in_local) # for non-spherical inertia - arm_in_world: Vec3 = world_T_principal.rot * arm_in_principal # for spherical inertia - - pos_err_v: Vec3 = control_point - held_point_in_world - inv_mass: float = float(1.0 / link.get_mass() if link.get_mass() > 0.0 else 0.0) - inv_spherical_inertia: float = float(1.0 / link.inertial_i[0, 0] if link.inertial_i[0, 0] > 0.0 else 0.0) - - inv_dt: float = 1.0 / delta_time - tau: float = MOUSE_SPRING_POSITION_CORRECTION_FACTOR - damp: float = MOUSE_SPRING_VELOCITY_CORRECTION_FACTOR - - total_impulse: Vec3 = Vec3.zero() - total_torque_impulse: Vec3 = Vec3.zero() - - for i in range(3*4): - body_point_vel: Vec3 = lin_vel + ang_vel.cross(arm_in_world) - vel_err_v: Vec3 = Vec3.zero() - body_point_vel - - dir: Vec3 = Vec3.zero() - dir.v[i % 3] = 1.0 - pos_err: float = dir.dot(pos_err_v) - vel_err: float = dir.dot(vel_err_v) - error: float = tau * pos_err * inv_dt + damp * vel_err - - arm_x_dir: Vec3 = arm_in_world.cross(dir) - virtual_mass: float = 1.0 / (inv_mass + arm_x_dir.sqr_magnitude() * inv_spherical_inertia + 1e-24) - impulse: float = error * virtual_mass - - lin_vel += impulse * inv_mass * dir - ang_vel += impulse * inv_spherical_inertia * arm_x_dir - total_impulse.v[i % 3] += impulse - total_torque_impulse += impulse * arm_x_dir - - # Apply the new force - total_force = total_impulse * inv_dt - total_torque = total_torque_impulse * inv_dt - force_tensor: torch.Tensor = total_force.as_tensor().unsqueeze(0) - torque_tensor: torch.Tensor = total_torque.as_tensor().unsqueeze(0) - link.solver.apply_links_external_force(force_tensor, (link.idx,), ref='link_com', local=False) - link.solver.apply_links_external_torque(torque_tensor, (link.idx,), ref='link_com', local=False) - - @property - def is_attached(self) -> bool: - return self.held_link is not None diff --git a/genesis/ext/pyrender/interaction/plugins/__init__.py b/genesis/ext/pyrender/interaction/plugins/__init__.py new file mode 100644 index 0000000000..4d9ca6c59c --- /dev/null +++ b/genesis/ext/pyrender/interaction/plugins/__init__.py @@ -0,0 +1,9 @@ +from .mesh_selector import MeshPointSelectorPlugin +from .mouse_interaction import MouseSpringViewerPlugin +from .viewer_controls import ViewerDefaultControls + +__all__ = [ + "ViewerDefaultControls", + "MeshPointSelectorPlugin", + "MouseSpringViewerPlugin", +] diff --git a/genesis/ext/pyrender/interaction/plugins/default_keyboard.py b/genesis/ext/pyrender/interaction/plugins/default_keyboard.py new file mode 100644 index 0000000000..9c211e84e3 --- /dev/null +++ b/genesis/ext/pyrender/interaction/plugins/default_keyboard.py @@ -0,0 +1,213 @@ +import os +from typing import TYPE_CHECKING + +import numpy as np +import pyglet +from typing_extensions import override + +import genesis as gs + +from ...constants import TEXT_PADDING +from ..base_interaction import EVENT_HANDLE_STATE, BaseViewerInteraction + +if TYPE_CHECKING: + from genesis.engine.scene import Scene + from genesis.ext.pyrender.node import Node + + +class ViewerControls(BaseViewerInteraction): + """ + Default keyboard controls for the Genesis viewer. + + This plugin handles the standard viewer keyboard shortcuts for recording, changing render modes, etc. + """ + + def __init__( + self, + viewer, + options=None, + camera: "Node" = None, + scene: "Scene" = None, + viewport_size: tuple[int, int] = None, + ): + super().__init__(viewer, options, camera, scene, viewport_size) + + # Instruction display state + self._display_instr = False + self._instr_texts = [ + ["> [i]: show keyboard instructions"], + [ + "< [i]: hide keyboard instructions", + " [r]: record video", + " [s]: save image", + " [z]: reset camera", + " [a]: camera rotation", + " [h]: shadow", + " [f]: face normal", + " [v]: vertex normal", + " [w]: world frame", + " [l]: link frame", + " [d]: wireframe", + " [c]: camera & frustrum", + " [F11]: full-screen mode", + ], + ] + + @override + def on_key_press(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: + if self.viewer is None: + return None + + # A causes the frame to rotate + self.viewer._message_text = None + if symbol == pyglet.window.key.A: + self.viewer.viewer_flags["rotate"] = not self.viewer.viewer_flags["rotate"] + if self.viewer.viewer_flags["rotate"]: + self.viewer._message_text = "Rotation On" + else: + self.viewer._message_text = "Rotation Off" + + # F11 toggles fullscreen + elif symbol == pyglet.window.key.F11: + self.viewer.viewer_flags["fullscreen"] = not self.viewer.viewer_flags["fullscreen"] + self.viewer.set_fullscreen(self.viewer.viewer_flags["fullscreen"]) + self.viewer.activate() + if self.viewer.viewer_flags["fullscreen"]: + self.viewer._message_text = "Fullscreen On" + else: + self.viewer._message_text = "Fullscreen Off" + + # H toggles shadows + elif symbol == pyglet.window.key.H: + self.viewer.render_flags["shadows"] = not self.viewer.render_flags["shadows"] + if self.viewer.render_flags["shadows"]: + self.viewer._message_text = "Shadows On" + else: + self.viewer._message_text = "Shadows Off" + + # W toggles world frame + elif symbol == pyglet.window.key.W: + if not self.viewer.gs_context.world_frame_shown: + self.viewer.gs_context.on_world_frame() + self.viewer._message_text = "World Frame On" + else: + self.viewer.gs_context.off_world_frame() + self.viewer._message_text = "World Frame Off" + + # L toggles link frame + elif symbol == pyglet.window.key.L: + if not self.viewer.gs_context.link_frame_shown: + self.viewer.gs_context.on_link_frame() + self.viewer._message_text = "Link Frame On" + else: + self.viewer.gs_context.off_link_frame() + self.viewer._message_text = "Link Frame Off" + + # C toggles camera frustum + elif symbol == pyglet.window.key.C: + if not self.viewer.gs_context.camera_frustum_shown: + self.viewer.gs_context.on_camera_frustum() + self.viewer._message_text = "Camera Frustrum On" + else: + self.viewer.gs_context.off_camera_frustum() + self.viewer._message_text = "Camera Frustrum Off" + + # F toggles face normals + elif symbol == pyglet.window.key.F: + self.viewer.render_flags["face_normals"] = not self.viewer.render_flags["face_normals"] + if self.viewer.render_flags["face_normals"]: + self.viewer._message_text = "Face Normals On" + else: + self.viewer._message_text = "Face Normals Off" + + # V toggles vertex normals + elif symbol == pyglet.window.key.V: + self.viewer.render_flags["vertex_normals"] = not self.viewer.render_flags["vertex_normals"] + if self.viewer.render_flags["vertex_normals"]: + self.viewer._message_text = "Vert Normals On" + else: + self.viewer._message_text = "Vert Normals Off" + + # R starts recording frames + elif symbol == pyglet.window.key.R: + if self.viewer.viewer_flags["record"]: + self.viewer.save_video() + self.viewer.set_caption(self.viewer.viewer_flags["window_title"]) + else: + # Importing moviepy is very slow and not used very often. Let's delay import. + from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter + + self.viewer._video_recorder = FFMPEG_VideoWriter( + filename=os.path.join(gs.utils.misc.get_cache_dir(), "tmp_video.mp4"), + fps=self.viewer.viewer_flags["refresh_rate"], + size=self.viewer.viewport_size, + ) + self.viewer.set_caption("{} (RECORDING)".format(self.viewer.viewer_flags["window_title"])) + self.viewer.viewer_flags["record"] = not self.viewer.viewer_flags["record"] + + # S saves the current frame as an image + elif symbol == pyglet.window.key.S: + self.viewer._save_image() + + # D toggles through wireframe modes + elif symbol == pyglet.window.key.D: + if self.viewer.render_flags["flip_wireframe"]: + self.viewer.render_flags["flip_wireframe"] = False + self.viewer.render_flags["all_wireframe"] = True + self.viewer.render_flags["all_solid"] = False + self.viewer._message_text = "All Wireframe" + elif self.viewer.render_flags["all_wireframe"]: + self.viewer.render_flags["flip_wireframe"] = False + self.viewer.render_flags["all_wireframe"] = False + self.viewer.render_flags["all_solid"] = True + self.viewer._message_text = "All Solid" + elif self.viewer.render_flags["all_solid"]: + self.viewer.render_flags["flip_wireframe"] = False + self.viewer.render_flags["all_wireframe"] = False + self.viewer.render_flags["all_solid"] = False + self.viewer._message_text = "Default Wireframe" + else: + self.viewer.render_flags["flip_wireframe"] = True + self.viewer.render_flags["all_wireframe"] = False + self.viewer.render_flags["all_solid"] = False + self.viewer._message_text = "Flip Wireframe" + + # Z resets the camera viewpoint + elif symbol == pyglet.window.key.Z: + self.viewer._reset_view() + + # I toggles instruction display + elif symbol == pyglet.window.key.I: + self._display_instr = not self._display_instr + + # P reloads shader program + elif symbol == pyglet.window.key.P: + self.viewer._renderer.reload_program() + + if self.viewer._message_text is not None: + self.viewer._message_opac = 1.0 + self.viewer._ticks_till_fade + + return None + + @override + def on_draw(self): + """Render keyboard instructions.""" + if self.viewer is None: + return + + if self._display_instr: + self.viewer._renderer.render_texts( + self._instr_texts[1], + TEXT_PADDING, + self.viewer.viewport_size[1] - TEXT_PADDING, + font_pt=26, + color=np.array([1.0, 1.0, 1.0, 0.85]), + ) + else: + self.viewer._renderer.render_texts( + self._instr_texts[0], + TEXT_PADDING, + self.viewer.viewport_size[1] - TEXT_PADDING, + font_pt=26, + color=np.array([1.0, 1.0, 1.0, 0.85]), + ) diff --git a/genesis/ext/pyrender/interaction/plugins/mesh_selector.py b/genesis/ext/pyrender/interaction/plugins/mesh_selector.py new file mode 100644 index 0000000000..c043c00a84 --- /dev/null +++ b/genesis/ext/pyrender/interaction/plugins/mesh_selector.py @@ -0,0 +1,208 @@ +import csv +from typing import TYPE_CHECKING, NamedTuple + +from genesis.options.viewer_interactions import MeshPointSelectorPlugin as MeshPointSelectorPluginOptions +from typing_extensions import override + +import genesis as gs + +from ..base_interaction import EVENT_HANDLE_STATE, EVENT_HANDLED, register_viewer_plugin +from ..ray import Ray +from ..raycaster import ViewerRaycaster +from ..vec3 import Pose, Vec3 +from .viewer_controls import ViewerDefaultControls + +if TYPE_CHECKING: + from genesis.engine.entities.rigid_entity import RigidLink + from genesis.engine.scene import Scene + from genesis.ext.pyrender.node import Node + + +class SelectedPoint(NamedTuple): + """ + Represents a selected point on a rigid mesh surface. + + Attributes + ---------- + link : RigidLink + The rigid link that the point belongs to. + local_position : Vec3 + The position of the point in the link's local coordinate frame. + local_normal : Vec3 + The surface normal at the point in the link's local coordinate frame. + """ + link: "RigidLink" + local_position: Vec3 + local_normal: Vec3 + + + +@register_viewer_plugin(MeshPointSelectorPluginOptions) +class MeshPointSelectorPlugin(ViewerDefaultControls): + """ + Interactive viewer plugin that enables using mouse clicks to select points on rigid meshes. + Selected points are stored in local coordinates relative to their link's frame. + """ + + def __init__( + self, + viewer, + options: MeshPointSelectorPluginOptions, + camera: "Node", + scene: "Scene", + viewport_size: tuple[int, int], + ) -> None: + super().__init__(viewer, options, camera, scene, viewport_size) + self.prev_mouse_pos: tuple[int, int] = (viewport_size[0] // 2, viewport_size[1] // 2) + + # List of selected points with link, local position, and local normal + self.selected_points: list[SelectedPoint] = [] + + self.raycaster: ViewerRaycaster = ViewerRaycaster(self.scene) + + def _snap_to_grid(self, position: Vec3) -> Vec3: + """ + Snap a position to the grid based on grid_snap settings. + + Parameters + ---------- + position : Vec3 + The position to snap. + + Returns + ------- + Vec3 + The snapped position. + """ + snap_x, snap_y, snap_z = self.options.grid_snap + + # Snap each axis if the snap value is non-negative + x = round(position.x / snap_x) * snap_x if snap_x >= 0 else position.x + y = round(position.y / snap_y) * snap_y if snap_y >= 0 else position.y + z = round(position.z / snap_z) * snap_z if snap_z >= 0 else position.z + + return Vec3.from_xyz(x, y, z) + + @override + def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STATE: + super().on_mouse_motion(x, y, dx, dy) + self.prev_mouse_pos = (x, y) + return None + + @override + def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: + super().on_mouse_press(x, y, button, modifiers) + if button == 1: # left mouse button + ray = self._screen_position_to_ray(x, y) + ray_hit = self.raycaster.cast_ray(ray.origin.v, ray.direction.v) + + if ray_hit.is_hit and ray_hit.geom: + link = ray_hit.geom.link + world_pos = ray_hit.position + world_normal = ray_hit.normal + + pose: Pose = Pose.from_link(link) + local_pos = pose.inverse_transform_point(world_pos) + local_normal = pose.inverse_transform_direction(world_normal) + + # Apply grid snapping to local position + local_pos = self._snap_to_grid(local_pos) + + selected_point = SelectedPoint( + link=link, + local_position=local_pos, + local_normal=local_normal + ) + self.selected_points.append(selected_point) + + return EVENT_HANDLED + return None + + @override + def update_on_sim_step(self) -> None: + self.raycaster.update_bvh() + + @override + def on_draw(self) -> None: + super().on_draw() + if self.scene._visualizer is not None and self.scene._visualizer.is_built: + self.scene.clear_debug_objects() + mouse_ray: Ray = self._screen_position_to_ray(*self.prev_mouse_pos) + + closest_hit = self.raycaster.cast_ray(mouse_ray.origin.v, mouse_ray.direction.v) + if closest_hit.is_hit: + snap_pos = self._snap_to_grid(closest_hit.position) + # Draw hover preview + self.scene.draw_debug_sphere( + snap_pos.v, + self.options.sphere_radius, + self.options.hover_color, + ) + self.scene.draw_debug_arrow( + snap_pos.v, + closest_hit.normal.v * 0.1, + self.options.sphere_radius / 2, + self.options.hover_color, + ) + + if self.selected_points: + world_positions = [] + for point in self.selected_points: + pose = Pose.from_link(point.link) + current_world_pos = pose.transform_point(point.local_position) + world_positions.append(current_world_pos.v) + + if len(world_positions) == 1: + self.scene.draw_debug_sphere( + world_positions[0], + self.options.sphere_radius, + self.options.sphere_color, + ) + else: + import numpy as np + + positions_array = np.array(world_positions) + self.scene.draw_debug_spheres( + positions_array, self.options.sphere_radius, self.options.sphere_color + ) + + @override + def on_close(self) -> None: + super().on_close() + + if not self.selected_points: + print("[MeshPointSelectorPlugin] No points selected.") + return + + output_file = self.options.output_file + try: + with open(output_file, 'w', newline='') as csvfile: + writer = csv.writer(csvfile) + + writer.writerow([ + 'point_idx', + 'link_idx', + 'local_pos_x', + 'local_pos_y', + 'local_pos_z', + 'local_normal_x', + 'local_normal_y', + 'local_normal_z' + ]) + + for i, point in enumerate(self.selected_points, 1): + writer.writerow([ + i, + point.link.idx, + point.local_position.x, + point.local_position.y, + point.local_position.z, + point.local_normal.x, + point.local_normal.y, + point.local_normal.z, + ]) + + gs.logger.info(f"[MeshPointSelectorPlugin] Wrote {len(self.selected_points)} selected points to '{output_file}'") + + except Exception as e: + gs.logger.error(f"[MeshPointSelectorPlugin] Error writing to '{output_file}': {e}") diff --git a/genesis/ext/pyrender/interaction/viewer_interaction.py b/genesis/ext/pyrender/interaction/plugins/mouse_interaction.py similarity index 53% rename from genesis/ext/pyrender/interaction/viewer_interaction.py rename to genesis/ext/pyrender/interaction/plugins/mouse_interaction.py index 05d49b3cd8..449c73dea6 100644 --- a/genesis/ext/pyrender/interaction/viewer_interaction.py +++ b/genesis/ext/pyrender/interaction/plugins/mouse_interaction.py @@ -1,44 +1,125 @@ -from typing import TYPE_CHECKING, cast -from typing_extensions import override from threading import Lock as threading_Lock +from typing import TYPE_CHECKING -import numpy as np +import torch +from genesis.options.viewer_interactions import MouseSpringViewerPlugin as MouseSpringViewerPluginOptions +from typing_extensions import override import genesis as gs -from .aabb import AABB, OBB -from .mouse_spring import MouseSpring -from .ray import Plane, Ray, RayHit -from .vec3 import Pose, Quat, Vec3, Color -from .viewer_interaction_base import ViewerInteractionBase, EVENT_HANDLE_STATE, EVENT_HANDLED +from ..aabb import AABB, OBB +from ..base_interaction import EVENT_HANDLE_STATE, EVENT_HANDLED, register_viewer_plugin +from ..ray import Plane, Ray, RayHit +from ..raycaster import ViewerRaycaster +from ..vec3 import Color, Pose, Quat, Vec3 +from .viewer_controls import ViewerDefaultControls if TYPE_CHECKING: - from genesis.engine.entities.rigid_entity import RigidGeom, RigidLink, RigidEntity + from genesis.engine.entities.rigid_entity import RigidEntity, RigidGeom, RigidLink from genesis.engine.scene import Scene from genesis.ext.pyrender.node import Node -class ViewerInteraction(ViewerInteractionBase): - """Functionalities to be implemented: - - mouse picking - - mouse dragging +MOUSE_SPRING_POSITION_CORRECTION_FACTOR = 1.0 +MOUSE_SPRING_VELOCITY_CORRECTION_FACTOR = 1.0 + +class MouseSpring: + def __init__(self) -> None: + self.held_link: "RigidLink | None" = None + self.held_point_in_local: Vec3 | None = None + self.prev_control_point: Vec3 | None = None + + def attach(self, picked_link: "RigidLink", control_point: Vec3) -> None: + # for now, we just pick the first geometry + self.held_link = picked_link + pose: Pose = Pose.from_link(self.held_link) + self.held_point_in_local = pose.inverse_transform_point(control_point) + self.prev_control_point = control_point + + def detach(self) -> None: + self.held_link = None + + def apply_force(self, control_point: Vec3, delta_time: float) -> None: + # note when threaded: apply_force is called before attach! + # note2: that was before we added a lock to ViewerInteraction; this migth be fixed now + if not self.held_link: + return + + self.prev_control_point = control_point + + # do simple force on COM only: + link: "RigidLink" = self.held_link + lin_vel: Vec3 = Vec3.from_tensor(link.get_vel()) + ang_vel: Vec3 = Vec3.from_tensor(link.get_ang()) + link_pose: Pose = Pose.from_link(link) + held_point_in_world: Vec3 = link_pose.transform_point(self.held_point_in_local) + + # note: we should assert earlier that link inertial_pos/quat are not None + # todo: verify inertial_pos/quat are stored in local frame + link_T_principal: Pose = Pose(Vec3.from_arraylike(link.inertial_pos), Quat.from_arraylike(link.inertial_quat)) + world_T_principal: Pose = link_pose * link_T_principal + + arm_in_principal: Vec3 = link_T_principal.inverse_transform_point(self.held_point_in_local) # for non-spherical inertia + arm_in_world: Vec3 = world_T_principal.rot * arm_in_principal # for spherical inertia + + pos_err_v: Vec3 = control_point - held_point_in_world + inv_mass: float = float(1.0 / link.get_mass() if link.get_mass() > 0.0 else 0.0) + inv_spherical_inertia: float = float(1.0 / link.inertial_i[0, 0] if link.inertial_i[0, 0] > 0.0 else 0.0) + + inv_dt: float = 1.0 / delta_time + tau: float = MOUSE_SPRING_POSITION_CORRECTION_FACTOR + damp: float = MOUSE_SPRING_VELOCITY_CORRECTION_FACTOR + + total_impulse: Vec3 = Vec3.zero() + total_torque_impulse: Vec3 = Vec3.zero() + + for i in range(3*4): + body_point_vel: Vec3 = lin_vel + ang_vel.cross(arm_in_world) + vel_err_v: Vec3 = Vec3.zero() - body_point_vel + + dir: Vec3 = Vec3.zero() + dir.v[i % 3] = 1.0 + pos_err: float = dir.dot(pos_err_v) + vel_err: float = dir.dot(vel_err_v) + error: float = tau * pos_err * inv_dt + damp * vel_err + + arm_x_dir: Vec3 = arm_in_world.cross(dir) + virtual_mass: float = 1.0 / (inv_mass + arm_x_dir.sqr_magnitude() * inv_spherical_inertia + 1e-24) + impulse: float = error * virtual_mass + + lin_vel += impulse * inv_mass * dir + ang_vel += impulse * inv_spherical_inertia * arm_x_dir + total_impulse.v[i % 3] += impulse + total_torque_impulse += impulse * arm_x_dir + + # Apply the new force + total_force = total_impulse * inv_dt + total_torque = total_torque_impulse * inv_dt + force_tensor: torch.Tensor = total_force.as_tensor().unsqueeze(0) + torque_tensor: torch.Tensor = total_torque.as_tensor().unsqueeze(0) + link.solver.apply_links_external_force(force_tensor, (link.idx,), ref='link_com', local=False) + link.solver.apply_links_external_torque(torque_tensor, (link.idx,), ref='link_com', local=False) + + @property + def is_attached(self) -> bool: + return self.held_link is not None + + +@register_viewer_plugin(MouseSpringViewerPluginOptions) +class MouseSpringViewerPlugin(ViewerDefaultControls): + """ + Basic interactive viewer plugin that enables using mouse to apply spring force on rigid entities. """ - def __init__(self, - camera: 'Node', - scene: 'Scene', + def __init__( + self, + viewer, + options: MouseSpringViewerPluginOptions, + camera: "Node", + scene: "Scene", viewport_size: tuple[int, int], - camera_yfov: float, - log_events: bool = False, - camera_fov: float = 60.0, ) -> None: - super().__init__(log_events) - self.camera: 'Node' = camera - self.scene: 'Scene' = scene - self.viewport_size: tuple[int, int] = viewport_size - self.camera_yfov: float = camera_yfov - - self.tan_half_fov: float = np.tan(0.5 * self.camera_yfov) + super().__init__(viewer, options, camera, scene, viewport_size) self.prev_mouse_pos: tuple[int, int] = (viewport_size[0] // 2, viewport_size[1] // 2) self.picked_link: RigidLink | None = None @@ -49,6 +130,8 @@ def __init__(self, self.mouse_spring: MouseSpring = MouseSpring() self.lock = threading_Lock() + self.raycaster: ViewerRaycaster = ViewerRaycaster(self.scene) + @override def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STATE: super().on_mouse_motion(x, y, dx, dy) @@ -67,13 +150,14 @@ def on_mouse_drag(self, x: int, y: int, dx: int, dy: int, buttons: int, modifier def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: super().on_mouse_press(x, y, button, modifiers) if button == 1: # left mouse button - ray_hit = self.raycast_against_entities(self.screen_position_to_ray(x, y)) + + ray_hit = self.raycaster.cast_ray(self._screen_position_to_ray(x, y).origin.v, self._screen_position_to_ray(x, y).direction.v) with self.lock: if ray_hit.geom: self.picked_link = ray_hit.geom.link assert self.picked_link is not None - temp_fwd = self.get_camera_forward() + temp_fwd = self._get_camera_forward() temp_back = -temp_fwd self.mouse_drag_plane = Plane(temp_back, ray_hit.position) @@ -96,17 +180,13 @@ def on_mouse_release(self, x: int, y: int, button: int, modifiers: int) -> EVENT self.mouse_spring.detach() - @override - def on_resize(self, width: int, height: int) -> EVENT_HANDLE_STATE: - super().on_resize(width, height) - self.viewport_size = (width, height) - self.tan_half_fov = np.tan(0.5 * self.camera_yfov) - @override def update_on_sim_step(self) -> None: + self.raycaster.update_bvh() + with self.lock: if self.picked_link: - mouse_ray: Ray = self.screen_position_to_ray(*self.prev_mouse_pos) + mouse_ray: Ray = self._screen_position_to_ray(*self.prev_mouse_pos) ray_hit: RayHit = self.mouse_drag_plane.raycast(mouse_ray) assert ray_hit.is_hit if ray_hit.is_hit: @@ -119,7 +199,7 @@ def update_on_sim_step(self) -> None: # apply force self.mouse_spring.apply_force(new_mouse_3d_pos, self.scene.sim.dt) else: - #apply displacement + # apply displacement pos = Vec3.from_tensor(self.picked_link.entity.get_pos()) pos += delta_3d_pos self.picked_link.entity.set_pos(pos.as_tensor()) @@ -129,11 +209,8 @@ def on_draw(self) -> None: super().on_draw() if self.scene._visualizer is not None and self.scene._visualizer.is_built: self.scene.clear_debug_objects() - mouse_ray: Ray = self.screen_position_to_ray(*self.prev_mouse_pos) - - closest_hit = self.raycast_against_entities(mouse_ray) - if not closest_hit.is_hit: - closest_hit = self._raycast_against_ground_plane(mouse_ray) + mouse_ray: Ray = self._screen_position_to_ray(*self.prev_mouse_pos) + closest_hit = self.raycaster.cast_ray(mouse_ray.origin.v, mouse_ray.direction.v) with self.lock: if self.picked_link: @@ -157,76 +234,6 @@ def on_draw(self) -> None: self._draw_entity_unrotated_obb(closest_hit.geom) - def screen_position_to_ray(self, x: float, y: float) -> Ray: - # convert screen position to ray - if True: - x = x - 0.5 * self.viewport_size[0] - y = y - 0.5 * self.viewport_size[1] - x = 2.0 * x / self.viewport_size[1] * self.tan_half_fov - y = 2.0 * y / self.viewport_size[1] * self.tan_half_fov - else: - # alternative way - projection_matrix = self.camera.camera.get_projection_matrix(*self.viewport_size) - x = x - 0.5 * self.viewport_size[0] - y = y - 0.5 * self.viewport_size[1] - x = 2.0 * x / self.viewport_size[0] / projection_matrix[0, 0] - y = 2.0 * y / self.viewport_size[1] / projection_matrix[1, 1] - - # Note: ignoring pixel aspect ratio - - mtx = self.camera.matrix - position = Vec3.from_array(mtx[:3, 3]) - forward = Vec3.from_array(-mtx[:3, 2]) - right = Vec3.from_array(mtx[:3, 0]) - up = Vec3.from_array(mtx[:3, 1]) - - direction = forward + right * x + up * y - return Ray(position, direction) - - def get_camera_forward(self) -> Vec3: - mtx = self.camera.matrix - return Vec3.from_array(-mtx[:3, 2]) - - def get_camera_ray(self) -> Ray: - mtx = self.camera.matrix - position = Vec3.from_array(mtx[:3, 3]) - forward = Vec3.from_array(-mtx[:3, 2]) - return Ray(position, forward) - - def _raycast_against_ground_plane(self, ray: Ray) -> RayHit: - ground_plane = Plane(Vec3.from_xyz(0, 0, 1), Vec3.zero()) - return ground_plane.raycast(ray) - - def raycast_against_entity_obb(self, entity: "RigidEntity", ray: Ray) -> RayHit: - if isinstance(entity.morph, gs.morphs.Box): - obb: OBB = self._get_box_obb(entity) - ray_hit = obb.raycast(ray) - if ray_hit.is_hit: - ray_hit.geom = entity.geoms[0] - return ray_hit - elif isinstance(entity.morph, gs.morphs.Plane): - # ignore plane - return RayHit.no_hit() - else: - closest_hit = RayHit.no_hit() - for link in entity.links: - if not link.is_fixed: - for geom in link.geoms: - obb: OBB = self._get_geom_placeholder_obb(geom) - ray_hit = obb.raycast(ray) - if ray_hit.distance < closest_hit.distance: - ray_hit.geom = geom - closest_hit = ray_hit - return closest_hit - - def raycast_against_entities(self, ray: Ray) -> RayHit: - closest_hit = RayHit.no_hit() - for entity in self.scene.sim.rigid_solver.entities: - rigid_entity: "RigidEntity" = cast("RigidEntity", entity) - ray_hit = self.raycast_against_entity_obb(rigid_entity, ray) - if ray_hit.distance < closest_hit.distance: - closest_hit = ray_hit - return closest_hit def _get_box_obb(self, box_entity: "RigidEntity") -> OBB: box: gs.morphs.Box = box_entity.morph diff --git a/genesis/ext/pyrender/interaction/plugins/viewer_controls.py b/genesis/ext/pyrender/interaction/plugins/viewer_controls.py new file mode 100644 index 0000000000..7857d122ce --- /dev/null +++ b/genesis/ext/pyrender/interaction/plugins/viewer_controls.py @@ -0,0 +1,219 @@ +import os +from typing import TYPE_CHECKING + +import numpy as np +import pyglet +from genesis.options.viewer_interactions import ViewerDefaultControls as ViewerDefaultControlsOptions +from typing_extensions import override + +import genesis as gs + +from ...constants import TEXT_PADDING +from ..base_interaction import EVENT_HANDLE_STATE, BaseViewerInteraction, register_viewer_plugin +from ..keybindings import Keybindings + +if TYPE_CHECKING: + from genesis.engine.scene import Scene + from genesis.ext.pyrender.node import Node + +@register_viewer_plugin(ViewerDefaultControlsOptions) +class ViewerDefaultControls(BaseViewerInteraction): + """ + Default keyboard controls for the Genesis viewer. + + This plugin handles the standard viewer keyboard shortcuts for recording, changing render modes, etc. + """ + + def __init__( + self, + viewer, + options=None, + camera: "Node" = None, + scene: "Scene" = None, + viewport_size: tuple[int, int] = None, + ): + super().__init__(viewer, options, camera, scene, viewport_size) + + self.keybindings: Keybindings = Keybindings( + toggle_keyboard_instructions=pyglet.window.key.I, + record_video=pyglet.window.key.R, + save_image=pyglet.window.key.S, + reset_camera=pyglet.window.key.Z, + camera_rotation=pyglet.window.key.A, + shadow=pyglet.window.key.H, + face_normals=pyglet.window.key.F, + vertex_normals=pyglet.window.key.V, + world_frame=pyglet.window.key.W, + link_frame=pyglet.window.key.L, + wireframe=pyglet.window.key.D, + camera_frustum=pyglet.window.key.C, + fullscreen_mode=pyglet.window.key.F11, + ) + if options and options.keybindings: + self.keybindings.apply_override_mapping(options.keybindings) + + self._display_instr = False + self._instr_texts = ( + ["> [i]: show keyboard instructions"], + ["< [i]: hide keyboard instructions"] + self.keybindings.as_instruction_texts( + padding=3, exclude=("toggle_keyboard_instructions")), + ) + + @override + def on_key_press(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: + if self.viewer is None: + return None + + # A causes the frame to rotate + self.viewer._message_text = None + if symbol == self.keybindings.camera_rotation: + self.viewer.viewer_flags["rotate"] = not self.viewer.viewer_flags["rotate"] + if self.viewer.viewer_flags["rotate"]: + self.viewer._message_text = "Rotation On" + else: + self.viewer._message_text = "Rotation Off" + + # F11 toggles fullscreen + elif symbol == self.keybindings.fullscreen_mode: + self.viewer.viewer_flags["fullscreen"] = not self.viewer.viewer_flags["fullscreen"] + self.viewer.set_fullscreen(self.viewer.viewer_flags["fullscreen"]) + self.viewer.activate() + if self.viewer.viewer_flags["fullscreen"]: + self.viewer._message_text = "Fullscreen On" + else: + self.viewer._message_text = "Fullscreen Off" + + # H toggles shadows + elif symbol == self.keybindings.shadow: + self.viewer.render_flags["shadows"] = not self.viewer.render_flags["shadows"] + if self.viewer.render_flags["shadows"]: + self.viewer._message_text = "Shadows On" + else: + self.viewer._message_text = "Shadows Off" + + # W toggles world frame + elif symbol == self.keybindings.world_frame: + if not self.viewer.gs_context.world_frame_shown: + self.viewer.gs_context.on_world_frame() + self.viewer._message_text = "World Frame On" + else: + self.viewer.gs_context.off_world_frame() + self.viewer._message_text = "World Frame Off" + + # L toggles link frame + elif symbol == self.keybindings.link_frame: + if not self.viewer.gs_context.link_frame_shown: + self.viewer.gs_context.on_link_frame() + self.viewer._message_text = "Link Frame On" + else: + self.viewer.gs_context.off_link_frame() + self.viewer._message_text = "Link Frame Off" + + # C toggles camera frustum + elif symbol == self.keybindings.camera_frustum: + if not self.viewer.gs_context.camera_frustum_shown: + self.viewer.gs_context.on_camera_frustum() + self.viewer._message_text = "Camera Frustrum On" + else: + self.viewer.gs_context.off_camera_frustum() + self.viewer._message_text = "Camera Frustrum Off" + + # F toggles face normals + elif symbol == self.keybindings.face_normals: + self.viewer.render_flags["face_normals"] = not self.viewer.render_flags["face_normals"] + if self.viewer.render_flags["face_normals"]: + self.viewer._message_text = "Face Normals On" + else: + self.viewer._message_text = "Face Normals Off" + + # V toggles vertex normals + elif symbol == self.keybindings.vertex_normals: + self.viewer.render_flags["vertex_normals"] = not self.viewer.render_flags["vertex_normals"] + if self.viewer.render_flags["vertex_normals"]: + self.viewer._message_text = "Vert Normals On" + else: + self.viewer._message_text = "Vert Normals Off" + + # R starts recording frames + elif symbol == self.keybindings.record_video: + if self.viewer.viewer_flags["record"]: + self.viewer.save_video() + self.viewer.set_caption(self.viewer.viewer_flags["window_title"]) + else: + # Importing moviepy is very slow and not used very often. Let's delay import. + from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter + + self.viewer._video_recorder = FFMPEG_VideoWriter( + filename=os.path.join(gs.utils.misc.get_cache_dir(), "tmp_video.mp4"), + fps=self.viewer.viewer_flags["refresh_rate"], + size=self.viewer.viewport_size, + ) + self.viewer.set_caption("{} (RECORDING)".format(self.viewer.viewer_flags["window_title"])) + self.viewer.viewer_flags["record"] = not self.viewer.viewer_flags["record"] + + # S saves the current frame as an image + elif symbol == self.keybindings.save_image: + self.viewer._save_image() + + # D toggles through wireframe modes + elif symbol == self.keybindings.wireframe: + if self.viewer.render_flags["flip_wireframe"]: + self.viewer.render_flags["flip_wireframe"] = False + self.viewer.render_flags["all_wireframe"] = True + self.viewer.render_flags["all_solid"] = False + self.viewer._message_text = "All Wireframe" + elif self.viewer.render_flags["all_wireframe"]: + self.viewer.render_flags["flip_wireframe"] = False + self.viewer.render_flags["all_wireframe"] = False + self.viewer.render_flags["all_solid"] = True + self.viewer._message_text = "All Solid" + elif self.viewer.render_flags["all_solid"]: + self.viewer.render_flags["flip_wireframe"] = False + self.viewer.render_flags["all_wireframe"] = False + self.viewer.render_flags["all_solid"] = False + self.viewer._message_text = "Default Wireframe" + else: + self.viewer.render_flags["flip_wireframe"] = True + self.viewer.render_flags["all_wireframe"] = False + self.viewer.render_flags["all_solid"] = False + self.viewer._message_text = "Flip Wireframe" + + # Z resets the camera viewpoint + elif symbol == self.keybindings.reset_camera: + self.viewer._reset_view() + + # I toggles instruction display + elif symbol == self.keybindings.toggle_keyboard_instructions: + self._display_instr = not self._display_instr + + # P reloads shader program + elif symbol == self.keybindings.reload_shader: + self.viewer._renderer.reload_program() + + if self.viewer._message_text is not None: + self.viewer._message_opac = 1.0 + self.viewer._ticks_till_fade + + return None + + @override + def on_draw(self): + """Render keyboard instructions.""" + if self.viewer is None: + return + + if self._display_instr: + self.viewer._renderer.render_texts( + self._instr_texts[1], + TEXT_PADDING, + self.viewer.viewport_size[1] - TEXT_PADDING, + font_pt=26, + color=np.array([1.0, 1.0, 1.0, 0.85]), + ) + else: + self.viewer._renderer.render_texts( + self._instr_texts[0], + TEXT_PADDING, + self.viewer.viewport_size[1] - TEXT_PADDING, + font_pt=26, + color=np.array([1.0, 1.0, 1.0, 0.85]), + ) diff --git a/genesis/ext/pyrender/interaction/raycaster.py b/genesis/ext/pyrender/interaction/raycaster.py new file mode 100644 index 0000000000..49f987a2cb --- /dev/null +++ b/genesis/ext/pyrender/interaction/raycaster.py @@ -0,0 +1,281 @@ +from typing import TYPE_CHECKING + +import gstaichi as ti +import numpy as np +from genesis.engine.bvh import AABB, LBVH, STACK_SIZE +from genesis.utils.raycast import kernel_update_aabbs, ray_aabb_intersection, ray_triangle_intersection + +import genesis as gs + +from .ray import RayHit +from .vec3 import Vec3 + +if TYPE_CHECKING: + from genesis.engine.scene import Scene + + +# Constant to indicate no hit occurred +NO_HIT_DISTANCE = -1.0 + + + +@ti.kernel +def kernel_cast_single_ray_for_viewer( + fixed_verts_state: ti.template(), + free_verts_state: ti.template(), + verts_info: ti.template(), + faces_info: ti.template(), + bvh_nodes: ti.template(), + bvh_morton_codes: ti.template(), + ray_start: ti.types.ndarray(ndim=1), # [3] + ray_direction: ti.types.ndarray(ndim=1), # [3] + max_range: ti.f32, + envs_idx: ti.types.ndarray(ndim=1), # [n_envs] + result: ti.types.ndarray(ndim=1), # [9]: [distance, geom_idx, hit_x, hit_y, hit_z, normal_x, normal_y, normal_z, env_idx] +): + """ + Taichi kernel for casting a single ray for viewer interaction. + + This loops over all environments in envs_idx and returns the closest hit. + + Returns: + result[0]: distance to hit point (NO_HIT_DISTANCE if no hit) + result[1]: geom_idx of hit geometry + result[2]: hit_point x coordinate + result[3]: hit_point y coordinate + result[4]: hit_point z coordinate + result[5]: normal x coordinate + result[6]: normal y coordinate + result[7]: normal z coordinate + result[8]: env_idx of hit environment + """ + n_triangles = faces_info.verts_idx.shape[0] + + # Setup ray + ray_start_world = ti.math.vec3(ray_start[0], ray_start[1], ray_start[2]) + ray_direction_world = ti.math.vec3(ray_direction[0], ray_direction[1], ray_direction[2]) + + # Initialize result with no hit + result[0] = -1.0 # NO_HIT_DISTANCE + result[1] = -1.0 # no geom + result[2] = 0.0 # hit_point x + result[3] = 0.0 # hit_point y + result[4] = 0.0 # hit_point z + result[5] = 0.0 # normal x + result[6] = 0.0 # normal y + result[7] = 0.0 # normal z + result[8] = -1.0 # no env + + global_closest_distance = max_range + global_hit_face = -1 + global_hit_env_idx = -1 + global_hit_normal = ti.math.vec3(0.0, 0.0, 0.0) + + # Loop over all environments in envs_idx + for i_b in range(envs_idx.shape[0]): + rendered_env_idx = ti.cast(envs_idx[i_b], ti.i32) + + hit_face = -1 + closest_distance = global_closest_distance + hit_normal = ti.math.vec3(0.0, 0.0, 0.0) + + # Stack for non-recursive BVH traversal + node_stack = ti.Vector.zero(ti.i32, STACK_SIZE) + node_stack[0] = 0 # Start at root node + stack_idx = 1 + + while stack_idx > 0: + stack_idx -= 1 + node_idx = node_stack[stack_idx] + + node = bvh_nodes[i_b, node_idx] + + # Check if ray hits the node's bounding box + aabb_t = ray_aabb_intersection(ray_start_world, ray_direction_world, node.bound.min, node.bound.max) + + if aabb_t >= 0.0 and aabb_t < closest_distance: + if node.left == -1: # Leaf node + # Get original triangle/face index + sorted_leaf_idx = node_idx - (n_triangles - 1) + i_f = ti.cast(bvh_morton_codes[0, sorted_leaf_idx][1], ti.i32) + + # Get triangle vertices + tri_vertices = ti.Matrix.zero(gs.ti_float, 3, 3) + for i in ti.static(range(3)): + i_v = faces_info.verts_idx[i_f][i] + i_fv = verts_info.verts_state_idx[i_v] + if verts_info.is_fixed[i_v]: + tri_vertices[:, i] = fixed_verts_state.pos[i_fv] + else: + tri_vertices[:, i] = free_verts_state.pos[i_fv, rendered_env_idx] + v0, v1, v2 = tri_vertices[:, 0], tri_vertices[:, 1], tri_vertices[:, 2] + + # Perform ray-triangle intersection + hit_result = ray_triangle_intersection(ray_start_world, ray_direction_world, v0, v1, v2) + + if hit_result.w > 0.0 and hit_result.x < closest_distance and hit_result.x >= 0.0: + closest_distance = hit_result.x + hit_face = i_f + # Compute triangle normal + edge1 = v1 - v0 + edge2 = v2 - v0 + hit_normal = edge1.cross(edge2).normalized() + else: # Internal node + # Push children onto stack + if stack_idx < ti.static(STACK_SIZE - 2): + node_stack[stack_idx] = node.left + node_stack[stack_idx + 1] = node.right + stack_idx += 2 + + # Update global closest if this environment had a closer hit + if hit_face >= 0 and closest_distance < global_closest_distance: + global_closest_distance = closest_distance + global_hit_face = hit_face + global_hit_env_idx = rendered_env_idx + global_hit_normal = hit_normal + + # Store result + if global_hit_face >= 0: + result[0] = global_closest_distance # distance (positive value indicates hit) + # Find which geom this face belongs to + i_g = faces_info.geom_idx[global_hit_face] + result[1] = gs.ti_float(i_g) + # Compute hit point + hit_point = ray_start_world + global_closest_distance * ray_direction_world + result[2] = hit_point.x + result[3] = hit_point.y + result[4] = hit_point.z + # Store normal + result[5] = global_hit_normal.x + result[6] = global_hit_normal.y + result[7] = global_hit_normal.z + result[8] = gs.ti_float(global_hit_env_idx) + + + +class ViewerRaycaster: + """ + BVH-accelerated raycaster for viewer interaction plugins. + + This class manages a BVH structure built from the scene's rigid geometry + and provides efficient single-ray casting for interactive applications. + Only considers environments specified in rendered_envs_idx. + """ + + def __init__(self, scene: "Scene"): + """ + Initialize the ViewerRaycaster. + + Parameters + ---------- + scene : Scene + The scene to build the raycaster for. + """ + self.scene = scene + self.solver = scene.sim.rigid_solver + + # Store rendered_envs_idx as numpy array for Taichi kernel + + # self.rendered_envs_idx = np.asarray(scene.vis_options.rendered_envs_idx or [0], dtype=gs.np_int) + self.rendered_envs_idx = np.asarray([0], dtype=gs.np_int) + + # Build the BVH structure for rendered environments. + n_faces = self.solver.faces_info.geom_idx.shape[0] + + if n_faces == 0: + gs.logger.warning("No faces found in scene, viewer raycasting will not work.") + self.aabb = None + self.bvh = None + return + + self.aabb = AABB(n_batches=len(self.rendered_envs_idx), n_aabbs=n_faces) + self.bvh = LBVH( + self.aabb, + max_n_query_result_per_aabb=0, # Not used for ray queries + n_radix_sort_groups=min(64, n_faces), + ) + + self.update_bvh() + + def update_bvh(self): + """Update the BVH structure with current geometry state.""" + if self.bvh is None: + return + + # Update vertex positions + from genesis.engine.solvers.rigid.rigid_solver_decomp import kernel_update_all_verts + + kernel_update_all_verts( + geoms_info=self.solver.geoms_info, + geoms_state=self.solver.geoms_state, + verts_info=self.solver.verts_info, + free_verts_state=self.solver.free_verts_state, + fixed_verts_state=self.solver.fixed_verts_state, + ) + + # Update AABBs for each rendered environment + kernel_update_aabbs( + free_verts_state=self.solver.free_verts_state, + fixed_verts_state=self.solver.fixed_verts_state, + verts_info=self.solver.verts_info, + faces_info=self.solver.faces_info, + aabb_state=self.aabb, + ) + + # Rebuild BVH + self.bvh.build() + + def cast_ray( + self, + ray_origin: np.ndarray, + ray_direction: np.ndarray, + max_range: float = 1000.0, + ) -> RayHit: + """ + Cast a single ray against all rendered environments and return the closest hit. + + Parameters + ---------- + ray_origin : np.ndarray, shape (3,) + The origin point of the ray in world coordinates. + ray_direction : np.ndarray, shape (3,) + The direction vector of the ray (will be normalized). + max_range : float, optional + Maximum distance to check for intersections. Default is 1000.0. + + Returns + ------- + RayHit + A RayHit object containing distance, position, normal, and geom. + If no hit, returns RayHit.no_hit(). + """ + ray_direction = ray_direction / (np.linalg.norm(ray_direction) + gs.EPS) + + ray_start_np = np.asarray(ray_origin, dtype=gs.np_float) + ray_dir_np = np.asarray(ray_direction, dtype=gs.np_float) + result_np = np.zeros(9, dtype=gs.np_float) + + kernel_cast_single_ray_for_viewer( + fixed_verts_state=self.solver.fixed_verts_state, + free_verts_state=self.solver.free_verts_state, + verts_info=self.solver.verts_info, + faces_info=self.solver.faces_info, + bvh_nodes=self.bvh.nodes, + bvh_morton_codes=self.bvh.morton_codes, + ray_start=ray_start_np, + ray_direction=ray_dir_np, + max_range=max_range, + envs_idx=self.rendered_envs_idx, + result=result_np, + ) + + distance = float(result_np[0]) + if distance < NO_HIT_DISTANCE + gs.EPS: # NO_HIT_DISTANCE + return RayHit.no_hit() + + geom_idx = int(result_np[1]) + position = Vec3(result_np[2:5]) + normal = Vec3(result_np[5:8]) + geom = self.solver.geoms[geom_idx] + + return RayHit(distance=distance, position=position, normal=normal, geom=geom) diff --git a/genesis/ext/pyrender/interaction/viewer_interaction_base.py b/genesis/ext/pyrender/interaction/viewer_interaction_base.py deleted file mode 100644 index ede759dc0c..0000000000 --- a/genesis/ext/pyrender/interaction/viewer_interaction_base.py +++ /dev/null @@ -1,52 +0,0 @@ -from typing import Union, Literal - -import genesis as gs - - -EVENT_HANDLE_STATE = Union[Literal[True], None] -EVENT_HANDLED: Literal[True] = True - -# Note: Viewer window is based on pyglet.window.Window, mouse events are defined in pyglet.window.BaseWindow - -class ViewerInteractionBase(): - """Base class for handling pyglet.window.Window events. - """ - - log_events: bool - - def __init__(self, log_events: bool = False): - self.log_events = log_events - - def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Mouse moved to {x}, {y}") - - def on_mouse_drag(self, x: int, y: int, dx: int, dy: int, buttons: int, modifiers: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Mouse dragged to {x}, {y}") - - def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Mouse buttons {button} pressed at {x}, {y}") - - def on_mouse_release(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Mouse buttons {button} released at {x}, {y}") - - def on_key_press(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Key pressed: {chr(symbol)}") - - def on_key_release(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Key released: {chr(symbol)}") - - def on_resize(self, width: int, height: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Window resized to {width}x{height}") - - def update_on_sim_step(self) -> None: - pass - - def on_draw(self) -> None: - pass diff --git a/genesis/ext/pyrender/viewer.py b/genesis/ext/pyrender/viewer.py index 780e10b0a8..fbf63fb22d 100644 --- a/genesis/ext/pyrender/viewer.py +++ b/genesis/ext/pyrender/viewer.py @@ -4,10 +4,10 @@ import os import shutil import sys -import time import threading +import time from threading import Event, RLock, Semaphore, Thread -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import numpy as np import OpenGL @@ -43,12 +43,11 @@ RenderFlags, TextAlign, ) -from .interaction.viewer_interaction import ViewerInteraction -from .interaction.viewer_interaction_base import ViewerInteractionBase, EVENT_HANDLE_STATE, EVENT_HANDLED +from .interaction import EVENT_HANDLE_STATE, EVENT_HANDLED, VIEWER_PLUGIN_MAP from .light import DirectionalLight from .node import Node from .renderer import Renderer -from .shader_program import ShaderProgram, ShaderProgramCache +from .shader_program import ShaderProgram from .trackball import Trackball if TYPE_CHECKING: @@ -204,7 +203,7 @@ def __init__( shadow=False, plane_reflection=False, env_separate_rigid=False, - enable_interaction=False, + plugin_options=gs.options.viewer_interactions.ViewerDefaultControls(), **kwargs, ): ####################################################################### @@ -293,26 +292,6 @@ def __init__( self._ticks_till_fade = 2.0 / 3.0 * self.viewer_flags["refresh_rate"] self._message_opac = 1.0 + self._ticks_till_fade - self._display_instr = False - self._instr_texts = [ - ["> [i]: show keyboard instructions"], - [ - "< [i]: hide keyboard instructions", - " [r]: record video", - " [s]: save image", - " [z]: reset camera", - " [a]: camera rotation", - " [h]: shadow", - " [f]: face normal", - " [v]: vertex normal", - " [w]: world frame", - " [l]: link frame", - " [d]: wireframe", - " [c]: camera & frustrum", - " [F11]: full-screen mode", - ], - ] - # Set up raymond lights and direct lights self._raymond_lights = self._create_raymond_lights() self._direct_light = self._create_direct_light() @@ -374,14 +353,18 @@ def __init__( self.scene.main_camera_node = self._camera_node self._reset_view() - # Setup mouse interaction - # Note: context.scene is genesis.engine.scene.Scene # Note: context._scene is genesis.ext.pyrender.scene.Scene - self.viewer_interaction = ( - ViewerInteraction(self._camera_node, context.scene, viewport_size, camera.yfov) - if enable_interaction - else ViewerInteractionBase() + + # Setup viewer plugin + plugin_cls = VIEWER_PLUGIN_MAP.get(type(plugin_options)) + if plugin_cls is None: + gs.raise_exception( + f"Viewer plugin type {type(plugin_options).__name__} is not registered. " + f"Available plugins: {list(VIEWER_PLUGIN_MAP.keys())}" + ) + self.interaction_plugin = plugin_cls( + self, plugin_options, self._camera_node, context.scene, viewport_size ) ####################################################################### @@ -405,7 +388,7 @@ def __init__( self._initialized_event.wait() if not self._is_active: if self._exception: - raise RuntimeError(f"Unable to initialize an OpenGL 3+ context.") from self._exception + raise RuntimeError("Unable to initialize an OpenGL 3+ context.") from self._exception raise OpenGL.error.Error("Invalid OpenGL context.") else: if self.auto_start: @@ -585,6 +568,8 @@ def on_close(self): # Do not consider the viewer as active anymore self._is_active = False + self.interaction_plugin.on_close() + # Remove our camera and restore the prior one try: if self._camera_node is not None: @@ -733,24 +718,7 @@ def on_draw(self): self.clear() self._render() - self.viewer_interaction.on_draw() - - if self._display_instr: - self._renderer.render_texts( - self._instr_texts[1], - TEXT_PADDING, - self.viewport_size[1] - TEXT_PADDING, - font_pt=26, - color=np.array([1.0, 1.0, 1.0, 0.85]), - ) - else: - self._renderer.render_texts( - self._instr_texts[0], - TEXT_PADDING, - self.viewport_size[1] - TEXT_PADDING, - font_pt=26, - color=np.array([1.0, 1.0, 1.0, 0.85]), - ) + self.interaction_plugin.on_draw() if self._message_text is not None: self._renderer.render_text( @@ -791,12 +759,12 @@ def on_resize(self, width: int, height: int) -> EVENT_HANDLE_STATE: self._trackball.resize(self._viewport_size) self._renderer.viewport_width = self._viewport_size[0] self._renderer.viewport_height = self._viewport_size[1] - self.viewer_interaction.on_resize(width, height) + self.interaction_plugin.on_resize(width, height) self.on_draw() def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STATE: """The mouse was moved with no buttons held down.""" - return self.viewer_interaction.on_mouse_motion(x, y, dx, dy) + return self.interaction_plugin.on_mouse_motion(x, y, dx, dy) def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: """Record an initial mouse press.""" @@ -818,11 +786,11 @@ def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_H # Stop animating while using the mouse self.viewer_flags["mouse_pressed"] = True - return self.viewer_interaction.on_mouse_press(x, y, button, modifiers) + return self.interaction_plugin.on_mouse_press(x, y, button, modifiers) def on_mouse_drag(self, x: int, y: int, dx: int, dy: int, buttons: int, modifiers: int) -> EVENT_HANDLE_STATE: """The mouse was moved with one or more buttons held down.""" - result = self.viewer_interaction.on_mouse_drag(x, y, dx, dy, buttons, modifiers) + result = self.interaction_plugin.on_mouse_drag(x, y, dx, dy, buttons, modifiers) if result is not EVENT_HANDLED: result = self._trackball.drag(np.array([x, y])) return result @@ -830,7 +798,7 @@ def on_mouse_drag(self, x: int, y: int, dx: int, dy: int, buttons: int, modifier def on_mouse_release(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: """Record a mouse release.""" self.viewer_flags["mouse_pressed"] = False - return self.viewer_interaction.on_mouse_release(x, y, button, modifiers) + return self.interaction_plugin.on_mouse_release(x, y, button, modifiers) def on_mouse_scroll(self, x, y, dx, dy): """Record a mouse scroll.""" @@ -868,155 +836,14 @@ def on_key_press(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: if len(tup) == 3: kwargs = tup[2] callback(self, *args, **kwargs) - return self.viewer_interaction.on_key_press(symbol, modifiers) - - # Otherwise, use default key functions - - # A causes the frame to rotate - self._message_text = None - if symbol == pyglet.window.key.A: - self.viewer_flags["rotate"] = not self.viewer_flags["rotate"] - if self.viewer_flags["rotate"]: - self._message_text = "Rotation On" - else: - self._message_text = "Rotation Off" - - # F11 toggles face normals - elif symbol == pyglet.window.key.F11: - self.viewer_flags["fullscreen"] = not self.viewer_flags["fullscreen"] - self.set_fullscreen(self.viewer_flags["fullscreen"]) - self.activate() - if self.viewer_flags["fullscreen"]: - self._message_text = "Fullscreen On" - else: - self._message_text = "Fullscreen Off" - - # H toggles shadows - elif symbol == pyglet.window.key.H: - self.render_flags["shadows"] = not self.render_flags["shadows"] - if self.render_flags["shadows"]: - self._message_text = "Shadows On" - else: - self._message_text = "Shadows Off" - - # W toggles world frame - elif symbol == pyglet.window.key.W: - if not self.gs_context.world_frame_shown: - self.gs_context.on_world_frame() - self._message_text = "World Frame On" - else: - self.gs_context.off_world_frame() - self._message_text = "World Frame Off" - - # L toggles link frame - elif symbol == pyglet.window.key.L: - if not self.gs_context.link_frame_shown: - self.gs_context.on_link_frame() - self._message_text = "Link Frame On" - else: - self.gs_context.off_link_frame() - self._message_text = "Link Frame Off" - - # C toggles camera frustum - elif symbol == pyglet.window.key.C: - if not self.gs_context.camera_frustum_shown: - self.gs_context.on_camera_frustum() - self._message_text = "Camera Frustrum On" - else: - self.gs_context.off_camera_frustum() - self._message_text = "Camera Frustrum Off" - - # F toggles face normals - elif symbol == pyglet.window.key.F: - self.render_flags["face_normals"] = not self.render_flags["face_normals"] - if self.render_flags["face_normals"]: - self._message_text = "Face Normals On" - else: - self._message_text = "Face Normals Off" - - # V toggles vertex normals - elif symbol == pyglet.window.key.V: - self.render_flags["vertex_normals"] = not self.render_flags["vertex_normals"] - if self.render_flags["vertex_normals"]: - self._message_text = "Vert Normals On" - else: - self._message_text = "Vert Normals Off" - - # R starts recording frames - elif symbol == pyglet.window.key.R: - if self.viewer_flags["record"]: - self.save_video() - self.set_caption(self.viewer_flags["window_title"]) - else: - # Importing moviepy is very slow and not used very often. Let's delay import. - from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter - - self._video_recorder = FFMPEG_VideoWriter( - filename=os.path.join(gs.utils.misc.get_cache_dir(), "tmp_video.mp4"), - fps=self.viewer_flags["refresh_rate"], - size=self.viewport_size, - ) - self.set_caption("{} (RECORDING)".format(self.viewer_flags["window_title"])) - self.viewer_flags["record"] = not self.viewer_flags["record"] - - # S saves the current frame as an image - elif symbol == pyglet.window.key.S: - self._save_image() - - # T toggles through geom types - # elif symbol == pyglet.window.key.T: - # if self.gs_context.rigid_shown == 'visual': - # self.gs_context.on_rigid('collision') - # self._message_text = "Geom Type: 'collision'" - # elif self.gs_context.rigid_shown == 'collision': - # self.gs_context.on_rigid('sdf') - # self._message_text = "Geom Type: 'sdf'" - # else: - # self.gs_context.on_rigid('visual') - # self._message_text = "Geom Type: 'visual'" - - # D toggles through wireframe modes - elif symbol == pyglet.window.key.D: - if self.render_flags["flip_wireframe"]: - self.render_flags["flip_wireframe"] = False - self.render_flags["all_wireframe"] = True - self.render_flags["all_solid"] = False - self._message_text = "All Wireframe" - elif self.render_flags["all_wireframe"]: - self.render_flags["flip_wireframe"] = False - self.render_flags["all_wireframe"] = False - self.render_flags["all_solid"] = True - self._message_text = "All Solid" - elif self.render_flags["all_solid"]: - self.render_flags["flip_wireframe"] = False - self.render_flags["all_wireframe"] = False - self.render_flags["all_solid"] = False - self._message_text = "Default Wireframe" - else: - self.render_flags["flip_wireframe"] = True - self.render_flags["all_wireframe"] = False - self.render_flags["all_solid"] = False - self._message_text = "Flip Wireframe" - - # Z resets the camera viewpoint - elif symbol == pyglet.window.key.Z: - self._reset_view() - - # i toggles instruction display - elif symbol == pyglet.window.key.I: - self._display_instr = not self._display_instr - - elif symbol == pyglet.window.key.P: - self._renderer.reload_program() - - if self._message_text is not None: - self._message_opac = 1.0 + self._ticks_till_fade - - return self.viewer_interaction.on_key_press(symbol, modifiers) + # Continue to plugins after registered callback + + # Delegate to viewer plugin + return self.interaction_plugin.on_key_press(symbol, modifiers) def on_key_release(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: """Record a key release.""" - return self.viewer_interaction.on_key_release(symbol, modifiers) + return self.interaction_plugin.on_key_release(symbol, modifiers) @staticmethod def _time_event(dt, self): @@ -1087,8 +914,7 @@ def _get_save_filename(self, file_exts): try: # Importing tkinter is very slow and not used very often. Let's delay import. - from tkinter import Tk - from tkinter import filedialog + from tkinter import Tk, filedialog if root is None: root = Tk() @@ -1230,7 +1056,8 @@ def get_program(self, vertex_shader, fragment_shader, geometry_shader=None, defi def start(self, auto_refresh=True): import pyglet # For some reason, this is necessary if 'pyglet.window.xlib' fails to import... try: - import pyglet.window.xlib, pyglet.display.xlib + import pyglet.display.xlib + import pyglet.window.xlib xlib_exceptions = (pyglet.window.xlib.XlibException, pyglet.display.xlib.NoSuchDisplayException) except ImportError: xlib_exceptions = () @@ -1301,7 +1128,7 @@ def start(self, auto_refresh=True): self._exception = e return else: - raise RuntimeError(f"Unable to initialize an OpenGL 3+ context.") from e + raise RuntimeError("Unable to initialize an OpenGL 3+ context.") from e pyglet.window.xlib._have_utf8 = False confs.insert(0, conf) except (pyglet.window.NoSuchConfigException, pyglet.gl.ContextException) as e: @@ -1311,7 +1138,7 @@ def start(self, auto_refresh=True): self._exception = e return else: - raise RuntimeError(f"Unable to initialize an OpenGL 3+ context.") from e + raise RuntimeError("Unable to initialize an OpenGL 3+ context.") from e if self._run_in_thread: pyglet.clock.schedule_interval(Viewer._time_event, 1.0 / self.viewer_flags["refresh_rate"], self) @@ -1398,7 +1225,7 @@ def refresh(self): self.flip() def update_on_sim_step(self): - self.viewer_interaction.update_on_sim_step() + self.interaction_plugin.update_on_sim_step() def _compute_initial_camera_pose(self): centroid = self.scene.centroid diff --git a/genesis/options/sensors/raycaster.py b/genesis/options/sensors/raycaster.py index b21d41eaf1..0a42e07d99 100644 --- a/genesis/options/sensors/raycaster.py +++ b/genesis/options/sensors/raycaster.py @@ -55,6 +55,21 @@ def ray_starts(self) -> torch.Tensor: # ============================== Generic Patterns ============================== +def _sanitize_rays_to_tensor(rays: Sequence[float]) -> torch.Tensor: + tensor = torch.tensor(rays, dtype=gs.tc_float, device=gs.device) + if tensor.ndim < 2 or tensor.shape[-1] != 3: + gs.raise_exception(f"Rays should have shape (..., 3). Got: {tensor.shape}") + return tensor + + +class RaycastCustomPattern(RaycastPattern): + + def __init__(self, ray_dirs: Sequence[float], ray_starts: Sequence[float]): + self._ray_dirs = _sanitize_rays_to_tensor(ray_dirs) + self._ray_starts = _sanitize_rays_to_tensor(ray_starts) + self._return_shape: tuple[int, ...] = ray_dirs.shape[:-1] + + class GridPattern(RaycastPattern): """ Configuration for grid-based ray casting. diff --git a/genesis/options/viewer_interactions.py b/genesis/options/viewer_interactions.py new file mode 100644 index 0000000000..7e72a0e6b2 --- /dev/null +++ b/genesis/options/viewer_interactions.py @@ -0,0 +1,56 @@ +from .options import Options + + +class ViewerInteraction(Options): + """ + Base class for viewer interaction options. + + All viewer interaction option classes should inherit from this base class. + """ + + pass + + +class ViewerDefaultControls(ViewerInteraction): + """ + Default viewer interaction controls with keyboard shortcuts for recording, changing render modes, etc. + + Parameters + ---------- + keybindings : dict[str, int] + Override the default mapping of action names to keyboard key codes (pyglet.window.key.*). + """ + + keybindings: dict[str, int] = None + + +class MouseSpringViewerPlugin(ViewerDefaultControls): + """ + Options for the interactive viewer plugin that allows mouse-based object manipulation. + """ + + pass + + +class MeshPointSelectorPlugin(ViewerDefaultControls): + """ + Options for the mesh point selector plugin that allows selecting points on a mesh. + + Parameters + ---------- + sphere_radius : float + The radius of the sphere used to visualize selected points. + sphere_color : tuple + The color of the sphere used to visualize selected points. + hover_color : tuple + The color of the sphere used to visualize the point and normal when hovering over a mesh. + grid_snap : tuple[float, float, float] + Grid snap spacing for each axis (x, y, z). Any negative value disables snapping for that axis. + Default is (-1.0, -1.0, -1.0) which means no snapping. + """ + + sphere_radius: float = 0.005 + sphere_color: tuple = (0.1, 0.3, 1.0, 1.0) + hover_color: tuple = (0.3, 0.5, 1.0, 1.0) + grid_snap: tuple[float, float, float] = (-1.0, -1.0, -1.0) + output_file: str = "selected_points.csv" diff --git a/genesis/options/vis.py b/genesis/options/vis.py index 1d955c8b83..ae2ffded6b 100644 --- a/genesis/options/vis.py +++ b/genesis/options/vis.py @@ -3,6 +3,7 @@ import genesis as gs from .options import Options +from .viewer_interactions import ViewerDefaultControls, ViewerInteraction class ViewerOptions(Options): @@ -33,17 +34,19 @@ class ViewerOptions(Options): The up vector of the camera's extrinsic pose. camera_fov : float The field of view (in degrees) of the camera. + viewer_plugin : ViewerPluginOptions + Viewer plugin that adds interactive functionality to the viewer. """ - res: Optional[tuple] = None - run_in_thread: Optional[bool] = None + res: tuple | None = None + run_in_thread: bool | None = None refresh_rate: int = 60 - max_FPS: Optional[int] = 60 + max_FPS: int | None = 60 camera_pos: tuple = (3.5, 0.5, 2.5) camera_lookat: tuple = (0.0, 0.0, 0.5) camera_up: tuple = (0.0, 0.0, 1.0) camera_fov: float = 40 - enable_interaction: bool = False + viewer_plugin: ViewerInteraction = ViewerDefaultControls() class VisOptions(Options): @@ -139,7 +142,7 @@ def __init__(self, **data): f"Unsupported `render_particle_as`: {self.render_particle_as}, must be one of ['sphere', 'tet']" ) - if not self.n_rendered_envs is None: + if self.n_rendered_envs is not None: gs.logger.warning( "Viewer option 'n_rendered_envs' is deprecated and will be removed in future release. Please use " "'rendered_envs_idx' instead." diff --git a/genesis/utils/raycast.py b/genesis/utils/raycast.py new file mode 100644 index 0000000000..4779622309 --- /dev/null +++ b/genesis/utils/raycast.py @@ -0,0 +1,121 @@ +import gstaichi as ti + +import genesis as gs + + +@ti.func +def ray_triangle_intersection(ray_start, ray_dir, v0, v1, v2): + """ + Moller-Trumbore ray-triangle intersection. + + Returns: vec4(t, u, v, hit) where hit=1.0 if intersection found, 0.0 otherwise + """ + result = ti.Vector.zero(gs.ti_float, 4) + + edge1 = v1 - v0 + edge2 = v2 - v0 + + # Begin calculating determinant - also used to calculate u parameter + h = ray_dir.cross(edge2) + a = edge1.dot(h) + + # Check all conditions in sequence without early returns + valid = True + + t = gs.ti_float(0.0) + u = gs.ti_float(0.0) + v = gs.ti_float(0.0) + f = gs.ti_float(0.0) + s = ti.Vector.zero(gs.ti_float, 3) + q = ti.Vector.zero(gs.ti_float, 3) + + # If determinant is near zero, ray lies in plane of triangle + if ti.abs(a) < gs.EPS: + valid = False + + if valid: + f = 1.0 / a + s = ray_start - v0 + u = f * s.dot(h) + + if u < 0.0 or u > 1.0: + valid = False + + if valid: + q = s.cross(edge1) + v = f * ray_dir.dot(q) + + if v < 0.0 or u + v > 1.0: + valid = False + + if valid: + # At this stage we can compute t to find out where the intersection point is on the line + t = f * edge2.dot(q) + + # Ray intersection + if t <= gs.EPS: + valid = False + + if valid: + result = ti.math.vec4(t, u, v, 1.0) + + return result + + +@ti.func +def ray_aabb_intersection(ray_start, ray_dir, aabb_min, aabb_max): + """ + Fast ray-AABB intersection test. + Returns the t value of intersection, or -1.0 if no intersection. + """ + result = -1.0 + + # Use the slab method for ray-AABB intersection + sign = ti.select(ray_dir >= 0.0, 1.0, -1.0) + ray_dir = sign * ti.max(ti.abs(ray_dir), gs.EPS) + inv_dir = 1.0 / ray_dir + + t1 = (aabb_min - ray_start) * inv_dir + t2 = (aabb_max - ray_start) * inv_dir + + tmin = ti.min(t1, t2) + tmax = ti.max(t1, t2) + + t_near = ti.max(tmin.x, tmin.y, tmin.z, 0.0) + t_far = ti.min(tmax.x, tmax.y, tmax.z) + + # Check if ray intersects AABB + if t_near <= t_far: + result = t_near + + return result + + +@ti.kernel +def kernel_update_aabbs( + free_verts_state: ti.template(), + fixed_verts_state: ti.template(), + verts_info: ti.template(), + faces_info: ti.template(), + # FIXME: can't import array_class since it is before gs.init + # free_verts_state: array_class.VertsState, + # fixed_verts_state: array_class.VertsState, + # verts_info: array_class.VertsInfo, + # faces_info: array_class.FacesInfo, + aabb_state: ti.template(), +): + for i_b, i_f in ti.ndrange(free_verts_state.pos.shape[1], faces_info.verts_idx.shape[0]): + aabb_state.aabbs[i_b, i_f].min.fill(ti.math.inf) + aabb_state.aabbs[i_b, i_f].max.fill(-ti.math.inf) + + for i in ti.static(range(3)): + i_v = faces_info.verts_idx[i_f][i] + i_fv = verts_info.verts_state_idx[i_v] + if verts_info.is_fixed[i_v]: + pos_v = fixed_verts_state.pos[i_fv] + aabb_state.aabbs[i_b, i_f].min = ti.min(aabb_state.aabbs[i_b, i_f].min, pos_v) + aabb_state.aabbs[i_b, i_f].max = ti.max(aabb_state.aabbs[i_b, i_f].max, pos_v) + else: + pos_v = free_verts_state.pos[i_fv, i_b] + aabb_state.aabbs[i_b, i_f].min = ti.min(aabb_state.aabbs[i_b, i_f].min, pos_v) + aabb_state.aabbs[i_b, i_f].max = ti.max(aabb_state.aabbs[i_b, i_f].max, pos_v) diff --git a/genesis/vis/viewer.py b/genesis/vis/viewer.py index 77c4bce345..d924860d09 100644 --- a/genesis/vis/viewer.py +++ b/genesis/vis/viewer.py @@ -1,6 +1,6 @@ +import importlib import os import threading -import importlib from typing import TYPE_CHECKING import numpy as np @@ -9,11 +9,10 @@ import genesis as gs import genesis.utils.geom as gu - from genesis.ext import pyrender from genesis.repr_base import RBC -from genesis.utils.tools import Rate from genesis.utils.misc import redirect_libc_stderr, tensor_to_array +from genesis.utils.tools import Rate if TYPE_CHECKING: from genesis.options.vis import ViewerOptions @@ -40,15 +39,12 @@ def __init__(self, options: "ViewerOptions", context): self._camera_init_lookat = np.asarray(options.camera_lookat, dtype=gs.np_float) self._camera_up = np.asarray(options.camera_up, dtype=gs.np_float) self._camera_fov = options.camera_fov - self._enable_interaction = options.enable_interaction + self._viewer_plugin = options.viewer_plugin # Validate viewer options if any(e.shape != (3,) for e in (self._camera_init_pos, self._camera_init_lookat, self._camera_up)): gs.raise_exception("ViewerOptions.camera_(pos|lookat|up) must be sequences of length 3.") - if options.enable_interaction and gs.backend != gs.cpu: - gs.logger.warning("Interaction code is slow on GPU. Switch to CPU backend or disable interaction.") - self._pyrender_viewer = None self.context = context @@ -99,7 +95,7 @@ def build(self, scene): shadow=self.context.shadow, plane_reflection=self.context.plane_reflection, env_separate_rigid=self.context.env_separate_rigid, - enable_interaction=self._enable_interaction, + plugin_options=self._viewer_plugin, viewer_flags={ "window_title": f"Genesis {gs.__version__}", "refresh_rate": self._refresh_rate,