Coverage for python/lsst/dax/apdb/cassandra/sessionFactory.py: 24%

98 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-05-28 01:36 -0700

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["SessionContext", "SessionFactory"] 

25 

26import logging 

27from collections.abc import Mapping 

28from contextlib import ExitStack 

29from typing import TYPE_CHECKING, Any 

30 

31# If cassandra-driver is not there the module can still be imported 

32# but ApdbCassandra cannot be instantiated. 

33try: 

34 import cassandra 

35 import cassandra.query 

36 from cassandra.auth import AuthProvider, PlainTextAuthProvider 

37 from cassandra.cluster import EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile, Session 

38 from cassandra.policies import AddressTranslator, RoundRobinPolicy, WhiteListRoundRobinPolicy 

39 

40 CASSANDRA_IMPORTED = True 

41except ImportError: 

42 CASSANDRA_IMPORTED = False 

43 

44from lsst.utils.db_auth import DbAuth, DbAuthNotFoundError 

45 

46from ..monitor import MonAgent 

47from ..timer import Timer 

48from .cassandra_utils import raw_data_factory 

49 

50if TYPE_CHECKING: 

51 from .config import ApdbCassandraConfig 

52 

53_LOG = logging.getLogger(__name__) 

54 

55_MON = MonAgent(__name__) 

56 

57 

58def _dump_query(rf: Any) -> None: 

59 """Dump cassandra query to debug log.""" 

60 _LOG.debug("Cassandra query: %s", rf.query) 

61 

62 

63if CASSANDRA_IMPORTED: 63 ↛ 65line 63 didn't jump to line 65 because the condition on line 63 was never true

64 

65 class _AddressTranslator(AddressTranslator): 

66 """Translate internal IP address to external. 

67 

68 Only used for docker-based setup, not a viable long-term solution. 

69 """ 

70 

71 def __init__(self, public_ips: tuple[str, ...], private_ips: tuple[str, ...]): 

72 self._map = dict(zip(private_ips, public_ips)) 

73 

74 def translate(self, private_ip: str) -> str: 

75 return self._map.get(private_ip, private_ip) 

76 

77 

78class SessionFactory: 

79 """Implementation of SessionFactory that uses parameters from Apdb 

80 configuration. 

81 

82 Parameters 

83 ---------- 

84 config : `ApdbCassandraConfig` 

85 Configuration object. 

86 """ 

87 

88 def __init__(self, config: ApdbCassandraConfig): 

89 self._config = config 

90 self._cluster: Cluster | None = None 

91 self._session: Session | None = None 

92 

93 def __del__(self) -> None: 

94 # Need to call Cluster.shutdown() to avoid warnings. 

95 if hasattr(self, "_cluster"): 

96 if self._cluster: 

97 self._cluster.shutdown() 

98 

99 def session(self) -> Session: 

100 """Return Cassandra Session, making new connection if necessary. 

101 

102 Returns 

103 ------- 

104 session : `cassandra.cluster.Sesion` 

105 Cassandra session object. 

106 """ 

107 if self._session is None: 

108 self._cluster, self._session = self._make_session() 

109 return self._session 

110 

111 def _make_session(self) -> tuple[Cluster, Session]: 

112 """Make Cassandra session. 

113 

114 Returns 

115 ------- 

116 cluster : `cassandra.cluster.Cluster` 

117 Cassandra Cluster object 

118 session : `cassandra.cluster.Session` 

119 Cassandra session object 

120 """ 

121 addressTranslator: AddressTranslator | None = None 

122 if self._config.connection_config.private_ips: 

123 addressTranslator = _AddressTranslator( 

124 self._config.contact_points, self._config.connection_config.private_ips 

125 ) 

126 

127 extra_parameters = { 

128 "idle_heartbeat_interval": 0, 

129 "idle_heartbeat_timeout": 30, 

130 "control_connection_timeout": 100, 

131 "executor_threads": 10, 

132 } 

133 extra_parameters.update(self._config.connection_config.extra_parameters) 

134 with Timer("cluster_connect", _MON): 

135 cluster = Cluster( 

136 execution_profiles=self._make_profiles(), 

137 contact_points=self._config.contact_points, 

138 port=self._config.connection_config.port, 

139 address_translator=addressTranslator, 

140 protocol_version=self._config.connection_config.protocol_version, 

141 auth_provider=self._make_auth_provider(), 

142 **extra_parameters, 

143 ) 

144 session = cluster.connect() 

145 

146 # Dump queries if debug level is enabled. 

147 if _LOG.isEnabledFor(logging.DEBUG): 

148 session.add_request_init_listener(_dump_query) 

149 

150 # Disable result paging 

151 session.default_fetch_size = None 

152 

153 return cluster, session 

154 

155 def _make_auth_provider(self) -> AuthProvider | None: 

156 """Make Cassandra authentication provider instance.""" 

157 try: 

158 dbauth = DbAuth() 

159 except DbAuthNotFoundError: 

160 # Credentials file doesn't exist, use anonymous login. 

161 return None 

162 

