1import json
2from typing import AsyncIterator, Literal, Callable, Protocol, TypeVar, Generic
3
4import httpx
5
6from . import structs
7
8
9DOCT = TypeVar("DOCT")
10
11
[docs]
12class DocumentLoader(Protocol, Generic[DOCT]):
13 """
14 Protocol for things that go between JSON blobs and hydrated document
15 instances.
16 """
17
[docs]
18 def loadj(self, blob: dict) -> DOCT:
19 """
20 Convert a JSON blob into a document object.
21 """
22
[docs]
23 def dumpj(self, doc: DOCT) -> dict:
24 """
25 Convert a document into a JSON blob.
26 """
27
28
[docs]
29class DocumentRegistry:
30 """
31 Handles de/serialization, manages migrations, etc.
32
33 Default loader implementation.
34
35 Do not use directly. You probably want one of the :ref:`integrations`.
36 """
37
38 TYPE_KEY = ""
39
40 _docclasses = {}
41 _migrations = []
42
43 def __init_sublcass__(cls):
44 cls._docclasses = {}
45 cls._migrations = []
46
47 @classmethod
48 def _get_class_from_name(cls, name: str) -> type:
49 return cls._docclasses[name]
50
51 @classmethod
52 def _get_name_from_class(cls, klass: type) -> str:
53 for name, kind in cls._docclasses.items():
54 if issubclass(klass, kind): # In case of decorator shenanigans
55 return name
56 else:
57 raise ValueError(f"Couldn't find name for {klass}")
58
[docs]
59 @classmethod
60 def document(cls, name: str):
61 """
62 Register a class as a loadable couch document.
63
64 Args:
65 name: The type identifier to save to CouchDB
66
67 .. note::
68
69 The given name must be globally unique and must never change.
70 """
71 assert not isinstance(name, type)
72 assert name not in cls._docclasses
73
74 def _(klass: type):
75 cls._docclasses[name] = klass
76 return klass
77
78 return _
79
[docs]
80 @classmethod
81 def migration(cls, before: type, after: type):
82 """
83 Define a function that'll convert between documents.
84 """
85 # Normalize to the document classes previously registered
86 bname = cls._get_name_from_class(before)
87 aname = cls._get_name_from_class(after)
88 # Enforce linearity
89 assert not any(b == bname for b, _, _ in cls._migrations)
90
91 def _(func: Callable):
92 cls._migrations.append((bname, aname, func))
93 return func
94
95 return _
96
[docs]
97 def load_doc(self, cls: type, blob: dict):
98 """
99 Converts a JSON blob into a document.
100
101 Override me.
102 """
103 raise NotImplementedError
104
[docs]
105 def dump_doc(self, doc) -> dict:
106 """
107 Convert a document into a JSON blob.
108
109 Override me.
110 """
111 raise NotImplementedError
112
113 def _migrate(self, bname, doc):
114 while funcs := [f for b, _, f in self._migrations if b == bname]:
115 (func,) = funcs
116 doc = func(doc)
117 bname = self._get_name_from_class(type(doc))
118 return doc
119
120 def loadj(self, blob):
121 type = blob.pop(self.TYPE_KEY)
122 klass = self._get_class_from_name(type)
123 doc = self.load_doc(klass, blob)
124 doc = self._migrate(type, doc)
125 return doc
126
127 def dumpj(self, doc):
128 blob = self.dump_doc(doc)
129 blob[self.TYPE_KEY] = self._get_name_from_class(type(doc))
130 return blob
131
132
[docs]
133class Conflict(Exception):
134 """
135 There was a conflict when trying to perform the operation.
136 """
137
138
[docs]
139class Missing(Exception):
140 """
141 Could not find the requested document.
142
143 Note that this is a 404, not a tombstone.
144 """
145
146
[docs]
147class Deleted(Exception):
148 """
149 Requested a deleted document.
150
151 Note that this is a document with a tombstone, not a 404.
152 """
153
154
[docs]
155class CouchSession:
156 """
157 A connection to CouchDB.
158
159 You probably want to override like::
160
161 class MySession(CouchSession):
162 loader = MyRegistry
163
164 """
165
166 _client: httpx.AsyncClient
167 _root: httpx.URL
168
169 #: Class responsible for de/serializing data.
170 loader: type[DocumentLoader]
171
172 def __init__(self, client: httpx.AsyncClient, root: httpx.URL):
173 self._client = client
174 self._root = root
175
176 @staticmethod
177 def _fix_params(params):
178 rv = {}
179 for key, value in params.items():
180 if value is None:
181 continue
182 elif isinstance(value, str):
183 rv[key] = value
184 else:
185 rv[key] = json.dumps(value)
186 return rv
187
188 async def _request(self, method, *urlparts, **kwargs):
189 url = self._root.join("/".join(urlparts))
190 if "params" in kwargs:
191 kwargs["params"] = self._fix_params(kwargs["params"])
192 resp = await self._client.request(method, url, **kwargs)
193 try:
194 resp.raise_for_status()
195 except httpx.HTTPStatusError as exc:
196 exc.add_note(f"Body: {exc.response.text}")
197 match exc.response.status_code:
198 case 404:
199 raise Missing(f"Could not find {'/'.join(urlparts)}") from exc
200 case 409:
201 raise Conflict(f"Conflict updating {'/'.join(urlparts)}") from exc
202 case _:
203 raise
204 return resp
205
206 def __getitem__(self, key: str) -> "Database":
207 """
208 Gets a database.
209
210 (Does not actually check if it exists.)
211 """
212 return Database(self, key)
213
[docs]
214 async def get_db(self, dbname: str) -> "Database":
215 """
216 Gets a database. Checks if it exists.
217
218 See :http:head:`/{db}`
219 """
220 await self._request("HEAD", dbname)
221 return Database(self, dbname)
222
[docs]
223 async def create_db(
224 self,
225 dbname: str,
226 *,
227 shards: int | None = None,
228 replicas: int | None = None,
229 partitioned: bool | None = None,
230 ) -> "Database":
231 """
232 Create a database
233
234 See :http:post:`/{db}`
235 """
236 try:
237 await self._request(
238 "PUT",
239 dbname,
240 params={
241 "q": shards,
242 "n": replicas,
243 "partitioned": partitioned,
244 },
245 )
246 except httpx.HTTPStatusError as exc:
247 match exc.response.status_code:
248 case 412:
249 raise Conflict("Database already exists") from exc
250 case _:
251 raise
252 return Database(self, dbname)
253
[docs]
254 async def delete_db(self, dbname: str):
255 """
256 Delete a database.
257
258 See :http:delete:`/{db}`
259 """
260 await self._request("DELETE", dbname)
261
[docs]
262 async def iter_dbs(self) -> AsyncIterator[str]:
263 """
264 List all databases
265
266 See :http:get:`/_all_dbs`
267 """
268 resp = await self._request("GET", "_all_dbs")
269 for dbname in resp.json():
270 yield dbname
271
272 # TODO: Database metadata
273
274
[docs]
275class Database:
276 """
277 An individual database.
278 """
279
280 def __init__(self, session, name):
281 """
282 :private:
283 """
284 self._session = session
285 self._name = name
286
287 def _blob2doc(self, blob, db, docid, etag=...):
288 if docid is ...:
289 docid = blob["_id"]
290 if etag is ...:
291 etag = f'"{blob["_rev"]}"'
292 doc = self._session.loader().loadj(blob)
293 doc.__db = db
294 doc.__docid = docid
295 doc.__etag = etag
296 return doc
297
298 def _doc2blob(self, doc):
299 blob = self._session.loader().dumpj(doc)
300 db = docid = etag = None
301 try:
302 db = doc.__db
303 docid = doc.__docid
304 etag = doc.__etag
305 except AttributeError:
306 pass
307 return blob, db, docid, etag
308
[docs]
309 async def get(
310 self,
311 docid: str,
312 *,
313 attachments: bool = False,
314 conflicts: bool = False,
315 deleted_conflicts: bool = False,
316 latest: bool = False,
317 local_seq: bool = False,
318 meta: bool = False,
319 open_revs: list[str] | Literal["all"] | None = None,
320 rev: str | None = None,
321 revs: bool = False,
322 revs_info: bool = False,
323 ):
324 """
325 Get a document
326
327 See :http:get:`/{db}/{docid}`
328 """
329 resp = await self._session._request(
330 "GET",
331 self._name,
332 docid,
333 params={
334 "attachments": attachments,
335 "conflicts": conflicts,
336 "deleted_conflicts": deleted_conflicts,
337 "latest": latest,
338 "local_seq": local_seq,
339 "meta": meta,
340 "open_revs": open_revs,
341 "rev": rev,
342 "revs": revs,
343 "revs_info": revs_info,
344 },
345 headers={
346 "Accept": "application/json",
347 },
348 )
349
350 blob = resp.json()
351 if blob.get("_deleted", False): # TODO: Flag to override this
352 raise Deleted("Document {self._name}/{docid} is marked as deleted")
353 if "ETag" in resp.headers:
354 etag = resp.headers["ETag"]
355 else:
356 # Conflicts mode
357 etag = f'"{blob["_rev"]}"'
358 doc = self._blob2doc(blob, self._name, docid, etag)
359 return doc
360
361 # TODO: Attachments
362
[docs]
363 async def attempt_put(
364 self,
365 doc,
366 docid: str | None = None,
367 *,
368 batch: bool = False,
369 ):
370 """
371 Update a document.
372
373 db and docid only need to be given if it's a new document.
374
375 See :http:put:`/{db}/{docid}`
376 """
377 blob, _db, _docid, etag = self._doc2blob(doc)
378 assert _db is None or _db == self._name
379 await self._session._request(
380 "PUT",
381 self._name,
382 _docid or docid,
383 params={"batch": "ok"} if batch else {},
384 headers={"If-Match": etag} if etag else {},
385 json=blob,
386 )
387
[docs]
388 async def attempt_delete(self, doc, *, batch: bool = False):
389 """
390 Delete a document
391
392 See :http:delete:`/{db}/{docid}`
393 """
394 _, db, docid, etag = self._doc2blob(doc)
395 assert db == self._name
396 assert docid
397 await self._session._request(
398 "DELETE",
399 db,
400 docid,
401 params={"batch": "ok"} if batch else {},
402 headers={"If-Match": etag},
403 )
404
[docs]
405 async def attempt_copy(self, src_doc, dst_doc, *, batch: bool = False):
406 """
407 Copy a document
408
409 .. todo::
410
411 Implement
412
413 See :http:copy:`/{db}/{docid}`
414 """
415 # FIXME: Figure out signature
416
[docs]
417 async def mutate(self, docid: str) -> AsyncIterator:
418 """
419 A document mutation loop::
420
421 async for doc in couch.mutate_doc("spam"):
422 doc.foo = "bar"
423
424 Will replay the mutation until it goes through.
425 """
426 doc = await self.get(docid)
427 while True:
428 yield doc
429 try:
430 await self.attempt_put(doc)
431 except Conflict:
432 doc = await self.get(docid)
433 else:
434 break
435
[docs]
436 async def iter_all_docs(
437 self, include_docs: bool = False
438 ) -> AsyncIterator[structs.AllDocs_DocRef]:
439 """
440 List all documents
441
442 TODO: More params
443
444 Args:
445 include_docs: Pre-load documents
446
447 See :http:get:`/{db}/_all_docs`
448 """
449 resp = await self._session._request(
450 "GET",
451 self._name,
452 "_all_docs",
453 params={
454 "include_docs": include_docs,
455 },
456 headers={
457 "Accept": "application/json",
458 },
459 )
460 blob = resp.json()
461 for ref in blob["rows"]:
462 if "doc" in ref:
463 doc = self._blob2doc(ref["doc"], self._name, ref["id"])
464 else:
465 doc = None
466 yield structs.AllDocs_DocRef(
467 _db=self, docid=ref["id"], rev=ref["value"]["rev"], _doc=doc
468 )
469
470 # TODO: Mango searches
471 # TODO: Database operations
472
473
[docs]
474class SessionPool:
475 """
476 Responsible for giving out Couch connections.
477
478 You probably want to override like::
479
480 class MyPool(SessionPool):
481 session_class = MySession
482 """
483
484 _client: httpx.AsyncClient
485
486 #: Class to use for sessions
487 session_class: type[CouchSession]
488
489 def __init__(self):
490 super().__init__()
491 self._client = self.make_client()
492
[docs]
493 def make_client(self) -> httpx.AsyncClient:
494 """
495 Produce an httpx client.
496 """
497 return httpx.AsyncClient(http2=True, follow_redirects=True)
498
[docs]
499 async def iter_servers(self) -> AsyncIterator[str]:
500 """
501 Produce the list of potential servers.
502
503 Override this
504 """
505 raise NotImplementedError
506 for _ in ():
507 yield
508
509 async def _check_server(self, url: httpx.URL):
510 resp = await self._client.get(url.join("_up"))
511 return resp.is_success
512
[docs]
513 async def session(self) -> CouchSession:
514 """
515 Get a session
516 """
517 async for url in self.iter_servers():
518 url = httpx.URL(url)
519 if await self._check_server(url):
520 return self.session_class(self._client, url)