mirror of
https://github.com/KwaiVGI/LivePortrait.git
synced 2024-12-22 20:42:38 +00:00
chore: format code
This commit is contained in:
parent
89676189b9
commit
c3f01d3f3b
14
inference.py
14
inference.py
@ -2,11 +2,11 @@
|
|||||||
|
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import tyro
|
import tyro
|
||||||
|
import subprocess
|
||||||
from src.config.argument_config import ArgumentConfig
|
from src.config.argument_config import ArgumentConfig
|
||||||
from src.config.inference_config import InferenceConfig
|
from src.config.inference_config import InferenceConfig
|
||||||
from src.config.crop_config import CropConfig
|
from src.config.crop_config import CropConfig
|
||||||
from src.live_portrait_pipeline import LivePortraitPipeline
|
from src.live_portrait_pipeline import LivePortraitPipeline
|
||||||
import subprocess
|
|
||||||
|
|
||||||
|
|
||||||
def partial_fields(target_class, kwargs):
|
def partial_fields(target_class, kwargs):
|
||||||
@ -37,19 +37,17 @@ def main():
|
|||||||
raise ImportError(
|
raise ImportError(
|
||||||
"FFmpeg is not installed. Please install FFmpeg before running this script. https://ffmpeg.org/download.html"
|
"FFmpeg is not installed. Please install FFmpeg before running this script. https://ffmpeg.org/download.html"
|
||||||
)
|
)
|
||||||
|
|
||||||
# fast check the args
|
# fast check the args
|
||||||
fast_check_args(args)
|
fast_check_args(args)
|
||||||
|
|
||||||
# specify configs for inference
|
# specify configs for inference
|
||||||
inference_cfg = partial_fields(
|
inference_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
|
||||||
InferenceConfig, args.__dict__
|
crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
|
||||||
) # use attribute of args to initial InferenceConfig
|
|
||||||
crop_cfg = partial_fields(
|
|
||||||
CropConfig, args.__dict__
|
|
||||||
) # use attribute of args to initial CropConfig
|
|
||||||
|
|
||||||
live_portrait_pipeline = LivePortraitPipeline(
|
live_portrait_pipeline = LivePortraitPipeline(
|
||||||
inference_cfg=inference_cfg, crop_cfg=crop_cfg
|
inference_cfg=inference_cfg,
|
||||||
|
crop_cfg=crop_cfg
|
||||||
)
|
)
|
||||||
|
|
||||||
# run
|
# run
|
||||||
|
@ -59,17 +59,19 @@ conda activate LivePortrait
|
|||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
Make sure your system has [FFmpeg](https://ffmpeg.org/)
|
**Note:** make sure your system has [FFmpeg](https://ffmpeg.org/) installed!
|
||||||
|
|
||||||
### 2. Download pretrained weights
|
### 2. Download pretrained weights
|
||||||
|
|
||||||
Download the pretrained weights from HuggingFace:
|
The easiest way to download the pretrained weights is from HuggingFace:
|
||||||
```bash
|
```bash
|
||||||
# you may need to run `git lfs install` first
|
# you may need to run `git lfs install` first
|
||||||
git clone https://huggingface.co/KwaiVGI/liveportrait pretrained_weights
|
git clone https://huggingface.co/KwaiVGI/liveportrait pretrained_weights
|
||||||
```
|
```
|
||||||
|
|
||||||
Or, download all pretrained weights from [Google Drive](https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib) or [Baidu Yun](https://pan.baidu.com/s/1MGctWmNla_vZxDbEp2Dtzw?pwd=z5cn). We have packed all weights in one directory 😊. Unzip and place them in `./pretrained_weights` ensuring the directory structure is as follows:
|
Alternatively, you can download all pretrained weights from [Google Drive](https://drive.google.com/drive/folders/1UtKgzKjFAOmZkhNK-OYT0caJ_w2XAnib) or [Baidu Yun](https://pan.baidu.com/s/1MGctWmNla_vZxDbEp2Dtzw?pwd=z5cn). Unzip and place them in `./pretrained_weights`.
|
||||||
|
|
||||||
|
Ensuring the directory structure is as follows, or contains:
|
||||||
```text
|
```text
|
||||||
pretrained_weights
|
pretrained_weights
|
||||||
├── insightface
|
├── insightface
|
||||||
|
@ -4,12 +4,9 @@ import os.path as osp
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
import cv2
|
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
cv2.setNumThreads(0)
|
|
||||||
cv2.ocl.setUseOpenCL(False)
|
|
||||||
|
|
||||||
from ..config.crop_config import CropConfig
|
from ..config.crop_config import CropConfig
|
||||||
from .crop import (
|
from .crop import (
|
||||||
average_bbox_lst,
|
average_bbox_lst,
|
||||||
@ -17,10 +14,10 @@ from .crop import (
|
|||||||
crop_image_by_bbox,
|
crop_image_by_bbox,
|
||||||
parse_bbox_from_landmark,
|
parse_bbox_from_landmark,
|
||||||
)
|
)
|
||||||
from .face_analysis_diy import FaceAnalysisDIY
|
|
||||||
from .io import contiguous
|
from .io import contiguous
|
||||||
from .landmark_runner import LandmarkRunner
|
|
||||||
from .rprint import rlog as log
|
from .rprint import rlog as log
|
||||||
|
from .face_analysis_diy import FaceAnalysisDIY
|
||||||
|
from .landmark_runner import LandmarkRunner
|
||||||
|
|
||||||
|
|
||||||
def make_abs_path(fn):
|
def make_abs_path(fn):
|
||||||
@ -34,16 +31,9 @@ class Trajectory:
|
|||||||
lmk_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list
|
lmk_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list
|
||||||
bbox_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # bbox list
|
bbox_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # bbox list
|
||||||
|
|
||||||
frame_rgb_lst: Union[Tuple, List, np.ndarray] = field(
|
frame_rgb_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame list
|
||||||
default_factory=list
|
lmk_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list
|
||||||
) # frame list
|
frame_rgb_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame crop list
|
||||||
|
|
||||||
lmk_crop_lst: Union[Tuple, List, np.ndarray] = field(
|
|
||||||
default_factory=list
|
|
||||||
) # lmk list
|
|
||||||
frame_rgb_crop_lst: Union[Tuple, List, np.ndarray] = field(
|
|
||||||
default_factory=list
|
|
||||||
) # frame crop list
|
|
||||||
|
|
||||||
|
|
||||||
class Cropper(object):
|
class Cropper(object):
|
||||||
@ -93,9 +83,7 @@ class Cropper(object):
|
|||||||
log("No face detected in the source image.")
|
log("No face detected in the source image.")
|
||||||
return None
|
return None
|
||||||
elif len(src_face) > 1:
|
elif len(src_face) > 1:
|
||||||
log(
|
log(f"More than one face detected in the image, only pick one face by rule {crop_cfg.direction}.")
|
||||||
f"More than one face detected in the image, only pick one face by rule {crop_cfg.direction}."
|
|
||||||
)
|
|
||||||
|
|
||||||
# NOTE: temporarily only pick the first face, to support multiple face in the future
|
# NOTE: temporarily only pick the first face, to support multiple face in the future
|
||||||
src_face = src_face[0]
|
src_face = src_face[0]
|
||||||
@ -115,9 +103,7 @@ class Cropper(object):
|
|||||||
ret_dct["lmk_crop"] = lmk
|
ret_dct["lmk_crop"] = lmk
|
||||||
|
|
||||||
# update a 256x256 version for network input
|
# update a 256x256 version for network input
|
||||||
ret_dct["img_crop_256x256"] = cv2.resize(
|
ret_dct["img_crop_256x256"] = cv2.resize(ret_dct["img_crop"], (256, 256), interpolation=cv2.INTER_AREA)
|
||||||
ret_dct["img_crop"], (256, 256), interpolation=cv2.INTER_AREA
|
|
||||||
)
|
|
||||||
ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / crop_cfg.dsize
|
ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / crop_cfg.dsize
|
||||||
|
|
||||||
return ret_dct
|
return ret_dct
|
||||||
@ -137,9 +123,7 @@ class Cropper(object):
|
|||||||
log(f"No face detected in the frame #{idx}")
|
log(f"No face detected in the frame #{idx}")
|
||||||
continue
|
continue
|
||||||
elif len(src_face) > 1:
|
elif len(src_face) > 1:
|
||||||
log(
|
log(f"More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.")
|
||||||
f"More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}."
|
|
||||||
)
|
|
||||||
src_face = src_face[0]
|
src_face = src_face[0]
|
||||||
lmk = src_face.landmark_2d_106
|
lmk = src_face.landmark_2d_106
|
||||||
lmk = self.landmark_runner.run(frame_rgb, lmk)
|
lmk = self.landmark_runner.run(frame_rgb, lmk)
|
||||||
@ -166,9 +150,7 @@ class Cropper(object):
|
|||||||
|
|
||||||
global_bbox = average_bbox_lst(trajectory.bbox_lst)
|
global_bbox = average_bbox_lst(trajectory.bbox_lst)
|
||||||
|
|
||||||
for idx, (frame_rgb, lmk) in enumerate(
|
for idx, (frame_rgb, lmk) in enumerate(zip(trajectory.frame_rgb_lst, trajectory.lmk_lst)):
|
||||||
zip(trajectory.frame_rgb_lst, trajectory.lmk_lst)
|
|
||||||
):
|
|
||||||
ret_dct = crop_image_by_bbox(
|
ret_dct = crop_image_by_bbox(
|
||||||
frame_rgb,
|
frame_rgb,
|
||||||
global_bbox,
|
global_bbox,
|
||||||
@ -201,9 +183,7 @@ class Cropper(object):
|
|||||||
log(f"No face detected in the frame #{idx}")
|
log(f"No face detected in the frame #{idx}")
|
||||||
raise Exception(f"No face detected in the frame #{idx}")
|
raise Exception(f"No face detected in the frame #{idx}")
|
||||||
elif len(src_face) > 1:
|
elif len(src_face) > 1:
|
||||||
log(
|
log(f"More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.")
|
||||||
f"More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}."
|
|
||||||
)
|
|
||||||
src_face = src_face[0]
|
src_face = src_face[0]
|
||||||
lmk = src_face.landmark_2d_106
|
lmk = src_face.landmark_2d_106
|
||||||
lmk = self.landmark_runner.run(frame_rgb_crop, lmk)
|
lmk = self.landmark_runner.run(frame_rgb_crop, lmk)
|
||||||
|
Loading…
Reference in New Issue
Block a user