163 # If dbauth_alias is defined then try it first without port number. 

164 hosts: list[tuple[str, int | None]] = [ 

165 (hostname, self._config.connection_config.port) for hostname in self._config.contact_points 

166 ] 

167 if dbauth_alias := self._config.get_dbauth_alias(): 

168 hosts = [(dbauth_alias, None)] + hosts 

169 

170 empty_username = False 

171 # Try every contact point in turn. 

172 for hostname, port in hosts: 

173 try: 

174 username, password = dbauth.getAuth( 

175 "cassandra", 

176 self._config.connection_config.username, 

177 hostname, 

178 port, 

179 self._config.keyspace, 

180 ) 

181 if not username: 

182 # Password without user name, try next hostname, but give 

183 # warning later if no better match is found. 

184 empty_username = True 

185 else: 

186 return PlainTextAuthProvider(username=username, password=password) 

187 except DbAuthNotFoundError: 

188 pass 

189 

190 if empty_username: 

191 _LOG.warning( 

192 f"Credentials file ({dbauth.db_auth_path}) provided password but not " 

193 "user name, anonymous Cassandra logon will be attempted." 

194 ) 

195 

196 return None 

197 

198 def _make_profiles(self) -> Mapping[Any, ExecutionProfile]: 

199 """Make all execution profiles used in the code.""" 

200 config = self._config 

201 if config.connection_config.private_ips: 

202 loadBalancePolicy = WhiteListRoundRobinPolicy(hosts=config.contact_points) 

203 else: 

204 loadBalancePolicy = RoundRobinPolicy() 

205 

206 read_tuples_profile = ExecutionProfile( 

207 consistency_level=getattr(cassandra.ConsistencyLevel, config.connection_config.read_consistency), 

208 request_timeout=config.connection_config.read_timeout, 

209 row_factory=cassandra.query.tuple_factory, 

210 load_balancing_policy=loadBalancePolicy, 

211 ) 

212 read_named_tuples_profile = ExecutionProfile( 

213 consistency_level=getattr(cassandra.ConsistencyLevel, config.connection_config.read_consistency), 

214 request_timeout=config.connection_config.read_timeout, 

215 row_factory=cassandra.query.named_tuple_factory, 

216 load_balancing_policy=loadBalancePolicy, 

217 ) 

218 read_raw_profile = ExecutionProfile( 

219 consistency_level=getattr(cassandra.ConsistencyLevel, config.connection_config.read_consistency), 

220 request_timeout=config.connection_config.read_timeout, 

221 row_factory=raw_data_factory, 

222 load_balancing_policy=loadBalancePolicy, 

223 ) 

224 # Profile to use with select_concurrent to return raw data (columns and 

225 # rows) 

226 read_raw_multi_profile = ExecutionProfile( 

227 consistency_level=getattr(cassandra.ConsistencyLevel, config.connection_config.read_consistency), 

228 request_timeout=config.connection_config.read_timeout, 

229 row_factory=raw_data_factory, 

230 load_balancing_policy=loadBalancePolicy, 

231 ) 

232 # Profile to use with select_concurrent to return raw data, 

233 # this also has very long timeout, to be be use for querying 

234 # DiaObjectDedup table that can return a lot of data. 

235 read_raw_multi_dedup_profile = ExecutionProfile( 

236 consistency_level=getattr(cassandra.ConsistencyLevel, config.connection_config.read_consistency), 

237 request_timeout=3600.0, 

238 row_factory=raw_data_factory, 

239 load_balancing_policy=loadBalancePolicy, 

240 ) 

241 write_profile = ExecutionProfile( 

242 consistency_level=getattr(cassandra.ConsistencyLevel, config.connection_config.write_consistency), 

243 request_timeout=config.connection_config.write_timeout, 

244 load_balancing_policy=loadBalancePolicy, 

245 ) 

246 # To replace default DCAwareRoundRobinPolicy 

247 default_profile = ExecutionProfile( 

248 load_balancing_policy=loadBalancePolicy, 

249 ) 

250 return { 

251 "read_tuples": read_tuples_profile, 

252 "read_named_tuples": read_named_tuples_profile, 

253 "read_raw": read_raw_profile, 

254 "read_raw_multi": read_raw_multi_profile, 

255 "read_raw_multi_dedup": read_raw_multi_dedup_profile, 

256 "write": write_profile, 

257 EXEC_PROFILE_DEFAULT: default_profile, 

258 } 

259 

260 

261class SessionContext(ExitStack): 

262 """Context manager for creating short-lived Cassandra sessions. 

263 

264 Parameters 

265 ---------- 

266 config : `ApdbCassandraConfig` 

267 Configuration object. 

268 """ 

269 

270 def __init__(self, config: ApdbCassandraConfig): 

271 super().__init__() 

272 self._session_factory = SessionFactory(config) 

273 

274 def __enter__(self) -> Session: 

275 super().__enter__() 

276 cluster, session = self._session_factory._make_session() 

277 self.enter_context(cluster) 

278 self.enter_context(session) 

279 return session