diff --git a/cog.yaml b/cog.yaml index d4b1520..5f49853 100644 --- a/cog.yaml +++ b/cog.yaml @@ -29,4 +29,5 @@ build: - "matplotlib==3.9.0" - "imageio-ffmpeg==0.5.1" - "tyro==0.8.5" + - "gradio==3.48.0" predict: "predict.py:Predictor" diff --git a/predict.py b/predict.py index f0957b2..8aca41e 100644 --- a/predict.py +++ b/predict.py @@ -1,9 +1,10 @@ -from cog import BasePredictor, Input, Path +from cog import BasePredictor, Input, Path, File from src.config.argument_config import ArgumentConfig from src.config.inference_config import InferenceConfig from src.config.crop_config import CropConfig from src.live_portrait_pipeline import LivePortraitPipeline +import requests class Predictor(BasePredictor): @@ -16,14 +17,18 @@ class Predictor(BasePredictor): def predict( self, - image: Path = Input(description="Portrait image"), - driving_info: Path = Input( - description="driving video or template (.pkl format)" - ), + image: Path = Input(description="Portrait image") ) -> Path: """Run a single prediction on the model""" + video_path, _ = self.live_portrait_pipeline.execute( - ArgumentConfig(source_image=image, driving_info=driving_info, output_dir="/tmp/") + ArgumentConfig( + source_image=image, + driving_info="assets/examples/driving/d0.mp4", + output_dir="/tmp/", + flag_pasteback=False, + flag_do_crop=False, + ) ) return Path(video_path)