97 lines
3.1 KiB
Python
97 lines
3.1 KiB
Python
from chuangliangTool.db_base import Chuangliang_ad
|
||
from typing import Optional, Union, Dict
|
||
|
||
|
||
class GetBaiduToken:
|
||
# 定义两种查询方式的SQL模板
|
||
TOKEN_QUERY_BY_NAME = """
|
||
SELECT advertiser_name, access_token
|
||
FROM chuangliang_ad.media_account
|
||
FORCE INDEX(idx_media_type_advertiser_status_is_active)
|
||
WHERE media_type = 'baidu'
|
||
AND is_delete = 0
|
||
AND advertiser_name = %s
|
||
AND product_version = 0
|
||
LIMIT 1
|
||
"""
|
||
|
||
TOKEN_QUERY_BY_ID = """
|
||
SELECT advertiser_name, access_token
|
||
FROM chuangliang_ad.media_account
|
||
FORCE INDEX(idx_media_type_advertiser_status_is_active)
|
||
WHERE media_type = 'baidu'
|
||
AND is_delete = 0
|
||
AND media_account_id = %s
|
||
AND product_version = 0
|
||
LIMIT 1
|
||
"""
|
||
|
||
def __init__(self, identifier: Union[str, int]):
|
||
"""
|
||
初始化方法
|
||
:param identifier: 可以是广告主名称(str)或媒体账户ID(int/str)
|
||
"""
|
||
self.identifier = identifier
|
||
self.header = None # 缓存header数据避免重复查询
|
||
|
||
def get_header(self) -> Optional[Dict[str, str]]:
|
||
"""获取包含用户名和令牌的header字典,优先返回缓存值"""
|
||
if self.header is not None:
|
||
return self.header
|
||
|
||
try:
|
||
result = self._query_token()
|
||
if result:
|
||
advertiser_name, access_token = result[0]
|
||
self.header = {
|
||
"userName": advertiser_name,
|
||
"accessToken": access_token
|
||
}
|
||
return self.header
|
||
else:
|
||
print(f"No active token found for: {self.identifier}")
|
||
return None
|
||
except Exception as e:
|
||
print(f"Error fetching token: {str(e)}")
|
||
return None
|
||
|
||
def _query_token(self) -> list:
|
||
"""根据标识符类型执行参数化SQL查询"""
|
||
# 判断标识符类型并选择相应的SQL模板
|
||
if self._is_media_account_id():
|
||
sql = self.TOKEN_QUERY_BY_ID
|
||
param = int(self.identifier) # 确保转换为整数
|
||
else:
|
||
sql = self.TOKEN_QUERY_BY_NAME
|
||
param = self.identifier
|
||
|
||
return Chuangliang_ad.query_params(sql, (param,))
|
||
|
||
def _is_media_account_id(self) -> bool:
|
||
"""判断标识符是否符合media_account_id的特征"""
|
||
# 检查是否为11位纯数字
|
||
if isinstance(self.identifier, int):
|
||
return len(str(self.identifier)) == 11
|
||
|
||
if isinstance(self.identifier, str):
|
||
return self.identifier.isdigit() and len(self.identifier) == 11
|
||
|
||
return False
|
||
|
||
|
||
# 使用示例
|
||
if __name__ == "__main__":
|
||
# 使用广告主名称查询
|
||
# token_getter1 = GetBaiduToken("原生-SLG-乱世-安卓12A20KA00006")
|
||
# header1 = token_getter1.get_header()
|
||
# print(header1)
|
||
|
||
# 使用媒体账户ID查询(字符串形式)
|
||
token_getter2 = GetBaiduToken("12466757256")
|
||
header2 = token_getter2.get_header()
|
||
print(header2)
|
||
#
|
||
# # 使用媒体账户ID查询(整数形式)
|
||
# token_getter3 = GetBaiduToken(12466757256)
|
||
# header3 = token_getter3.get_header()
|
||
# print(header3) |