diff --git a/nuke_live_portrait.py b/nuke_live_portrait.py index 51de19b..db5d471 100644 --- a/nuke_live_portrait.py +++ b/nuke_live_portrait.py @@ -1,12 +1,27 @@ -import torch import logging +import torch +import os +import torch.nn.functional as F +from torch import nn from src.utils.helper import load_model +from src.utils.camera import get_rotation_matrix +from typing import List, Optional, Tuple, Dict +import numpy as np +import torchvision -logging.basicConfig(level=logging.INFO) + +try: + from rich.logging import RichHandler +except ImportError: + RichHandler = None + +# Disable CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR warnings +torch.backends.cudnn.benchmark = True + +logging.basicConfig(level=logging.INFO, handlers=[RichHandler()] if RichHandler else None) LOGGER = logging.getLogger(__name__) - -model_config = { +MODEL_CONFIG = { "model_params": { "appearance_feature_extractor_params": { "image_channel": 3, @@ -60,10 +75,857 @@ model_config = { } +torch.set_printoptions(precision=4, sci_mode=False) +np.set_printoptions(precision=4, suppress=True) + + +def file_size(file_path): + """Get the file size in MB.""" + size_in_bytes = os.path.getsize(file_path) + return int(size_in_bytes / (1024 * 1024)) + + +# --- Nuke models --- + + +class LivePortraitNukeFaceDetection(nn.Module): + """Live Portrait model for Nuke. + + Detect facial landmarks, then crop, align, and stabilize + the face for further processing. + + Args: + face_detection: The face detection model. + face_alignment: The face alignment model. + scale: The scale of the face. + """ + + def __init__(self, face_detection, face_alignment, scale=2.3) -> None: + """Initialize the model.""" + super().__init__() + self.face_detection = face_detection + self.face_alignment = face_alignment + self.fiter_threshold = 0.5 + self.reference_scale = 195 + self.resolution = 256 + self.resize = torchvision.transforms.Resize( + (256, 256), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + max_size=None, + antialias=True, + ) + + self.crop_cfg_dsize = 512 + self.crop_cfg_scale = scale + self.crop_cfg_vx_ratio = 0.0 + self.crop_cfg_vy_ratio = -0.125 + self.crop_cfg_face_index = 0 + self.crop_cfg_face_index_order = "large-small" + self.crop_cfg_rotate = True + + def forward(self, x: torch.Tensor) -> torch.Tensor: + MAIN_FACE_INDEX = 0 + resolution = self.resolution + b, c, h, w = x.shape + device = torch.device("cuda") if x.is_cuda else torch.device("cpu") + x0 = x.clone() + + x = x * 255 + x = x.flip(-3) # RGB to BGR + x = x - torch.tensor([104.0, 117.0, 123.0], device=device).view(1, 3, 1, 1) + + olist = self.face_detection(x) # olist = net(img_batch) # patched uint8_t overflow error + + for i in range(len(olist) // 2): + olist[i * 2] = F.softmax(olist[i * 2], dim=1) + + bboxlists = self.get_predictions(olist) + detected_faces = self.filter_bboxes(bboxlists[0]) + + if detected_faces.size(0) == 0: + print("LivePortrait Face Detection: No faces detected") + return torch.zeros((b, 4, h, w), device=device) + + d = detected_faces[MAIN_FACE_INDEX] + + d_top_left = float(d[0]) + d_bottom_right = float(d[1]) + d_top_right = float(d[2]) + d_bottom_left = float(d[3]) + + center_x = d_top_right - (d_top_right - d_top_left) / 2.0 + center_y = d_bottom_left - (d_bottom_left - d_bottom_right) / 2.0 + center_y = center_y - (d_bottom_left - d_bottom_right) * 0.12 + + scale = (d_top_right - d_top_left + d_bottom_left - d_bottom_right) / self.reference_scale + + ul = self.transform( + torch.tensor([[1, 1]]), + torch.tensor([[center_x, center_y]], device=device), + scale, + resolution, + True, + )[0] + ul_x = int(ul[0]) + ul_y = int(ul[1]) + + br = self.transform( + torch.tensor([[resolution, resolution]]), + torch.tensor([[center_x, center_y]], device=device), + scale, + resolution, + True, + )[0] + br_x = int(br[0]) + br_y = int(br[1]) + + crop = torch.zeros([b, c, br_y - ul_y, br_x - ul_x], device=device) + newX = torch.tensor([max(1, -ul_x + 1), min(br_x, w) - ul_x]) + newY = torch.tensor([max(1, -ul_y + 1), min(br_y, h) - ul_y]) + + oldX = torch.tensor([max(1, ul_x + 1), min(br_x, w)]) + oldY = torch.tensor([max(1, ul_y + 1), min(br_y, h)]) + + crop[:, :, newY[0] - 1 : newY[1], newX[0] - 1 : newX[1]] = x0[ + :, :, oldY[0] - 1 : oldY[1], oldX[0] - 1 : oldX[1] + ] + + crop_resized = self.resize(crop) + fa_out = self.face_alignment(crop_resized) + + pts, pts_img, scores = self.get_preds_fromhm( + fa_out, torch.tensor([center_x, center_y], device=device), scale + ) + + lmk_tensor = pts_img.squeeze() + + ret_dct = self.crop_image( + x0.squeeze(), + lmk_tensor, + dsize=self.crop_cfg_dsize, + scale=self.crop_cfg_scale, + vx_ratio=self.crop_cfg_vx_ratio, + vy_ratio=self.crop_cfg_vy_ratio, + flag_do_rot=True, + use_lip=True, + ) + + ret_dct["cropped_image_256"] = self.resize(ret_dct["img_crop"].permute(2, 0, 1).unsqueeze(0)) + ret_dct["pt_crop_256x256"] = ret_dct["pt_crop"] * 256 / self.crop_cfg_dsize + + # TODO: implement self.landmark_runner.run(img_rgb, pts) + + out = torch.zeros((b, 4, h, w), device=device) # RGBA + out[0, :3, 0:256, 0:256] = ret_dct["cropped_image_256"].squeeze() + out[0, -1, 0, :5] = detected_faces[MAIN_FACE_INDEX].reshape(-1) + out[0, -1, 1:3, :68] = ret_dct["pt_crop"].permute(1, 0) + out[0, -1, 3:4, :9] = ret_dct["M_o2c"].reshape(-1) + out[0, -1, 4:5, :9] = ret_dct["M_c2o"].reshape(-1) + + return out.contiguous() + + # All functions below were adapted from the original code, + # stripped of any external dependencies. + + def decode(self, loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + + """ + boxes = torch.cat( + ( + priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1]), + ), + 1, + ) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + + def get_predictions(self, olist: List[torch.Tensor]): + variances = torch.tensor([0.1, 0.2], dtype=torch.float32, device=olist[0].device) + bboxlists = [] + + for i in range(len(olist) // 2): + ocls, oreg = olist[i * 2], olist[i * 2 + 1] + stride = 2 ** (i + 2) + mask = ocls[:, 1, :, :] > 0.05 + hindex, windex = torch.where(mask)[1], torch.where(mask)[2] + + for idx in range(hindex.size(0)): + h = hindex[idx] + w = windex[idx] + axc, ayc = stride / 2 + w * stride, stride / 2 + h * stride + axc_float = float(axc) + ayx_float = float(ayc) + priors = torch.tensor( + [[axc_float / 1.0, ayx_float / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]], + device="cuda", + ) + score = ocls[:, 1, h, w].unsqueeze(1) + loc = oreg[:, :, h, w].clone() + boxes = self.decode(loc, priors, variances) + bboxlists.append(torch.cat((boxes, score), dim=1)) + + if len(bboxlists) == 0: + output = torch.zeros((1, 0, 5)) # Assuming 5 columns in the final output + else: + output = torch.stack(bboxlists, dim=1) + + return output + + def nms(self, dets: torch.Tensor, thresh: float) -> List[int]: + if dets.size(0) == 0: + return [] + + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = torch.argsort(scores, descending=True) + + keep: List[int] = [] + + while order.numel() > 0: + i = int(order[0]) + keep.append(i) + + if order.size(0) == 1: + break + + xx1 = torch.max(x1[i], x1[order[1:]]) + yy1 = torch.max(y1[i], y1[order[1:]]) + xx2 = torch.min(x2[i], x2[order[1:]]) + yy2 = torch.min(y2[i], y2[order[1:]]) + + w = torch.clamp(xx2 - xx1 + 1, min=0.0) + h = torch.clamp(yy2 - yy1 + 1, min=0.0) + + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = torch.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep + + def filter_bboxes(self, bboxlist): + nms_thresh = 0.3 + + if bboxlist.size(0) > 0: + keep_indices = self.nms(bboxlist, nms_thresh) + bboxlist = bboxlist[keep_indices] + + mask = bboxlist[:, -1] > self.fiter_threshold + bboxlist = bboxlist[mask] + + return bboxlist + + def transform(self, points, center, scale: float, resolution: int, invert: bool = False): + """Generate and affine transformation matrix. + + Given a set of points, a center, a scale and a targer resolution, the + function generates and affine transformation matrix. If invert is ``True`` + it will produce the inverse transformation. + + Arguments: + points -- the input 2D points + center -- the center around which to perform the transformations + scale -- the scale of the face/object + resolution -- the output resolution + + Keyword Arguments: + invert {bool} -- define wherever the function should produce the direct or the + inverse transformation matrix (default: {False}) + """ + N = points.shape[0] + _pt = torch.ones(N, 3, device=points.device) + _pt[:, 0:2] = points + + h = 200.0 * scale + + t = torch.eye(3, device=points.device).unsqueeze(0).repeat(N, 1, 1) # [N, 3, 3] + t[:, 0, 0] = resolution / h + t[:, 1, 1] = resolution / h + t[:, 0, 2] = resolution * (-center[:, 0] / h + 0.5) + t[:, 1, 2] = resolution * (-center[:, 1] / h + 0.5) + + if invert: + t = torch.inverse(t) + + new_point = torch.bmm(t, _pt.unsqueeze(-1))[:, 0:2, 0] # [N, 2] + + return new_point.long() + + def _get_preds_fromhm( + self, + hm: torch.Tensor, + idx: torch.Tensor, + center: Optional[torch.Tensor], + scale: Optional[float] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Obtain (x,y) coordinates given a set of N heatmaps and the + coresponding locations of the maximums. If the center + and the scale is provided the function will return the points also in + the original coordinate frame. + + Arguments: + hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H] + + Keyword Arguments: + center {torch.tensor} -- the center of the bounding box (default: {None}) + scale {float} -- face scale (default: {None}) + """ + B, C, H, W = hm.shape + idx = idx + 1 + + preds = idx.unsqueeze(-1).repeat(1, 1, 2).float() + preds[:, :, 0] = (preds[:, :, 0] - 1) % W + 1 + preds[:, :, 1] = torch.floor((preds[:, :, 1] - 1) / H) + 1 + + hm_padded = F.pad(hm, (1, 1, 1, 1), mode="replicate") # [B, C, H+2, W+2] + + pX = preds[:, :, 0].long() - 1 + pY = preds[:, :, 1].long() - 1 + + pX_padded = pX + 1 + pY_padded = pY + 1 + + pX_p1 = pX_padded + 1 + pX_m1 = pX_padded - 1 + pY_p1 = pY_padded + 1 + pY_m1 = pY_padded - 1 + + B_C = B * C + hm_padded_flat = hm_padded.contiguous().view(B_C, H + 2, W + 2) + pX_padded_flat = pX_padded.view(B_C) + pY_padded_flat = pY_padded.view(B_C) + pX_p1_flat = pX_p1.view(B_C) + pX_m1_flat = pX_m1.view(B_C) + pY_p1_flat = pY_p1.view(B_C) + pY_m1_flat = pY_m1.view(B_C) + + batch_channel_indices = torch.arange(B_C, device=hm.device) + + hm_padded_flat = hm_padded_flat.float() + val_pX_p1 = hm_padded_flat[batch_channel_indices, pY_padded_flat, pX_p1_flat] + val_pX_m1 = hm_padded_flat[batch_channel_indices, pY_padded_flat, pX_m1_flat] + val_pY_p1 = hm_padded_flat[batch_channel_indices, pY_p1_flat, pX_padded_flat] + val_pY_m1 = hm_padded_flat[batch_channel_indices, pY_m1_flat, pX_padded_flat] + + diff_x = val_pX_p1 - val_pX_m1 + diff_y = val_pY_p1 - val_pY_m1 + diff = torch.stack([diff_x, diff_y], dim=1) # Shape [B_C, 2] + + sign_diff = torch.sign(diff) + preds_flat = preds.view(B_C, 2) + preds_flat += sign_diff * 0.25 + + preds = preds_flat.view(B, C, 2) + preds -= 0.5 + + if center is not None and scale is not None: + preds_flat = preds.view(B * C, 2) + if center.dim() == 1: + center_expanded = center.view(1, 2).expand(B * C, 2) + else: + center_expanded = center.view(B, 1, 2).expand(B, C, 2).reshape(B * C, 2) + preds_orig_flat = self.transform(preds_flat, center_expanded, scale, H, invert=True) + preds_orig = preds_orig_flat.view(B, C, 2) + else: + preds_orig = torch.zeros_like(preds) + + return preds, preds_orig + + def get_preds_fromhm( + self, hm: torch.Tensor, center: torch.Tensor, scale: Optional[float] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Obtain (x,y) coordinates given a set of N heatmaps. If the center + and the scale is provided the function will return the points also in + the original coordinate frame. + + Arguments: + hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H] + + Keyword Arguments: + center {torch.tensor} -- the center of the bounding box (default: {None}) + scale {float} -- face scale (default: {None}) + """ + B, C, H, W = hm.shape + hm_reshape = hm.view(B, C, H * W) + idx = torch.argmax(hm_reshape, dim=-1) + scores = torch.gather(hm_reshape, dim=-1, index=idx.unsqueeze(-1)).squeeze(-1) + preds, preds_orig = self._get_preds_fromhm(hm, idx, center, scale) + return preds, preds_orig, scores + + def crop_image( + self, + img: torch.Tensor, + pts: torch.Tensor, + dsize: int = 224, + scale: float = 1.5, + vx_ratio: float = 0.0, + vy_ratio: float = -0.1, + flag_do_rot: bool = True, + use_lip: bool = True, + ) -> Dict[str, torch.Tensor]: + pts = pts.float() + M_INV, _ = self._estimate_similar_transform_from_pts( + pts, + dsize=dsize, + scale=scale, + vx_ratio=vx_ratio, + vy_ratio=vy_ratio, + flag_do_rot=flag_do_rot, + use_lip=use_lip, + ) + + img_crop = self._transform_img(img, M_INV, dsize) + pt_crop = self._transform_pts(pts, M_INV) + + M_o2c = torch.vstack( + [M_INV, torch.tensor([0, 0, 1], dtype=M_INV.dtype, device=M_INV.device)] + ) + M_c2o = torch.inverse(M_o2c) + + ret_dct = { + "M_o2c": M_o2c, + "M_c2o": M_c2o, + "img_crop": img_crop, + "pt_crop": pt_crop, + } + + return ret_dct + + def _estimate_similar_transform_from_pts( + self, + pts: torch.Tensor, + dsize: int, + scale: float = 1.5, + vx_ratio: float = 0.0, + vy_ratio: float = -0.1, + flag_do_rot: bool = True, + use_lip: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Estimate Similar Transform from Points + + Calculate the affine matrix of the cropped image from sparse points, + the original image to the cropped image, the inverse is the cropped image to the original image + + pts: landmark, 101 or 68 points or other points, Nx2 + scale: the larger scale factor, the smaller face ratio + vx_ratio: x shift + vy_ratio: y shift, the smaller the y shift, the lower the face region + rot_flag: if it is true, conduct correction + """ + center, size, angle = self.parse_rect_from_landmark( + pts, scale=scale, vx_ratio=vx_ratio, vy_ratio=vy_ratio, use_lip=use_lip + ) + + s = dsize / size[0] + tgt_center = torch.tensor([dsize / 2.0, dsize / 2.0], dtype=pts.dtype) + + if flag_do_rot: + costheta = torch.cos(angle) + sintheta = torch.sin(angle) + cx, cy = center[0], center[1] + tcx, tcy = tgt_center[0], tgt_center[1] + + M_INV = torch.zeros((2, 3), dtype=pts.dtype, device=pts.device) + M_INV[0, 0] = s * costheta + M_INV[0, 1] = s * sintheta + M_INV[0, 2] = tcx - s * (costheta * cx + sintheta * cy) + M_INV[1, 0] = -s * sintheta + M_INV[1, 1] = s * costheta + M_INV[1, 2] = tcy - s * (-sintheta * cx + costheta * cy) + else: + M_INV = torch.zeros((2, 3), dtype=pts.dtype, device=pts.device) + M_INV[0, 0] = s + M_INV[0, 1] = 0.0 + M_INV[0, 2] = tgt_center[0] - s * center[0] + M_INV[1, 0] = 0.0 + M_INV[1, 1] = s + M_INV[1, 2] = tgt_center[1] - s * center[1] + + M_INV_H = torch.cat( + [M_INV, torch.tensor([[0, 0, 1]], dtype=M_INV.dtype, device=pts.device)], dim=0 + ) + M = torch.inverse(M_INV_H) + + return M_INV, M[:2, :] + + def parse_rect_from_landmark( + self, + pts: torch.Tensor, + scale: float = 1.5, + need_square: bool = True, + vx_ratio: float = 0.0, + vy_ratio: float = 0.0, + use_deg_flag: bool = False, + use_lip: bool = True, + ): + """Parse center, size, angle from 101/68/5/x landmarks + + vx_ratio: the offset ratio along the pupil axis x-axis, multiplied by size + vy_ratio: the offset ratio along the pupil axis y-axis, multiplied by size, which is used to contain more forehead area + """ + pt2 = self.parse_pt2_from_pt68(pts, use_lip=use_lip) + + if not use_lip: + v = pt2[1] - pt2[0] + new_pt1 = torch.stack([pt2[0, 0] - v[1], pt2[0, 1] + v[0]], dim=0) + pt2 = torch.stack([pt2[0], new_pt1], dim=0) + + uy = pt2[1] - pt2[0] + l = torch.norm(uy) + if l.item() <= 1e-3: + uy = torch.tensor([0.0, 1.0], dtype=pts.dtype) + else: + uy = uy / l + ux = torch.stack((uy[1], -uy[0])) + + angle = torch.acos(ux[0]) + if ux[1].item() < 0: + angle = -angle + + M = torch.stack([ux, uy], dim=0) + + center0 = torch.mean(pts, dim=0) + rpts = torch.matmul(pts - center0, M.T) + lt_pt = torch.min(rpts, dim=0)[0] + rb_pt = torch.max(rpts, dim=0)[0] + center1 = (lt_pt + rb_pt) / 2 + + size = rb_pt - lt_pt + if need_square: + m = torch.max(size[0], size[1]) + size = torch.stack([m, m]) + + size = size * scale + center = center0 + ux * center1[0] + uy * center1[1] + center = center + ux * (vx_ratio * size) + uy * (vy_ratio * size) + + if use_deg_flag: + angle = torch.rad2deg(angle) + + return center, size, angle + + def parse_pt2_from_pt68(self, pt68: torch.Tensor, use_lip: bool = True) -> torch.Tensor: + if use_lip: + left_eye = pt68[42:48].mean(dim=0) + right_eye = pt68[36:42].mean(dim=0) + mouth_center = (pt68[48] + pt68[54]) / 2.0 + + pt68_new = torch.stack([left_eye, right_eye, mouth_center], dim=0) + pt2 = torch.stack([(pt68_new[0] + pt68_new[1]) / 2.0, pt68_new[2]], dim=0) + else: + left_eye = pt68[42:48].mean(dim=0) + right_eye = pt68[36:42].mean(dim=0) + + pt2 = torch.stack([left_eye, right_eye], dim=0) + + v = pt2[1] - pt2[0] + pt2[1, 0] = pt2[0, 0] - v[1] + pt2[1, 1] = pt2[0, 1] + v[0] + + return pt2 + + def _transform_img(self, img: torch.Tensor, M: torch.Tensor, dsize: int): + """conduct similarity or affine transformation to the image, do not do border operation! + img: + M: 2x3 matrix or 3x3 matrix + dsize: target shape (width, height) + """ + if isinstance(dsize, (tuple, list)): + out_h, out_w = dsize + else: + out_h = out_w = dsize + + C, H_in, W_in = img.shape + + M_norm = self._normalize_affine(M, W_in, H_in, out_w, out_h) + grid = F.affine_grid(M_norm.unsqueeze(0), [1, C, out_h, out_w], align_corners=False) + img = img.unsqueeze(0) + + img_warped = F.grid_sample( + img, grid, align_corners=False, mode="bilinear", padding_mode="zeros" + ) + img_warped = img_warped.squeeze(0) + img_warped = img_warped.permute(1, 2, 0) # [H, W, C] + + return img_warped + + def _normalize_affine(self, M: torch.Tensor, W_in: int, H_in: int, W_out: int, H_out: int): + device = M.device + dtype = M.dtype + + M_h = torch.cat([M, torch.tensor([[0.0, 0.0, 1.0]], dtype=dtype, device=device)], dim=0) + + M_h_inv = torch.inverse(M_h) + + W_in_f = float(W_in) + H_in_f = float(H_in) + W_out_f = float(W_out) + H_out_f = float(H_out) + + S_in = torch.zeros(3, 3, dtype=dtype, device=device) + S_in[0, 0] = 2.0 / W_in_f + S_in[0, 1] = 0.0 + S_in[0, 2] = -1.0 + S_in[1, 0] = 0.0 + S_in[1, 1] = 2.0 / H_in_f + S_in[1, 2] = -1.0 + S_in[2, 0] = 0.0 + S_in[2, 1] = 0.0 + S_in[2, 2] = 1.0 + + S_out = torch.zeros(3, 3, dtype=dtype, device=device) + S_out[0, 0] = W_out_f / 2.0 + S_out[0, 1] = 0.0 + S_out[0, 2] = W_out_f / 2.0 + S_out[1, 0] = 0.0 + S_out[1, 1] = H_out_f / 2.0 + S_out[1, 2] = H_out_f / 2.0 + S_out[2, 0] = 0.0 + S_out[2, 1] = 0.0 + S_out[2, 2] = 1.0 + + M_combined = torch.matmul(torch.matmul(S_in, M_h_inv), S_out) + + return M_combined[:2, :] + + def _transform_pts(self, pts: torch.Tensor, M: torch.Tensor) -> torch.Tensor: + return torch.matmul(pts, M[:, :2].T) + M[:, 2] + + +class LivePortraitNukeAppearanceFeatureExtractor(nn.Module): + """Live Portrait model for Nuke. + + Args: + encoder: The encoder model. + decoder: The decoder model. + n: Depth Anything window list parameter. + """ + + def __init__(self, model) -> None: + """Initialize the model.""" + super().__init__() + self.model = model + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, c, h, w = x.shape + + out = self.model(x) # Tensor[1, 32, 16, 64, 64] + + # Split batch by 2 as 16 rows in the red and green channels + out_block = out.view([1, 2, 16, 16, 64, 64]) + out_block = out_block.permute(0, 1, 2, 4, 3, 5).reshape(1, 2, 16 * 64, 64 * 16) + return out_block + + +class LivePortraitNukeMotionExtractor(nn.Module): + """LivePortraitNukeMotionExtractor model for Nuke. + + Args: + model: The encoder model. + n: Depth Anything window list parameter. + """ + + def __init__(self, model) -> None: + """Initialize the model.""" + super().__init__() + self.model = model + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, c, h, w = x.shape + + kp_info = self.model(x) + + for k, v in kp_info.items(): + if isinstance(v, torch.Tensor): + kp_info[k] = v.float() + + bs = kp_info["kp"].shape[0] + + pitch = self.headpose_pred_to_degree(kp_info["pitch"]) # Bx1 + yaw = self.headpose_pred_to_degree(kp_info["yaw"]) # Bx1 + roll = self.headpose_pred_to_degree(kp_info["roll"]) # Bx1 + rot_mat = get_rotation_matrix(pitch, yaw, roll) # Bx3x3 + exp = kp_info["exp"] # .reshape(bs, -1, 3) # BxNx3 + kp = kp_info["kp"] # .reshape(bs, -1, 3) # BxNx3 + scale = kp_info["scale"] + t = kp_info["t"] + + if kp.ndim == 2: + num_kp = kp.shape[1] // 3 # Bx(num_kpx3) + else: + num_kp = kp.shape[1] # Bxnum_kpx3 + + kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3) + kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3) + kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty + + out = torch.zeros([b, 1, h, w], device=x.device) + out[0, 0, 0, :3] = torch.cat([pitch, yaw, roll], dim=0) + out[0, 0, 1, :9] = rot_mat.reshape(-1) + out[0, 0, 2, :63] = kp.reshape(-1) + out[0, 0, 3, :63] = kp_transformed.reshape(-1) + out[0, 0, 4, :63] = exp.reshape(-1) + out[0, 0, 5, :1] = scale.reshape(-1) + out[0, 0, 6, :3] = t.reshape(-1) + + return out.contiguous() + + def headpose_pred_to_degree(self, x: torch.Tensor) -> torch.Tensor: + idx_tensor = torch.arange(0, 66, device=x.device, dtype=torch.float32) + pred = F.softmax(x, dim=1) + degree = torch.sum(pred * idx_tensor, dim=1) * 3 - 97.5 + return degree + + +class LivePortraitNukeWarpingModule(nn.Module): + """LivePortraitNukeMotionExtractor model for Nuke. + + Args: + model: The encoder model. + n: Depth Anything window list parameter. + """ + + def __init__(self, model) -> None: + """Initialize the model.""" + super().__init__() + self.model = model + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Input tensor [1,2,1040,1024] - 2 channels image, 1040x1024 + # Split the tensor x into feature_3d, kp_source, and kp_driving + kp_source = x[:, 2, 0, :63].reshape(1, 21, 3).contiguous() + kp_driving = x[:, 2, 1, :63].reshape(1, 21, 3).contiguous() + feature_3d = x[:, :2, :, :] # .reshape(1, 32, 16, 64, 64) + feature_3d = ( + feature_3d.view(1, 2, 16, 64, 16, 64) + .permute(0, 1, 2, 4, 3, 5) + .contiguous() + .view(1, 32, 16, 64, 64) + .contiguous() + ) + + out_dct = self.model(feature_3d, kp_driving, kp_source) + out = out_dct["out"] + + assert out is not None + + out = out.view(1, 16, 16, 64, 64) + out = out.permute(0, 1, 3, 2, 4) + out = out.reshape(1, 1, 1024, 1024) + return out.contiguous() + + +class LivePortraitNukeSpadeGenerator(nn.Module): + """LivePortraitNukeMotionExtractor model for Nuke. + + Args: + model: The encoder model. + n: Depth Anything window list parameter. + """ + + def __init__(self, model) -> None: + """Initialize the model.""" + super().__init__() + self.model = model + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, c, h, w = x.shape + + # Input tensor [1, 1, 1040,1024] to [1, 256, 64, 64] + x = x.view(1, 16, 64, 16, 64) + x = x.permute(0, 1, 3, 2, 4).reshape(1, 256, 64, 64) + out = self.model(feature=x) + out = out.contiguous() + return out + + +class LivePortraitNukeStitchingModule(nn.Module): + """LivePortraitNukeMotionExtractor model for Nuke. + + Args: + model: The encoder model. + n: Depth Anything window list parameter. + """ + + def __init__(self, model) -> None: + """Initialize the model.""" + super().__init__() + self.model = model + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, c, h, w = x.shape + + kp_source = x[:, 0, 0, :63].reshape(1, 21, 3).contiguous() # 1x21x3 + kp_driving = x[:, 0, 1, :63].reshape(1, 21, 3).contiguous() # 1x21x3 + kp_driving_new = kp_driving.clone() + + bs, num_kp = kp_source.shape[:2] + + feat_stiching = self.concat_feat(kp_source, kp_driving) + delta = self.model(feat_stiching) # 1x65 + delta_exp = delta[..., : 3 * num_kp].reshape(bs, num_kp, 3) # 1x21x3 + delta_tx_ty = delta[..., 3 * num_kp : 3 * num_kp + 2].reshape(bs, 1, 2) # 1x1x2 + + kp_driving_new += delta_exp + kp_driving_new[..., :2] += delta_tx_ty + + out = torch.zeros([b, 1, h, w], device=x.device) + out[0, 0, 0, :63] = kp_driving_new.reshape(-1) + + return out.contiguous() + + def concat_feat(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor: + """ + kp_source: (bs, k, 3) + kp_driving: (bs, k, 3) + Return: (bs, 2k*3) + """ + bs_src = kp_source.shape[0] + bs_dri = kp_driving.shape[0] + assert bs_src == bs_dri, "batch size must be equal" + + feat = torch.cat([kp_source.view(bs_src, -1), kp_driving.view(bs_dri, -1)], dim=1) + return feat + + +# --- Tracing original models --- + + +def face_detection(): + """ + SFDDetector = getattr(__import__('custom_nodes.ComfyUI-LivePortraitKJ.face_alignment.detection.sfd.sfd_detector', fromlist=['']), 'SFDDetector') + sfd_detector = SFDDetector("cuda") + sf3d = sfd_detector.face_detector + sfd_detector_traced = torch.jit.script(sfd_detector.face_detector) + sfd_detector_traced.save("sfd_detector_traced.pt") + """ + sfd_detector_traced = torch.load("./pretrained_weights/sfd_detector_traced.pt") + return sfd_detector_traced + + def trace_appearance_feature_extractor(): + LOGGER.info("--- Tracing appearance_feature_extractor ---") appearance_feature_extractor = load_model( - ckpt_path="/mnt/x/1_projects/relight/LivePortrait/src/config/../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth", - model_config=model_config, + ckpt_path="./pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth", + model_config=MODEL_CONFIG, device=0, model_type="appearance_feature_extractor", ) @@ -73,17 +935,20 @@ def trace_appearance_feature_extractor(): appearance_feature_extractor = torch.jit.script(appearance_feature_extractor) LOGGER.info("Traced appearance_feature_extractor") - torch.jit.save(appearance_feature_extractor, "build/appearance_feature_extractor.pt") + destination = "./build/appearance_feature_extractor.pt" + torch.jit.save(appearance_feature_extractor, destination) + LOGGER.info("Model saved to: %s", destination) -trace_appearance_feature_extractor() # done + return appearance_feature_extractor def trace_motion_extractor(): - import src.modules.convnextv2 + LOGGER.info("--- Tracing motion_extractor ---") + motion_extractor = load_model( - ckpt_path="/mnt/x/1_projects/relight/LivePortrait/src/config/../../pretrained_weights/liveportrait/base_models/motion_extractor.pth", - model_config=model_config, + ckpt_path="./pretrained_weights/liveportrait/base_models/motion_extractor.pth", + model_config=MODEL_CONFIG, device=0, model_type="motion_extractor", ) @@ -91,20 +956,22 @@ def trace_motion_extractor(): with torch.no_grad(): motion_extractor.eval() motion_extractor = torch.jit.script(motion_extractor) - # model = src.modules.convnextv2.ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], nuk_kp=21) - # model = torch.jit.script(model.downsample_layers) LOGGER.info("Traced motion_extractor") - torch.jit.save(motion_extractor, "build/motion_extractor.pt") + destination = "./build/motion_extractor.pt" + torch.jit.save(motion_extractor, destination) + LOGGER.info("Model saved to: %s", destination) -trace_motion_extractor() + return motion_extractor def trace_warping_module(): + LOGGER.info("--- Tracing warping_module ---") + warping_module = load_model( - ckpt_path="/mnt/x/1_projects/relight/LivePortrait/src/config/../../pretrained_weights/liveportrait/base_models/warping_module.pth", - model_config=model_config, + ckpt_path="./pretrained_weights/liveportrait/base_models/warping_module.pth", + model_config=MODEL_CONFIG, device=0, model_type="warping_module", ) @@ -114,16 +981,20 @@ def trace_warping_module(): warping_module = torch.jit.script(warping_module) LOGGER.info("Traced warping_module") - torch.jit.save(warping_module, "build/warping_module.pt") + destination = "./build/warping_module.pt" + torch.jit.save(warping_module, destination) + LOGGER.info("Model saved to: %s", destination) -trace_warping_module() # done + return warping_module def trace_spade_generator(): + LOGGER.info("--- Tracing spade_generator ---") + spade_generator = load_model( ckpt_path="./pretrained_weights/liveportrait/base_models/spade_generator.pth", - model_config=model_config, + model_config=MODEL_CONFIG, device=0, model_type="spade_generator", ) @@ -133,223 +1004,188 @@ def trace_spade_generator(): spade_generator = torch.jit.script(spade_generator) LOGGER.info("Traced spade_generator") - torch.jit.save(spade_generator, "build/spade_generator.pt") + destination = "./build/spade_generator.pt" + torch.jit.save(spade_generator, destination) + LOGGER.info("Model saved to: %s", destination) -trace_spade_generator() # done + return spade_generator def trace_stitching_retargeting_module(): + LOGGER.info("--- Tracing stitching_retargeting_module ---") + stitching_retargeting_module = load_model( - ckpt_path="/mnt/x/1_projects/relight/LivePortrait/src/config/../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth", - model_config=model_config, + ckpt_path="./pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth", + model_config=MODEL_CONFIG, device=0, model_type="stitching_retargeting_module", ) with torch.no_grad(): - stitching = stitching_retargeting_module['stitching'].eval() - lip = stitching_retargeting_module['lip'].eval() - eye = stitching_retargeting_module['eye'].eval() + stitching = stitching_retargeting_module["stitching"].eval() + lip = stitching_retargeting_module["lip"].eval() + eye = stitching_retargeting_module["eye"].eval() stitching_trace = torch.jit.script(stitching) lip_trace = torch.jit.script(lip) eye_trace = torch.jit.script(eye) LOGGER.info("Traced stitching_retargeting_module") - torch.jit.save(stitching_trace, "build/stitching_retargeting_module_stitching.pt") - torch.jit.save(lip_trace, "build/stitching_retargeting_module_lip.pt") - torch.jit.save(eye_trace, "build/stitching_retargeting_module_eye.pt") + + destination = "./build/stitching_retargeting_module_stitching.pt" + torch.jit.save(stitching_trace, destination) + + destination = "./build/stitching_retargeting_module_eye.pt" + torch.jit.save(eye_trace, destination) + + destination = "./build/stitching_retargeting_module_lip.pt" + torch.jit.save(lip_trace, destination) + + LOGGER.info("Model saved to: %s", destination) + + return stitching_trace, lip_trace, eye_trace -trace_stitching_retargeting_module() # done - -""" - -from src.modules.util import SPADEResnetBlock -from torch.nn.utils.spectral_norm import spectral_norm, SpectralNorm -from torch.nn.utils.parametrizations import spectral_norm - -fin=512; fout=512; norm_G='spadespectralinstance'; label_nc=256; use_se=False; dilation=1 -fmiddle = min(fin, fout) -c = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation) -s = spectral_norm(c) -j = spectral_norm(c).eval() -torch.jit.script(s) -torch.jit.trace(s, torch.randn(1, 512, 64, 64)) -m = torch.jit.trace(j, torch.randn(1, 512, 64, 64)) - -with torch.no_grad(): - torch.jit.script(s) +# --- Tracing Nuke models --- -with torch.no_grad(): - torch.jit.trace(s, torch.randn(1, 512, 64, 64)) +def trace_face_detection_nuke(run_test=False): + sf3d_face_detection = torch.jit.load("./pretrained_weights/sfd_detector_traced.pt").cuda() + face_alignment = torch.jit.load( + "./pretrained_weights/from_kj/2DFAN4-cd938726ad.zip" + ).cuda() + + model = LivePortraitNukeFaceDetection( + face_detection=sf3d_face_detection, face_alignment=face_alignment + ) + + def test_forward(): + with torch.no_grad(): + m = torch.randn(1, 3, 720, 1280).cuda() + model.eval() + out = model(m) + LOGGER.info(out.shape) + + if run_test: + test_forward() + + model_traced = torch.jit.script(model) + destination = "./build/face_detection_nuke.pt" + torch.jit.save(model_traced, destination) + + LOGGER.info("Model saved to: %s", destination) +def trace_appearance_feature_extractor_nuke(run_test=False): + appearance_feature_extractor = trace_appearance_feature_extractor() + model = LivePortraitNukeAppearanceFeatureExtractor(model=appearance_feature_extractor) -sp = SPADEResnetBlock(fin=512, fout=512, norm_G="spadespectralinstance", label_nc=256, use_se=False, dilation=1) -sp.eval() -sp_traced = torch.jit.trace(sp, (torch.randn(1, 512, 64, 64), torch.randn(1, 256, 64, 64))) + def test_forward(): + with torch.no_grad(): + m = torch.randn(1, 3, 256, 256).cuda() + model.eval() + out = model(m) + LOGGER.info(out.shape) + + if run_test: + test_forward() + + model_traced = torch.jit.script(model) + + destination = "./build/appearance_feature_extractor_nuke.pt" + torch.jit.save(model_traced, destination) + LOGGER.info("Model saved to: %s", destination) -import torch -from torch import nn +def trace_motion_extractor_nuke(run_test=False): + motion_extractor = trace_motion_extractor() + model = LivePortraitNukeMotionExtractor(model=motion_extractor) + model_traced = torch.jit.script(model) -with torch.no_grad(): - c1 = torch.nn.utils.spectral_norm(torch.nn.Conv2d(1,2,3)) - torch.jit.script(c1) + def test_forward(): + with torch.no_grad(): + m = torch.randn(1, 3, 256, 256).cuda() + model.eval() + out = model(m) + LOGGER.info(out.shape) + + if run_test: + test_forward() + + destination = "./build/motion_extractor_nuke.pt" + torch.jit.save(model_traced, destination) + LOGGER.info("Model saved to: %s", destination) + return model_traced -c2 = torch.nn.utils.parametrizations.spectral_norm(torch.nn.Conv2d(1,2,3)) -torch.jit.script(c2) +def trace_warping_module_nuke(run_test=False): + warping_module = trace_warping_module() + model = LivePortraitNukeWarpingModule(model=warping_module) -with torch.no_grad(): - c2 = torch.nn.utils.parametrizations.spectral_norm(torch.nn.Conv2d(1,2,3)) - torch.jit.script(c2) + def test_forward(): + with torch.no_grad(): + m = torch.randn([1, 3, 1024, 1024], dtype=torch.float32, device="cuda") + model.eval() + out = model(m) + LOGGER.info(out.shape) -l = nn.Linear(20, 40) -l.weight.size() + if run_test: + test_forward() -m = torch.nn.utils.spectral_norm(nn.Linear(20, 40)) -m -m.weight_u.size() - ---- - -def modify_state_dict_inplace(model): - state_dict = model.state_dict() - keys_to_delete = [] - - for key in list(state_dict.keys()): - value = state_dict[key] - k = key.split(".") - - if len(k) == 3 and k[-1] == "weight" and k[-2] in ["conv_0", "conv_1"]: - # Register new parameters - model.register_parameter(f"{k[0]}.{k[-2]}.weight_orig", torch.nn.Parameter(value)) - model.register_parameter(f"{k[0]}.{k[-2]}.weight_u", torch.nn.Parameter(torch.zeros_like(value))) - model.register_parameter(f"{k[0]}.{k[-2]}.weight_v", torch.nn.Parameter(torch.zeros_like(value))) - - state_dict[f"{k[0]}.{k[-2]}.weight_orig"] = value - keys_to_delete.append(key) - state_dict[f"{k[0]}.{k[-2]}.weight_u"] = torch.zeros_like(value) - state_dict[f"{k[0]}.{k[-2]}.weight_v"] = torch.zeros_like(value) - - # for key in keys_to_delete: - # delattr(module, key) - - # Load the modified state_dict back into the model - model.load_state_dict(state_dict) - return model - -model = SPADEDecoder(**model_params).cuda(device) + model_traced = torch.jit.script(model) + destination = "./build/warping_module_nuke.pt" + torch.jit.save(model_traced, destination) + LOGGER.info("Model saved to: %s", destination) -modify_state_dict_inplace(model) -model.state_dict().keys() +def trace_spade_generator_nuke(run_test=False): + warping_module = trace_spade_generator() + model = LivePortraitNukeSpadeGenerator(model=warping_module) -model._parameters[name] + def test_forward(): + with torch.no_grad(): + m = torch.randn([1, 1, 1024, 1024], dtype=torch.float32, device="cuda") + model.eval() + out = model(m) + LOGGER.info(out.shape) -model = modify_state_dict_inplace(model) -model.state_dict().keys() + if run_test: + test_forward() -model.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage), strict=True) + model_traced = torch.jit.script(model) + destination = "./build/spade_generator_nuke.pt" + torch.jit.save(model_traced, destination) + LOGGER.info("Model saved to: %s", destination) +def trace_stitching_retargeting_module_nuke(run_test=False): + stitching_model, lip_model, eye_model = trace_stitching_retargeting_module() -modify_state_dict_inplace(model.state_dict()) + model = LivePortraitNukeStitchingModule(model=stitching_model) -modified_state_dict.keys() -model.register_parameter("aaalero", torch.nn.Parameter(torch.randn(1, 2, 3))) -model.state_dict = modified_state_dict -model = SPADEDecoder(**model_params).cuda(device) -model.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage), strict=True) + def test_forward(): + with torch.no_grad(): + m = torch.randn([1, 1, 64, 64], dtype=torch.float32, device="cuda") + model.eval() + out = model(m) + LOGGER.info(out.shape) -type(state_dict.keys()) -type(modified_state_dict) + if run_test: + test_forward() -missing_keys = ["G_middle_0.conv_0.weight_orig", "G_middle_0.conv_0.weight_u", "G_middle_0.conv_0.weight_v", "G_middle_0.conv_1.weight_orig", "G_middle_0.conv_1.weight_u", "G_middle_0.conv_1.weight_v", "G_middle_1.conv_0.weight_orig", "G_middle_1.conv_0.weight_u", "G_middle_1.conv_0.weight_v", "G_middle_1.conv_1.weight_orig", "G_middle_1.conv_1.weight_u", "G_middle_1.conv_1.weight_v", "G_middle_2.conv_0.weight_orig", "G_middle_2.conv_0.weight_u", "G_middle_2.conv_0.weight_v", "G_middle_2.conv_1.weight_orig", "G_middle_2.conv_1.weight_u", "G_middle_2.conv_1.weight_v", "G_middle_3.conv_0.weight_orig", "G_middle_3.conv_0.weight_u", "G_middle_3.conv_0.weight_v", "G_middle_3.conv_1.weight_orig", "G_middle_3.conv_1.weight_u", "G_middle_3.conv_1.weight_v", "G_middle_4.conv_0.weight_orig", "G_middle_4.conv_0.weight_u", "G_middle_4.conv_0.weight_v", "G_middle_4.conv_1.weight_orig", "G_middle_4.conv_1.weight_u", "G_middle_4.conv_1.weight_v", "G_middle_5.conv_0.weight_orig", "G_middle_5.conv_0.weight_u", "G_middle_5.conv_0.weight_v", "G_middle_5.conv_1.weight_orig", "G_middle_5.conv_1.weight_u", "G_middle_5.conv_1.weight_v", "up_0.conv_0.weight_orig", "up_0.conv_0.weight_u", "up_0.conv_0.weight_v", "up_0.conv_1.weight_orig", "up_0.conv_1.weight_u", "up_0.conv_1.weight_v", "up_0.conv_s.weight_orig", "up_0.conv_s.weight_u", "up_0.conv_s.weight_v", "up_1.conv_0.weight_orig", "up_1.conv_0.weight_u", "up_1.conv_0.weight_v", "up_1.conv_1.weight_orig", "up_1.conv_1.weight_u", "up_1.conv_1.weight_v", "up_1.conv_s.weight_orig", "up_1.conv_s.weight_u", "up_1.conv_s.weight_v"] - -missing_keys in list(modified_state_dict.keys()) + model_traced = torch.jit.script(model) + destination = "./build/stitching_retargeting_module_stitching_nuke.pt" + torch.jit.save(model_traced, destination) + LOGGER.info("Model saved to: %s", destination) -def apply(module: Module, name: str): - weight = module._parameters[name] +if __name__ == "__main__": - delattr(module, fn.name) - module.register_parameter(fn.name + "_orig", weight) - # We still need to assign weight back as fn.name because all sorts of - # things may assume that it exists, e.g., when initializing weights. - # However, we can't directly assign as it could be an nn.Parameter and - # gets added as a parameter. Instead, we register weight.data as a plain - # attribute. - setattr(module, fn.name, weight.data) - module.register_buffer(fn.name + "_u", u) - module.register_buffer(fn.name + "_v", v) - - module.register_forward_pre_hook(fn) - module._register_state_dict_hook(SpectralNormStateDictHook(fn)) - module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn)) - return fn - - -dims = [96, 192, 384, 768] -from typing import List - - -import torch -import torch.nn as nn -import torch.nn.functional as F - -class LayerNorm(nn.Module): - def __init__(self, normalized_shape: int, eps: float = 1e-6, data_format: str = "channels_last"): - super().__init__() - self.weight = nn.Parameter(torch.ones(normalized_shape)) - self.bias = nn.Parameter(torch.zeros(normalized_shape)) - self.eps = eps - self.data_format = data_format - self.normalized_shape = (normalized_shape,) - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.data_format == "channels_last": - return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - elif self.data_format == "channels_first": - u = x.mean(1, keepdim=True) - s = (x - u).pow(2).mean(1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.eps) - x = self.weight[:, None, None] * x + self.bias[:, None, None] - return x - else: - raise ValueError(f"Unsupported data_format: {self.data_format}") - def extra_repr(self) -> str: - return f'normalized_shape={self.normalized_shape}, ' \ - f'eps={self.eps}, data_format={self.data_format}' - - -class YourModule(nn.Module): - def __init__(self, in_chans: int, dims: List[int]): - super().__init__() - self.downsample_layers = nn.ModuleList() - stem = nn.Sequential( - nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), - LayerNorm(dims[0], eps=1e-6, data_format="channels_first") - ) - self.downsample_layers.append(stem) - for i in range(3): - downsample_layer = nn.Sequential( - LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), - nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), - ) - self.downsample_layers.append(downsample_layer) - def forward(self, x: torch.Tensor) -> torch.Tensor: - for layer in self.downsample_layers: - x = layer(x) - return x - - - -# Now try to script the entire module -model = YourModule(3, dims) -torch.jit.script(model) - -""" + run_test = True + trace_face_detection_nuke(run_test) + trace_appearance_feature_extractor_nuke(run_test) + trace_motion_extractor_nuke(run_test) + trace_warping_module_nuke(run_test) + trace_spade_generator_nuke(run_test) + trace_stitching_retargeting_module_nuke(run_test) diff --git a/src/utils/camera.py b/src/utils/camera.py index a3dd942..734c751 100644 --- a/src/utils/camera.py +++ b/src/utils/camera.py @@ -32,6 +32,8 @@ def get_rotation_matrix(pitch_, yaw_, roll_): """ the input is in degree """ # transform to radian + PI = np.pi + pitch = pitch_ / 180 * PI yaw = yaw_ / 180 * PI roll = roll_ / 180 * PI