From c2c099f5418c06f74f229a2adf5ca8679639a359 Mon Sep 17 00:00:00 2001 From: nkfyz Date: Fri, 8 Nov 2024 14:05:48 +0800 Subject: [PATCH] add server service --- server/Dockerfile | 41 +++++++++ server/client.py | 12 +++ server/consumer.py | 190 ++++++++++++++++++++++++++++++++++++++++ server/proxy_server.py | 114 ++++++++++++++++++++++++ server/requirements.txt | 2 + server/server.py | 156 +++++++++++++++++++++++++++++++++ 6 files changed, 515 insertions(+) create mode 100644 server/Dockerfile create mode 100644 server/client.py create mode 100644 server/consumer.py create mode 100644 server/proxy_server.py create mode 100644 server/requirements.txt create mode 100644 server/server.py diff --git a/server/Dockerfile b/server/Dockerfile new file mode 100644 index 0000000..e0ac237 --- /dev/null +++ b/server/Dockerfile @@ -0,0 +1,41 @@ +ARG CUDA_VERSION=12.4.1 +# **NOTE**: For 4090, change to 12.1.0 +# For H100, change to 12.4.1 + +FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 + +LABEL authors="luchentech" + +# ====== Install base dependencies =================== +ARG PYTHON_VERSION=3.10 +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update -y && \ + apt-get install -y software-properties-common && \ + add-apt-repository ppa:deadsnakes/ppa && \ + apt-get update -y && \ + apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv && \ + update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 && \ + update-alternatives --install /usr/bin/python python /usr/bin/python${PYTHON_VERSION} 1 + +# install libgl1 for import error described in https://stackoverflow.com/questions/55313610/ +# Alternative solution: pip install opencv-python-headless +RUN apt-get update -y && \ + apt-get install -y python3-pip git curl ffmpeg && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +RUN python3 --version && python --version + +# Install pip s.t. it will be compatible with our PYTHON_VERSION +RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} + +WORKDIR /workspace + +RUN python3 -m pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu121 + +RUN git clone https://github.com/KwaiVGI/LivePortrait.git && \ + cd LivePortrait && \ + python3 -m pip install -r requirements.txt + +RUN python3 -m pip install --no-cache-dir onnxruntime-gpu==1.17.1 --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple \ No newline at end of file diff --git a/server/client.py b/server/client.py new file mode 100644 index 0000000..1bf9615 --- /dev/null +++ b/server/client.py @@ -0,0 +1,12 @@ +mapping = { + "src_key": "../assets/examples/source/s5.jpg", + "driving_key": "../assets/examples/driving/d13.mp4", + } + +import requests, json + +base_url = "http://localhost:8005" +headers = {"accept": "application/json", "Content-Type": "application/json"} + +response = requests.post(f"{base_url}/submit", headers=headers, json=mapping) +print(response) diff --git a/server/consumer.py b/server/consumer.py new file mode 100644 index 0000000..a8cdd2d --- /dev/null +++ b/server/consumer.py @@ -0,0 +1,190 @@ +import argparse +import asyncio +import redis +import os +import time +import sys + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, parent_dir) + +from src.config.argument_config import ArgumentConfig +from src.config.inference_config import InferenceConfig +from src.config.crop_config import CropConfig +from src.live_portrait_pipeline import LivePortraitPipeline + +OUTPUT_LOCAL_PATH = os.getenv("OUTPUT_LOCAL_PATH", "./tmp/outputs/") + + +# aioredis +# redis + + +import tyro + +import logging + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +TASK_PREFIX = "lp-task-" # Prefix for task keys, used for retrieving data/status +TASK_TAG = "lp" +STREAM_NAME_PREFIX = os.getenv("REDIS_STREAM_PREFIX", "task_stream") + "_" +CONSUMER_GROUP_PREFIX = os.getenv("REDIS_GROUP_PREFIX", "task_group") + "_" + +LOCK_KEY_PREFIX = os.getenv("LOCK_KEY_PREFIX", "lock_key") + "_" + TASK_TAG +CONSUMER_NAME = os.getenv("HOSTNAME", "test_consumer_lp") + +live_portrait_engine = None +r = None + +def parse_args(): + parser = argparse.ArgumentParser(description="LivePortrait Redis Task Worker") + parser.add_argument("--redis-host", type=str, default="43.156.39.249", help="Redis host") + parser.add_argument("--redis-port", type=int, default=31317, help="Redis port") + return parser.parse_args() + + +def partial_fields(target_class, kwargs): + return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)}) + + +def init_live_portrait_pipeline(): + # fast_check_args(args) + + # specify configs for inference + inference_cfg = partial_fields(InferenceConfig, args.__dict__) + crop_cfg = partial_fields(CropConfig, args.__dict__) + + global live_portrait_pipeline + live_portrait_pipeline = LivePortraitPipeline( + inference_cfg=inference_cfg, + crop_cfg=crop_cfg + ) + + return live_portrait_pipeline + + +def task_worker( + stream_name: str, + consumer_group: str, + consumer_name: str, + task_prefix: str = "lp-task-", + lock_key_prefix: str = "lock_key_lp", + lock_timeout: int = 30, +): + logger.info(f"Starting task worker {consumer_name} in {consumer_group} for stream {stream_name}...") + + while True: + logger.info(f"Listening for tasks in {stream_name}...") + + tasks = r.xreadgroup(consumer_group, consumer_name, {stream_name: ">"}, count=1, block=0) + + if tasks: + for stream_name, task_list in tasks: + for task_id, task_data in task_list: + + serving_task_id = task_data.get("request_id") + lock_key = lock_key_prefix + serving_task_id + + try: + # 目前部署的redis - keydb是集群模式兼容单服务,存在多次消费,需要加锁 + # Try to acquire the lock + lock_acquired = r.set(lock_key, consumer_name, nx=True, ex=lock_timeout) + if not lock_acquired: + logger.info(f"Failed to acquire lock {lock_key}. Skipping...") + continue + logger.info( + f"Acquired lock {lock_key} (expiring in {lock_timeout} secs), processing tasks from stream {stream_name}" + f"on {r.connection_pool.connection_kwargs['host']}:{r.connection_pool.connection_kwargs['port']}" + ) + + occupied_by_pod = r.hmget(task_prefix + serving_task_id, "occupied_by_pod") + if not occupied_by_pod: + logger.warning( + f"Task occupied_by_pod of {serving_task_id} not found. Processing the task anyway." + ) + else: + occupied_by_pod = occupied_by_pod[0] + if occupied_by_pod and occupied_by_pod != consumer_name: + logger.error( + f"Task {serving_task_id} is already being processed by another consumer: {occupied_by_pod}. Skipping..." + ) + continue + # NOTE Important: mark the task as occupied by the current consumer + r.hmset(task_prefix + serving_task_id, {"occupied_by_pod": consumer_name}) + logger.info(f"Set task {serving_task_id} as occupied by {consumer_name}") + + task_payload = r.hgetall(task_prefix + serving_task_id) + logger.info(f"Processing task {serving_task_id}, task payload: {task_payload}") + if not task_payload: + raise RuntimeError(f"Task payload of {serving_task_id} not found") + + # # decoded_data = {key.decode('utf-8'): value.decode('utf-8') for key, value in task_payload.items()} + # task_payload = task_payload.get("payload") # actual payload from the request + # # res = await process_task(task_payload) + # # if not res: + # # raise RuntimeError(f"Error processing task {serving_task_id}") + # logger.info(f"Processing task {serving_task_id} with payload {task_payload}...") + args = tyro.cli(ArgumentConfig) + args.source = task_payload.get("src_key") + args.driving = task_payload.get("driving_key") + args.output_dir = OUTPUT_LOCAL_PATH + wfp, wfp_concat = live_portrait_pipeline.execute(args) + print(wfp, wfp_concat) + live_portrait_engine + + except Exception as e: + logger.error(f"{e}") + # Callback with failed status + finally: + # NOTE: messages are always acknowledged and deleted after the task finished or failed; + # We won't add back the task to the stream if it's failed + r.xack(stream_name, consumer_group, task_id) + r.xdel(stream_name, task_id) + + if r.get(lock_key) == consumer_name: + r.delete(lock_key) + logger.info(f"Released lock {lock_key} by consumer {consumer_name}") + + time.sleep(1.0) + +def run_worker(args): + + global live_portrait_engine, r + + redis_host, redis_port = args.redis_host, args.redis_port + # task_group_lp + consumer_group = CONSUMER_GROUP_PREFIX + TASK_TAG + # task_stream_lp + stream_name = STREAM_NAME_PREFIX + TASK_TAG + + r = redis.Redis(host=redis_host, port=redis_port, decode_responses=True) + logger.info(f"Creating consumer group {consumer_group} for stream {stream_name}") + try: + r.xgroup_create(stream_name, consumer_group, id="0", mkstream=True) + except redis.exceptions.ResponseError as e: + if "BUSYGROUP" in str(e): + logger.warning(f"Consumer group {consumer_group} already exists") + else: + raise e + + logger.info("Initializing LivePortrait engine...") + live_portrait_engine = init_live_portrait_pipeline() + logger.info("LivePortrait engine initialized") + + task_worker( + stream_name=stream_name, + consumer_group=consumer_group, + consumer_name=CONSUMER_NAME, + task_prefix=TASK_PREFIX, + lock_key_prefix=LOCK_KEY_PREFIX) + + + +if __name__ == "__main__": + args = parse_args() + + loop = asyncio.get_event_loop() + loop.run_until_complete(run_worker(args)) \ No newline at end of file diff --git a/server/proxy_server.py b/server/proxy_server.py new file mode 100644 index 0000000..8460813 --- /dev/null +++ b/server/proxy_server.py @@ -0,0 +1,114 @@ +import os +import argparse +import logging +from fastapi.security import APIKeyHeader +from fastapi import FastAPI, Depends +from typing import Optional, Tuple +import aioredis + +import uvicorn + +from pydantic import BaseModel, Field, root_validator +from starlette.responses import JSONResponse + +TASK_PREFIX = "lp-task-" # Prefix for task keys, used for retrieving data/status +TASK_TAG = "lp" +STREAM_NAME_PREFIX = os.getenv("REDIS_STREAM_PREFIX", "task_stream") + "_" +CONSUMER_GROUP_PREFIX = os.getenv("REDIS_GROUP_PREFIX", "task_group") + "_" +TASK_STATUS_EXPIRE = int(os.getenv("TASK_STATUS_EXPIRE", 43200)) # 12 hours by default + + +import logging + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +r = None + +async def _init_consumer_group(stream_name: str, group_name: str): + try: + await r.xgroup_create(stream_name, group_name, id="0", mkstream=True) + except aioredis.exceptions.ResponseError as e: + if "BUSYGROUP" in str(e): + logger.warning(f"Consumer group {group_name} already exists in {stream_name}") + else: + raise e + + +async def lifespan(app: FastAPI): + global r + r = aioredis.Redis(host=args.redis_host, port=args.redis_port, decode_responses=True) + + await _init_consumer_group(STREAM_NAME_PREFIX + TASK_TAG, CONSUMER_GROUP_PREFIX + TASK_TAG) + + yield + await r.close() + logger.info("Redis connection closed") + logger.info("Proxy server shut down") + +app = FastAPI(lifespan=lifespan) + + +class LivePortraitRequestProto(BaseModel): + '''Define the request schema for the Live Portrait API. + Live Portrait needs two inputs to generate the final output: + 1. The source image/video (src_key or src_local_path) + 2. The driving image/video (driving_key or driving_local_path) + This server will process source input according to the driving input. + Some characteristics of the driving input will be transferred and applied to the source input. + ''' + request_id: str = Field('test_request_id_0', description="Unique request ID", min_length=1) + src_key: Optional[str] = Field(None, description="Object key of the source image/video from COS", min_length=1) + driving_key: Optional[str] = Field(None, description="Object key of the driving image/video from COS", min_length=1) + + @root_validator(pre=True) + def check_either_image_key_or_local_path(cls, values): + src_key = values.get("src_key") + driving_key = values.get("driving_key") + if not src_key or not driving_key: + raise ValueError("The 'src_key' and 'driving_key' must be provided") + return values + + +@app.post("/submit") +async def submit(req: LivePortraitRequestProto): + mapping = { + "src_key": req.src_key, + "driving_key": req.driving_key, + } + + # Task data/status must be set first before adding to the queue + task_name = TASK_PREFIX + req.request_id + + logger.info(f"Recving task {task_name} with payload {mapping}") + + n = await r.hset(task_name, mapping=mapping) + await r.expire(task_name, TASK_STATUS_EXPIRE) + logger.info(f"{n} fields successfully added for {task_name} (expire in {TASK_STATUS_EXPIRE} secs)") + + # Example: task_stream_lp + queue_name = STREAM_NAME_PREFIX + TASK_TAG + redis_task_id = await r.xadd(queue_name, {"request_id": req.request_id}) + logger.info(f"Redis task {redis_task_id} added to Redis Stream {queue_name}") + + # trim the Redis stream + await r.xtrim(queue_name, maxlen=2000, approximate=True) + + return JSONResponse(content={"request_id": req.request_id, "message_id": redis_task_id}) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Proxy server for OpenSora serving") + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--redis-host", type=str, default="43.156.39.249", help="Redis host") + parser.add_argument("--redis-port", type=int, default=31317, help="Redis port") + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + + uvicorn.run(app, host=args.host, port=args.port) + logger.info(f"Proxy server is running on {args.host}:{args.port})") \ No newline at end of file diff --git a/server/requirements.txt b/server/requirements.txt new file mode 100644 index 0000000..66033a8 --- /dev/null +++ b/server/requirements.txt @@ -0,0 +1,2 @@ +aioredis +redis diff --git a/server/server.py b/server/server.py new file mode 100644 index 0000000..43f2ee6 --- /dev/null +++ b/server/server.py @@ -0,0 +1,156 @@ +import argparse +import logging +from fastapi.security import APIKeyHeader +import os +from fastapi import FastAPI, Depends +from typing import Optional, Tuple + +from src.config.argument_config import ArgumentConfig +from src.config.inference_config import InferenceConfig +from src.config.crop_config import CropConfig +from src.live_portrait_pipeline import LivePortraitPipeline +from pydantic import BaseModel, Field, root_validator +from starlette.responses import JSONResponse + +import tyro + +# Initialize logger +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# NOTE: For now, we use an API key to authenticate the requests. +# This method should only be used for testing purposes. +# For actual deployment, we should use OAuth2 or other more secure methods. +header_scheme = APIKeyHeader(name="X-API-Key") +SERVING_API_KEY = os.getenv("SERVING_API_KEY", "oLjQD5hWDYN5DeAQ4cx5CL3vJYOTXf0c") +OUTPUT_LOCAL_PATH = os.getenv("OUTPUT_LOCAL_PATH", "/tmp/outputs/") + +app = FastAPI() + +live_portrait_pipeline = None + + +class LivePortraitRequestProto(BaseModel): + '''Define the request schema for the Live Portrait API. + Live Portrait needs two inputs to generate the final output: + 1. The source image/video (src_key or src_local_path) + 2. The driving image/video (driving_key or driving_local_path) + This server will process source input according to the driving input. + Some characteristics of the driving input will be transferred and applied to the source input. + ''' + src_key: Optional[str] = Field(None, description="Object key of the source image/video from COS", min_length=1) + src_local_path: Optional[str] = Field( + None, description="Local file path of the source image/video to be processed on the server", min_length=1 + ) + + driving_key: Optional[str] = Field(None, description="Object key of the driving image/video from COS", min_length=1) + driving_local_path: Optional[str] = Field( + None, description="Local file path of the driving image/video to be processed on the server", min_length=1 + ) + + @root_validator(pre=True) + def check_either_image_key_or_local_path(cls, values): + src_key = values.get("src_key") + src_local_path = values.get("src_local_path") + driving_key = values.get("driving_key") + driving_local_path = values.get("driving_local_path") + if (src_key and src_local_path) or (not src_key and not src_local_path): + raise ValueError("Either 'src_key' or 'src_local_path' must be provided, but not both.") + if (driving_key and driving_local_path) or (not driving_key and not driving_local_path): + raise ValueError("Either 'driving_key' or 'driving_local_path' must be provided, but not both.") + return values + + +def init_live_portrait_pipeline(): + # fast_check_args(args) + + # specify configs for inference + inference_cfg = partial_fields(InferenceConfig, args.__dict__) + crop_cfg = partial_fields(CropConfig, args.__dict__) + + global live_portrait_pipeline + live_portrait_pipeline = LivePortraitPipeline( + inference_cfg=inference_cfg, + crop_cfg=crop_cfg + ) + + +def partial_fields(target_class, kwargs): + return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)}) + + +def extract_inputs_from_request(request: LivePortraitRequestProto) -> Tuple[str, str]: + + src_input = None + driving_input = None + + # 1. Download the image/video from COS to local storage + if request.src_key: + storage_dir = "/tmp/sources" + filename = os.path.basename(request.src_key) + + if request.driving_key: + storage_dir = "/tmp/driving" + filename = os.path.basename(request.driving_key) + + os.makedirs(storage_path, exist_ok=True) # Ensure the directory exists + storage_path = os.path.join(storage_dir, filename) + # Download the image/video from COS + # download_success = await cos_download_file(request.src_key, src_local_path) + download_success = True + if not download_success: + logger.error(f"Failed to download image: {request.image_key}") + return (None, None) + + # 2. Process the local image/video + # This step is just for local tests + if request.src_local_path: + logger.warning("Using local files is not intended for deployment. This is for testing purposes only.") + local_image_path = request.local_path + if not os.path.exists(local_image_path): + logger.error(f"Local image path does not exist: {local_image_path}") + return JSONResponse({"message": "Local image path does not exist", "faces_detected": 0}, status_code=500) + + if request.driving_local_path: + logger.warning("Using local files is not intended for deployment. This is for testing purposes only.") + local_image_path = request.local_path + if not os.path.exists(local_image_path): + logger.error(f"Local image path does not exist: {local_image_path}") + return JSONResponse({"message": "Local image path does not exist", "faces_detected": 0}, status_code=500) + +@app.post("/live-portrait") +def live_portrait(request: LivePortraitRequestProto, api_key: str = Depends(header_scheme)): + if live_portrait_pipeline is None: + return JSONResponse({"message": "Server not ready", "output_path": ''}, status_code=503) + if api_key != SERVING_API_KEY: + return JSONResponse({"message": "Invalid API key", "output_path": ''}, status_code=401) + + src_input = None + driving_input = None + + src_input, driving_input = extract_inputs_from_request(request) + if src_input is None or driving_input is None: + return JSONResponse({"message": "Failed to process inputs", "output_path": ''}, status_code=500) + + args = tyro.cli(ArgumentConfig) + args.source = src_input + args.driving = driving_input + args.output_dir = OUTPUT_LOCAL_PATH + wfp, wfp_concat = live_portrait_pipeline.execute(src_input, driving_input) + print(wfp, wfp_concat) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Live Portrait Server") + parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on") + parser.add_argument("--port", type=int, default=8000, help="Port to run the server on") + args = parser.parse_args() + + # init_cos_client() + # logger.info("COS client initialized") + + init_live_portrait_pipeline() + + import uvicorn + + uvicorn.run(app, host=args.host, port=args.port) \ No newline at end of file