stream_hacks to main #2

Merged
skeh merged 8 commits from stream_hacks into main 2024-12-11 03:22:13 +00:00
3 changed files with 27 additions and 4 deletions
Showing only changes of commit c9db70f57d - Show all commits

View File

@ -74,6 +74,17 @@ class Process(TransformProcess):
joints[JOINT_TYPES.HEAD] = head_joint joints[JOINT_TYPES.HEAD] = head_joint
# Synthizise other joints from existing data
if not joints.get(JOINT_TYPES.CHEST) and joints.get(JOINT_TYPES.HEAD):
chest_center = joints[JOINT_TYPES.HEAD].pos.as_np()
chest_center = np.power(chest_center, 3) / (1e3 + np.power(chest_center, 2))
chest_center -= [0, 100, 0]
chest_rot = Quaternion.identity().slerp(joints[JOINT_TYPES.HEAD].rot, 0.1)
chest_joint = Joint(Point3d(*chest_center), chest_rot)
joints[JOINT_TYPES.CHEST] = chest_joint
skeleton = Skeleton(joints) skeleton = Skeleton(joints)
self._outputs['skel'].send(skeleton) self._outputs['skel'].send(skeleton)

View File

@ -1,7 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
from scipy.spatial.transform import Rotation from scipy.spatial.transform import Rotation, Slerp
from .Type import Type from .Type import Type
@ -12,6 +12,10 @@ class Quaternion(Type):
y: float y: float
z: float z: float
@classmethod
def identity(cls):
return cls(1, 0, 0, 0)
def __mul__(self, q): def __mul__(self, q):
if isinstance(q, self.__class__): if isinstance(q, self.__class__):
product = self.as_np() * q.as_np() product = self.as_np() * q.as_np()
@ -45,6 +49,16 @@ class Quaternion(Type):
def conjugate(self): def conjugate(self):
return self.__class__(self.w, -self.x, -self.y, -self.z) return self.__class__(self.w, -self.x, -self.y, -self.z)
def slerp(self, other, t):
r = Rotation.from_quat([
[self.x, self.y, self.z, self.w],
[other.x, other.y, other.z, other.w],
])
slerp = Slerp([0, 1], r)
x, y, z, w = slerp([t]).as_quat()[0]
return self.__class__(w, x, y, z)
def draw(self, canvas, origin): def draw(self, canvas, origin):
raise NotImplementedError() raise NotImplementedError()

View File

@ -2,9 +2,6 @@ from dataclasses import dataclass, field
from enum import Enum from enum import Enum
import typing import typing
import cv2
import matplotlib.pyplot as plt
from .Type import Type from .Type import Type
from .Point3d import Point3d from .Point3d import Point3d
from .Quaternion import Quaternion from .Quaternion import Quaternion
@ -34,6 +31,7 @@ default_colors = [
[128, 255, 0], [255, 0, 128], [0, 255, 128], [128, 255, 0], [255, 0, 128], [0, 255, 128],
] ]
@dataclass @dataclass
class Joint(Type): class Joint(Type):
pos: Point3d pos: Point3d