sent-file-fast-serve/core/storage.py
2024-11-30 19:49:52 +08:00

299 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# @Time : 2023/8/11 20:06
# @Author : Lan
# @File : storage.py
# @Software: PyCharm
from typing import Optional
import aiohttp
import asyncio
from pathlib import Path
import datetime
import io
import re
import sys
import aioboto3
from botocore.config import Config
from fastapi import HTTPException, Response, UploadFile
from core.response import APIResponse
from core.settings import data_root, settings
from apps.base.models import FileCodes
from core.utils import get_file_url
from fastapi.responses import FileResponse
class FileStorageInterface:
_instance: Optional['FileStorageInterface'] = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(FileStorageInterface, cls).__new__(cls, *args, **kwargs)
return cls._instance
async def save_file(self, file: UploadFile, save_path: str):
"""
保存文件
"""
raise NotImplementedError
async def delete_file(self, file_code: FileCodes):
"""
删除文件
"""
raise NotImplementedError
async def get_file_url(self, file_code: FileCodes):
"""
获取文件分享的url
如果服务不支持直接访问文件,可以通过服务器中转下载。
此时,此方法可以调用 utils.py 中的 `get_file_url` 方法获取服务器中转下载的url
"""
raise NotImplementedError
async def get_file_response(self, file_code: FileCodes):
"""
获取文件响应
如果服务不支持直接访问文件,则需要实现该方法,返回文件响应
其余情况,可以不实现该方法
"""
raise NotImplementedError
class SystemFileStorage(FileStorageInterface):
def __init__(self):
self.chunk_size = 256 * 1024
self.root_path = data_root
def _save(self, file, save_path):
with open(save_path, 'wb') as f:
chunk = file.read(self.chunk_size)
while chunk:
f.write(chunk)
chunk = file.read(self.chunk_size)
async def save_file(self, file: UploadFile, save_path: str):
save_path = self.root_path / save_path
if not save_path.parent.exists():
save_path.parent.mkdir(parents=True)
await asyncio.to_thread(self._save, file.file, save_path)
async def delete_file(self, file_code: FileCodes):
save_path = self.root_path / await file_code.get_file_path()
if save_path.exists():
save_path.unlink()
async def get_file_url(self, file_code: FileCodes):
return await get_file_url(file_code.code)
async def get_file_response(self, file_code: FileCodes):
file_path = self.root_path / await file_code.get_file_path()
if not file_path.exists():
return APIResponse(code=404, detail='文件已过期删除')
return FileResponse(file_path, filename=file_code.prefix + file_code.suffix)
class S3FileStorage(FileStorageInterface):
def __init__(self):
self.access_key_id = settings.s3_access_key_id
self.secret_access_key = settings.s3_secret_access_key
self.bucket_name = settings.s3_bucket_name
self.s3_hostname = settings.s3_hostname
self.region_name = settings.s3_region_name
self.signature_version = settings.s3_signature_version
self.endpoint_url = settings.s3_endpoint_url or f'https://{self.s3_hostname}'
self.aws_session_token = settings.aws_session_token
self.proxy = settings.s3_proxy
self.session = aioboto3.Session(aws_access_key_id=self.access_key_id, aws_secret_access_key=self.secret_access_key)
if not settings.s3_endpoint_url:
self.endpoint_url = f'https://{self.s3_hostname}'
else:
# 如果提供了 s3_endpoint_url则优先使用它
self.endpoint_url = settings.s3_endpoint_url
async def save_file(self, file: UploadFile, save_path: str):
async with self.session.client("s3", endpoint_url=self.endpoint_url, aws_session_token=self.aws_session_token, region_name=self.region_name,
config=Config(signature_version=self.signature_version)) as s3:
await s3.put_object(Bucket=self.bucket_name, Key=save_path, Body=await file.read(), ContentType=file.content_type)
async def delete_file(self, file_code: FileCodes):
async with self.session.client("s3", endpoint_url=self.endpoint_url, region_name=self.region_name, config=Config(signature_version=self.signature_version)) as s3:
await s3.delete_object(Bucket=self.bucket_name, Key=await file_code.get_file_path())
async def get_file_response(self, file_code: FileCodes):
try:
filename = file_code.prefix + file_code.suffix
async with self.session.client("s3", endpoint_url=self.endpoint_url, region_name=self.region_name, config=Config(signature_version=self.signature_version)) as s3:
link = await s3.generate_presigned_url('get_object', Params={'Bucket': self.bucket_name, 'Key': await file_code.get_file_path()}, ExpiresIn=3600)
tmp = io.BytesIO()
async with aiohttp.ClientSession() as session:
async with session.get(link) as resp:
tmp.write(await resp.read())
tmp.seek(0)
content = tmp.read()
tmp.close()
return Response(content, media_type="application/octet-stream", headers={"Content-Disposition": f'attachment; filename="{filename.encode("utf-8").decode("latin-1")}"'})
except Exception:
raise HTTPException(status_code=503, detail='服务代理下载异常,请稍后再试')
async def get_file_url(self, file_code: FileCodes):
if file_code.prefix == '文本分享':
return file_code.text
if self.proxy:
return await get_file_url(file_code.code)
else:
async with self.session.client("s3", endpoint_url=self.endpoint_url, region_name=self.region_name, config=Config(signature_version=self.signature_version)) as s3:
result = await s3.generate_presigned_url('get_object', Params={'Bucket': self.bucket_name, 'Key': await file_code.get_file_path()}, ExpiresIn=3600)
return result
class OneDriveFileStorage(FileStorageInterface):
def __init__(self):
try:
import msal
from office365.graph_client import GraphClient
from office365.runtime.client_request_exception import ClientRequestException
except ImportError:
raise ImportError('请先安装`msal`和`Office365-REST-Python-Client`')
self.msal = msal
self.domain = settings.onedrive_domain
self.client_id = settings.onedrive_client_id
self.username = settings.onedrive_username
self.password = settings.onedrive_password
self.proxy = settings.onedrive_proxy
self._ClientRequestException = ClientRequestException
try:
client = GraphClient(self.acquire_token_pwd)
self.root_path = client.me.drive.root.get_by_path(settings.onedrive_root_path).get().execute_query()
except ClientRequestException as e:
if e.code == 'itemNotFound':
client.me.drive.root.create_folder(settings.onedrive_root_path)
self.root_path = client.me.drive.root.get_by_path(settings.onedrive_root_path).get().execute_query()
else:
raise e
except Exception as e:
raise Exception('OneDrive验证失败请检查配置是否正确\n' + str(e))
def acquire_token_pwd(self):
authority_url = f'https://login.microsoftonline.com/{self.domain}'
app = self.msal.PublicClientApplication(
authority=authority_url,
client_id=self.client_id
)
result = app.acquire_token_by_username_password(username=self.username,
password=self.password,
scopes=['https://graph.microsoft.com/.default'])
return result
def _get_path_str(self, path):
if isinstance(path, str):
path = path.replace('\\', '/').replace('//', '/').split('/')
elif isinstance(path, Path):
path = str(path).replace('\\', '/').replace('//', '/').split('/')
else:
raise TypeError('path must be str or Path')
path[-1] = path[-1].split('.')[0]
return '/'.join(path)
def _save(self, file, save_path):
content = file.file.read()
name = file.filename
path = self._get_path_str(save_path)
self.root_path.get_by_path(path).upload(name, content).execute_query()
async def save_file(self, file: UploadFile, save_path: str):
await asyncio.to_thread(self._save, file, save_path)
def _delete(self, save_path):
path = self._get_path_str(save_path)
try:
self.root_path.get_by_path(path).delete_object().execute_query()
except self._ClientRequestException as e:
if e.code == 'itemNotFound':
pass
else:
raise e
async def delete_file(self, file_code: FileCodes):
await asyncio.to_thread(self._delete, await file_code.get_file_path())
def _convert_link_to_download_link(self, link):
p1 = re.search(r'https:\/\/(.+)\.sharepoint\.com', link).group(1)
p2 = re.search(r'personal\/(.+)\/', link).group(1)
p3 = re.search(rf'{p2}\/(.+)', link).group(1)
return f'https://{p1}.sharepoint.com/personal/{p2}/_layouts/52/download.aspx?share={p3}'
def _get_file_url(self, save_path, name):
path = self._get_path_str(save_path)
remote_file = self.root_path.get_by_path(path + '/' + name)
expiration_datetime = datetime.datetime.now(tz=datetime.timezone.utc) + datetime.timedelta(hours=1)
expiration_datetime = expiration_datetime.strftime("%Y-%m-%dT%H:%M:%SZ")
premission = remote_file.create_link("view", "anonymous", expiration_datetime=expiration_datetime).execute_query()
return self._convert_link_to_download_link(premission.link.webUrl)
async def get_file_response(self, file_code: FileCodes):
try:
filename = file_code.prefix + file_code.suffix
link = await asyncio.to_thread(self._get_file_url, await file_code.get_file_path(), filename)
tmp = io.BytesIO()
async with aiohttp.ClientSession() as session:
async with session.get(link) as resp:
tmp.write(await resp.read())
tmp.seek(0)
content = tmp.read()
tmp.close()
return Response(content, media_type="application/octet-stream", headers={"Content-Disposition": f'attachment; filename="{filename.encode("utf-8").decode("latin-1")}"'})
except Exception:
raise HTTPException(status_code=503, detail='服务代理下载异常,请稍后再试')
async def get_file_url(self, file_code: FileCodes):
if self.proxy:
return await get_file_url(file_code.code)
else:
return await asyncio.to_thread(self._get_file_url, await file_code.get_file_path(), f'{file_code.prefix}{file_code.suffix}')
class OpenDALFileStorage(FileStorageInterface):
def __init__(self):
try:
import opendal
except ImportError:
raise ImportError('请先安装 `opendal`, 例如: "pip install opendal"')
self.service = settings.opendal_scheme
service_settings = {}
for key, value in settings.items():
if key.startswith('opendal_' + self.service):
setting_name = key.split('_', 2)[2]
service_settings[setting_name] = value
self.operator = opendal.AsyncOperator(settings.opendal_scheme, **service_settings)
async def save_file(self, file: UploadFile, save_path: str):
await self.operator.write(save_path, file.file.read())
async def delete_file(self, file_code: FileCodes):
await self.operator.delete(await file_code.get_file_path())
async def get_file_url(self, file_code: FileCodes):
return await get_file_url(file_code.code)
async def get_file_response(self, file_code: FileCodes):
try:
filename = file_code.prefix + file_code.suffix
content = await self.operator.read(await file_code.get_file_path())
headers = {
"Content-Disposition": f'attachment; filename="{filename}"'
}
return Response(content, headers=headers, media_type="application/octet-stream")
except Exception as e:
print(e, file=sys.stderr)
raise HTTPException(status_code=404, detail="文件已过期删除")
storages = {
'local': SystemFileStorage,
's3': S3FileStorage,
'onedrive': OneDriveFileStorage,
'opendal': OpenDALFileStorage,
}