from chuangliangTool.db_base import Chuangliang_ad from typing import Optional, Union, Dict class GetBaiduToken: # 定义两种查询方式的SQL模板 TOKEN_QUERY_BY_NAME = """ SELECT advertiser_name, access_token FROM chuangliang_ad.media_account FORCE INDEX(idx_media_type_advertiser_status_is_active) WHERE media_type = 'baidu' AND is_delete = 0 AND advertiser_name = %s AND product_version = 0 LIMIT 1 """ TOKEN_QUERY_BY_ID = """ SELECT advertiser_name, access_token FROM chuangliang_ad.media_account FORCE INDEX(idx_media_type_advertiser_status_is_active) WHERE media_type = 'baidu' AND is_delete = 0 AND media_account_id = %s AND product_version = 0 LIMIT 1 """ def __init__(self, identifier: Union[str, int]): """ 初始化方法 :param identifier: 可以是广告主名称(str)或媒体账户ID(int/str) """ self.identifier = identifier self.header = None # 缓存header数据避免重复查询 def get_header(self) -> Optional[Dict[str, str]]: """获取包含用户名和令牌的header字典,优先返回缓存值""" if self.header is not None: return self.header try: result = self._query_token() if result: advertiser_name, access_token = result[0] self.header = { "userName": advertiser_name, "accessToken": access_token } return self.header else: print(f"No active token found for: {self.identifier}") return None except Exception as e: print(f"Error fetching token: {str(e)}") return None def _query_token(self) -> list: """根据标识符类型执行参数化SQL查询""" # 判断标识符类型并选择相应的SQL模板 if self._is_media_account_id(): sql = self.TOKEN_QUERY_BY_ID param = int(self.identifier) # 确保转换为整数 else: sql = self.TOKEN_QUERY_BY_NAME param = self.identifier return Chuangliang_ad.query_params(sql, (param,)) def _is_media_account_id(self) -> bool: """判断标识符是否符合media_account_id的特征""" # 检查是否为11位纯数字 if isinstance(self.identifier, int): return len(str(self.identifier)) == 11 if isinstance(self.identifier, str): return self.identifier.isdigit() and len(self.identifier) == 11 return False # 使用示例 if __name__ == "__main__": # 使用广告主名称查询 # token_getter1 = GetBaiduToken("原生-SLG-乱世-安卓12A20KA00006") # header1 = token_getter1.get_header() # print(header1) # 使用媒体账户ID查询(字符串形式) token_getter2 = GetBaiduToken("12466757256") header2 = token_getter2.get_header() print(header2) # # # 使用媒体账户ID查询(整数形式) # token_getter3 = GetBaiduToken(12466757256) # header3 = token_getter3.get_header() # print(header3)