chore: format code

This commit is contained in:
guojianzhu 2024-07-12 14:34:28 +08:00
parent 89676189b9
commit c3f01d3f3b
3 changed files with 22 additions and 42 deletions

View File

@ -2,11 +2,11 @@
import os.path as osp
import tyro
import subprocess
from src.config.argument_config import ArgumentConfig
from src.config.inference_config import InferenceConfig
from src.config.crop_config import CropConfig
from src.live_portrait_pipeline import LivePortraitPipeline
import subprocess
def partial_fields(target_class, kwargs):
@ -37,19 +37,17 @@ def main():
raise ImportError(
"FFmpeg is not installed. Please install FFmpeg before running this script. https://ffmpeg.org/download.html"
)
# fast check the args
fast_check_args(args)
# specify configs for inference
inference_cfg = partial_fields(
InferenceConfig, args.__dict__
) # use attribute of args to initial InferenceConfig
crop_cfg = partial_fields(
CropConfig, args.__dict__
) # use attribute of args to initial CropConfig
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
live_portrait_pipeline = LivePortraitPipeline(
inference_cfg=inference_cfg, crop_cfg=crop_cfg
inference_cfg=inference_cfg,
crop_cfg=crop_cfg
)
# run

View File

@ -59,17 +59,19 @@ conda activate LivePortrait
pip install -r requirements.txt
```
Make sure your system has [FFmpeg](https://ffmpeg.org/)
**Note:** make sure your system has [FFmpeg](https://ffmpeg.org/) installed!
### 2. Download pretrained weights
Download the pretrained weights from HuggingFace:
The easiest way to download the pretrained weights is from HuggingFace:
```bash
# you may need to run `git lfs install` first
git clone https://huggingface.co/KwaiVGI/liveportrait pretrained_weights
```
Or, download all pretrained weights from [Google Drive](https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib) or [Baidu Yun](https://pan.baidu.com/s/1MGctWmNla_vZxDbEp2Dtzw?pwd=z5cn). We have packed all weights in one directory 😊. Unzip and place them in `./pretrained_weights` ensuring the directory structure is as follows:
Alternatively, you can download all pretrained weights from [Google Drive](https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib) or [Baidu Yun](https://pan.baidu.com/s/1MGctWmNla_vZxDbEp2Dtzw?pwd=z5cn). Unzip and place them in `./pretrained_weights`.
Ensuring the directory structure is as follows, or contains:
```text
pretrained_weights
├── insightface

View File

@ -4,12 +4,9 @@ import os.path as osp
from dataclasses import dataclass, field
from typing import List, Tuple, Union
import cv2
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
import numpy as np
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
from ..config.crop_config import CropConfig
from .crop import (
average_bbox_lst,
@ -17,10 +14,10 @@ from .crop import (
crop_image_by_bbox,
parse_bbox_from_landmark,
)
from .face_analysis_diy import FaceAnalysisDIY
from .io import contiguous
from .landmark_runner import LandmarkRunner
from .rprint import rlog as log
from .face_analysis_diy import FaceAnalysisDIY
from .landmark_runner import LandmarkRunner
def make_abs_path(fn):
@ -34,16 +31,9 @@ class Trajectory:
lmk_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list
bbox_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # bbox list
frame_rgb_lst: Union[Tuple, List, np.ndarray] = field(
default_factory=list
) # frame list
lmk_crop_lst: Union[Tuple, List, np.ndarray] = field(
default_factory=list
) # lmk list
frame_rgb_crop_lst: Union[Tuple, List, np.ndarray] = field(
default_factory=list
) # frame crop list
frame_rgb_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame list
lmk_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list
frame_rgb_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame crop list
class Cropper(object):
@ -93,9 +83,7 @@ class Cropper(object):
log("No face detected in the source image.")
return None
elif len(src_face) > 1:
log(
f"More than one face detected in the image, only pick one face by rule {crop_cfg.direction}."
)
log(f"More than one face detected in the image, only pick one face by rule {crop_cfg.direction}.")
# NOTE: temporarily only pick the first face, to support multiple face in the future
src_face = src_face[0]
@ -115,9 +103,7 @@ class Cropper(object):
ret_dct["lmk_crop"] = lmk
# update a 256x256 version for network input
ret_dct["img_crop_256x256"] = cv2.resize(
ret_dct["img_crop"], (256, 256), interpolation=cv2.INTER_AREA
)
ret_dct["img_crop_256x256"] = cv2.resize(ret_dct["img_crop"], (256, 256), interpolation=cv2.INTER_AREA)
ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / crop_cfg.dsize
return ret_dct
@ -137,9 +123,7 @@ class Cropper(object):
log(f"No face detected in the frame #{idx}")
continue
elif len(src_face) > 1:
log(
f"More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}."
)
log(f"More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.")
src_face = src_face[0]
lmk = src_face.landmark_2d_106
lmk = self.landmark_runner.run(frame_rgb, lmk)
@ -166,9 +150,7 @@ class Cropper(object):
global_bbox = average_bbox_lst(trajectory.bbox_lst)
for idx, (frame_rgb, lmk) in enumerate(
zip(trajectory.frame_rgb_lst, trajectory.lmk_lst)
):
for idx, (frame_rgb, lmk) in enumerate(zip(trajectory.frame_rgb_lst, trajectory.lmk_lst)):
ret_dct = crop_image_by_bbox(
frame_rgb,
global_bbox,
@ -201,9 +183,7 @@ class Cropper(object):
log(f"No face detected in the frame #{idx}")
raise Exception(f"No face detected in the frame #{idx}")
elif len(src_face) > 1:
log(
f"More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}."
)
log(f"More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.")
src_face = src_face[0]
lmk = src_face.landmark_2d_106
lmk = self.landmark_runner.run(frame_rgb_crop, lmk)