Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 5 additions & 116 deletions genesis/engine/sensors/raycaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@

import genesis as gs
import genesis.utils.array_class as array_class
from genesis.engine.bvh import AABB, LBVH, STACK_SIZE
from genesis.options.sensors import (
Raycaster as RaycasterOptions,
)
from genesis.options.sensors import (
RaycastPattern,
)
from genesis.engine.bvh import AABB, LBVH, STACK_SIZE
from genesis.utils.geom import (
ti_normalize,
ti_transform_by_quat,
Expand All @@ -21,6 +23,7 @@
transform_by_trans_quat,
)
from genesis.utils.misc import concat_with_tensor, make_tensor_field
from genesis.utils.raycast import kernel_update_aabbs, ray_aabb_intersection, ray_triangle_intersection
from genesis.vis.rasterizer_context import RasterizerContext

from .base_sensor import (
Expand All @@ -36,119 +39,6 @@
from genesis.utils.ring_buffer import TensorRingBuffer


@ti.func
def ray_triangle_intersection(ray_start, ray_dir, v0, v1, v2):
"""
Moller-Trumbore ray-triangle intersection.

Returns: vec4(t, u, v, hit) where hit=1.0 if intersection found, 0.0 otherwise
"""
result = ti.Vector.zero(gs.ti_float, 4)

edge1 = v1 - v0
edge2 = v2 - v0

# Begin calculating determinant - also used to calculate u parameter
h = ray_dir.cross(edge2)
a = edge1.dot(h)

# Check all conditions in sequence without early returns
valid = True

t = gs.ti_float(0.0)
u = gs.ti_float(0.0)
v = gs.ti_float(0.0)
f = gs.ti_float(0.0)
s = ti.Vector.zero(gs.ti_float, 3)
q = ti.Vector.zero(gs.ti_float, 3)

# If determinant is near zero, ray lies in plane of triangle
if ti.abs(a) < gs.EPS:
valid = False

if valid:
f = 1.0 / a
s = ray_start - v0
u = f * s.dot(h)

if u < 0.0 or u > 1.0:
valid = False

if valid:
q = s.cross(edge1)
v = f * ray_dir.dot(q)

if v < 0.0 or u + v > 1.0:
valid = False

if valid:
# At this stage we can compute t to find out where the intersection point is on the line
t = f * edge2.dot(q)

# Ray intersection
if t <= gs.EPS:
valid = False

if valid:
result = ti.math.vec4(t, u, v, 1.0)

return result


@ti.func
def ray_aabb_intersection(ray_start, ray_dir, aabb_min, aabb_max):
"""
Fast ray-AABB intersection test.
Returns the t value of intersection, or -1.0 if no intersection.
"""
result = -1.0

# Use the slab method for ray-AABB intersection
sign = ti.select(ray_dir >= 0.0, 1.0, -1.0)
ray_dir = sign * ti.max(ti.abs(ray_dir), gs.EPS)
inv_dir = 1.0 / ray_dir

t1 = (aabb_min - ray_start) * inv_dir
t2 = (aabb_max - ray_start) * inv_dir

tmin = ti.min(t1, t2)
tmax = ti.max(t1, t2)

t_near = ti.max(tmin.x, tmin.y, tmin.z, 0.0)
t_far = ti.min(tmax.x, tmax.y, tmax.z)

# Check if ray intersects AABB
if t_near <= t_far:
result = t_near

return result


@ti.kernel
def kernel_update_aabbs(
free_verts_state: array_class.VertsState,
fixed_verts_state: array_class.VertsState,
verts_info: array_class.VertsInfo,
faces_info: array_class.FacesInfo,
aabb_state: ti.template(),
):
for i_b, i_f in ti.ndrange(free_verts_state.pos.shape[1], faces_info.verts_idx.shape[0]):
aabb_state.aabbs[i_b, i_f].min.fill(ti.math.inf)
aabb_state.aabbs[i_b, i_f].max.fill(-ti.math.inf)

for i in ti.static(range(3)):
i_v = faces_info.verts_idx[i_f][i]
i_fv = verts_info.verts_state_idx[i_v]
if verts_info.is_fixed[i_v]:
pos_v = fixed_verts_state.pos[i_fv]
aabb_state.aabbs[i_b, i_f].min = ti.min(aabb_state.aabbs[i_b, i_f].min, pos_v)
aabb_state.aabbs[i_b, i_f].max = ti.max(aabb_state.aabbs[i_b, i_f].max, pos_v)
else:
pos_v = free_verts_state.pos[i_fv, i_b]
aabb_state.aabbs[i_b, i_f].min = ti.min(aabb_state.aabbs[i_b, i_f].min, pos_v)
aabb_state.aabbs[i_b, i_f].max = ti.max(aabb_state.aabbs[i_b, i_f].max, pos_v)


@ti.kernel
def kernel_cast_rays(
fixed_verts_state: array_class.VertsState,
Expand Down Expand Up @@ -195,8 +85,7 @@ def kernel_cast_rays(
ray_dir_local = ti.math.vec3(ray_directions[i_p, 0], ray_directions[i_p, 1], ray_directions[i_p, 2])
ray_direction_world = ti_normalize(ti_transform_by_quat(ray_dir_local, link_quat), gs.EPS)

# --- 2. BVH Traversal ---
# FIXME: this duplicates the logic in LBVH.query() which also does traversal
# --- 2. BVH Traversal for ray intersection ---

max_range = max_ranges[i_s]
hit_face = -1
Expand Down
12 changes: 12 additions & 0 deletions genesis/ext/pyrender/interaction/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from .base_interaction import (
EVENT_HANDLE_STATE,
EVENT_HANDLED,
VIEWER_PLUGIN_MAP,
BaseViewerInteraction,
register_viewer_plugin,
)
from .plugins.mesh_selector import MeshPointSelectorPlugin
from .plugins.mouse_interaction import MouseSpringViewerPlugin
from .plugins.viewer_controls import ViewerDefaultControls
from .ray import Plane, Ray, RayHit
from .vec3 import Color, Pose, Quat, Vec3
127 changes: 127 additions & 0 deletions genesis/ext/pyrender/interaction/base_interaction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from typing import TYPE_CHECKING, Literal, Type

import numpy as np

from .ray import Ray
from .vec3 import Vec3

if TYPE_CHECKING:
from genesis.engine.scene import Scene
from genesis.ext.pyrender.node import Node
from genesis.options.viewer_interactions import ViewerInteraction as ViewerPluginOptions


EVENT_HANDLE_STATE = Literal[True] | None
EVENT_HANDLED: Literal[True] = True

# Global map from options class to viewer plugin class
VIEWER_PLUGIN_MAP: dict[Type["ViewerPluginOptions"], Type["BaseViewerInteraction"]] = {}


def register_viewer_plugin(options_cls: Type["ViewerPluginOptions"]):
"""
Decorator to register a viewer plugin class with its corresponding options class.

Parameters
----------
options_cls : Type[ViewerPluginOptions]
The options class that configures this viewer plugin.

Returns
-------
Callable
The decorator function that registers the plugin class.

Example
-------
@register_viewer_plugin(ViewerInteractionOptions)
class ViewerInteraction(ViewerInteractionBase):
...
"""
def _impl(plugin_cls: Type["BaseViewerInteraction"]):
VIEWER_PLUGIN_MAP[options_cls] = plugin_cls
return plugin_cls
return _impl

# Note: Viewer window is based on pyglet.window.Window, mouse events are defined in pyglet.window.BaseWindow

class BaseViewerInteraction():
"""
Base class for handling pyglet.window.Window events.
"""

def __init__(
self,
viewer,
options: "ViewerPluginOptions",
camera: "Node",
scene: "Scene",
viewport_size: tuple[int, int],
):
self.viewer = viewer
self.options: "ViewerPluginOptions" = options
self.camera: 'Node' = camera
self.scene: 'Scene' = scene
self.viewport_size: tuple[int, int] = viewport_size

self.camera_yfov: float = camera.camera.yfov
self.tan_half_fov: float = np.tan(0.5 * self.camera_yfov)

def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STATE:
pass

def on_mouse_drag(self, x: int, y: int, dx: int, dy: int, buttons: int, modifiers: int) -> EVENT_HANDLE_STATE:
pass

def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE:
pass

def on_mouse_release(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE:
pass

def on_key_press(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE:
pass

def on_key_release(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE:
pass

def on_resize(self, width: int, height: int) -> EVENT_HANDLE_STATE:
self.viewport_size = (width, height)
self.tan_half_fov = np.tan(0.5 * self.camera_yfov)

def update_on_sim_step(self) -> None:
pass

def on_draw(self) -> None:
pass

def on_close(self) -> None:
pass

def _screen_position_to_ray(self, x: float, y: float) -> Ray:
# convert screen position to ray
x = x - 0.5 * self.viewport_size[0]
y = y - 0.5 * self.viewport_size[1]
x = 2.0 * x / self.viewport_size[1] * self.tan_half_fov
y = 2.0 * y / self.viewport_size[1] * self.tan_half_fov

# Note: ignoring pixel aspect ratio

mtx = self.camera.matrix
position = Vec3.from_array(mtx[:3, 3])
forward = Vec3.from_array(-mtx[:3, 2])
right = Vec3.from_array(mtx[:3, 0])
up = Vec3.from_array(mtx[:3, 1])

direction = forward + right * x + up * y
return Ray(position, direction)

def _get_camera_forward(self) -> Vec3:
mtx = self.camera.matrix
return Vec3.from_array(-mtx[:3, 2])

def _get_camera_ray(self) -> Ray:
mtx = self.camera.matrix
position = Vec3.from_array(mtx[:3, 3])
forward = Vec3.from_array(-mtx[:3, 2])
return Ray(position, forward)
27 changes: 27 additions & 0 deletions genesis/ext/pyrender/interaction/keybindings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from pyglet.window.key import symbol_string


class Keybindings:
def __init__(self, map: dict[str, int] = {}, **kwargs: dict[str, int]):
self._map: dict[str, int] = {**map, **kwargs}

def __getattr__(self, name: str) -> int:
if name in self._map:
return self._map[name]
raise AttributeError(f"Action '{name}' not found in keybindings.")

def as_instruction_texts(self, padding, exclude: tuple[str]) -> list[str]:
width = 4 + padding
return [
f"{'[' + symbol_string(self._map[action]).lower():>{width}}]: " +
action.replace('_', ' ') for action in self._map.keys() if action not in exclude
]

def extend(self, mapping: dict[str, int], replace_only: bool = False) -> None:
current_keys = self._map.keys()
for action, key in mapping.items():
if replace_only and action not in self._map:
raise KeyError(f"Action '{action}' not found. Available actions: {list(self._map.keys())}")
if key in current_keys:
raise ValueError(f"Key '{symbol_string(key)}' is already assigned to another action.")
self._map[action] = key
Loading