1import json
2import typing
3from typing import (
4 AsyncIterator,
5 Literal,
6 Callable,
7 Protocol,
8 TypeVar,
9 Generic,
10 ClassVar,
11 Any,
12)
13import warnings
14
15import httpx
16
17from . import structs
18
19
20DOCT = TypeVar("DOCT")
21
22# 3.12: type TypeIDType = ...
23# Technically, anything JSONable is allowable, but only allowing atomic types
24# makes a bunch of reasoning easier.
25#: The type of class identifiers
26TypeIDType = str | int | bool | None
27
28
[docs]
29class DocumentLoader(Protocol, Generic[DOCT]):
30 """
31 Protocol for things that go between JSON blobs and hydrated document
32 instances.
33 """
34
[docs]
35 def load_from_blob(self, blob: dict) -> DOCT:
36 """
37 Convert a JSON blob into a document object.
38 """
39
[docs]
40 def dump_to_blob(self, doc: DOCT) -> dict:
41 """
42 Convert a document into a JSON blob.
43 """
44
[docs]
45 def update_doc(self, doc: DOCT, **fields):
46 """
47 Update a document object in-place.
48 """
49
50
[docs]
51class DocumentRegistry:
52 """
53 Handles de/serialization, manages migrations, etc.
54
55 Default loader implementation.
56
57 Do not use directly. You probably want one of the :ref:`integrations`.
58 """
59
60 TYPE_KEY = ""
61
62 _docclasses: ClassVar[dict[TypeIDType, type]] = {}
63 _migrations: ClassVar[list[tuple[TypeIDType, TypeIDType, Callable]]] = []
64
65 def __init_sublcass__(cls):
66 cls._docclasses = {}
67 cls._migrations = []
68
69 @classmethod
70 def _get_class_from_name(cls, name: TypeIDType) -> type:
71 return cls._docclasses[name]
72
73 @classmethod
74 def _get_name_from_class(cls, klass: type) -> TypeIDType:
75 for name, kind in cls._docclasses.items():
76 if issubclass(klass, kind): # In case of decorator shenanigans
77 return name
78 else:
79 raise ValueError(f"Couldn't find name for {klass}")
80
[docs]
81 @classmethod
82 def document(cls, name: TypeIDType):
83 """
84 Register a class as a loadable couch document.
85
86 Args:
87 name: The type identifier to save to CouchDB
88
89 .. note::
90
91 The given name must be globally unique and must never change.
92 """
93 assert not isinstance(name, type)
94 assert name not in cls._docclasses
95
96 def _(klass: type):
97 cls._docclasses[name] = klass
98 return klass
99
100 return _
101
[docs]
102 @classmethod
103 def migration(cls, before: type, after: type):
104 """
105 Define a function that'll convert between documents.
106 """
107 # Normalize to the document classes previously registered
108 bname = cls._get_name_from_class(before)
109 aname = cls._get_name_from_class(after)
110 # Enforce linearity
111 assert not any(b == bname for b, _, _ in cls._migrations)
112
113 def _(func: Callable):
114 cls._migrations.append((bname, aname, func))
115 return func
116
117 return _
118
[docs]
119 def load_doc(self, cls: type, blob: dict):
120 """
121 Converts a JSON blob into a document.
122
123 Override me.
124 """
125 raise NotImplementedError
126
[docs]
127 def dump_doc(self, doc) -> dict:
128 """
129 Convert a document into a JSON blob.
130
131 Override me.
132 """
133 raise NotImplementedError
134
[docs]
135 def update_doc(self, doc, **fields):
136 """
137 Update a doc in-place
138 """
139 raise NotImplementedError
140
141 def _migrate(self, bname, doc):
142 while funcs := [f for b, _, f in self._migrations if b == bname]:
143 (func,) = funcs
144 doc = func(doc)
145 bname = self._get_name_from_class(type(doc))
146 return doc
147
[docs]
148 def load_from_blob(self, blob):
149 type = blob.pop(self.TYPE_KEY)
150 klass = self._get_class_from_name(type)
151 doc = self.load_doc(klass, blob)
152 doc = self._migrate(type, doc)
153 return doc
154
[docs]
155 def dump_to_blob(self, doc):
156 blob = self.dump_doc(doc)
157 blob[self.TYPE_KEY] = self._get_name_from_class(type(doc))
158 return blob
159
160
[docs]
161class Conflict(Exception):
162 """
163 There was a conflict when trying to perform the operation.
164 """
165
166
[docs]
167class Missing(Exception):
168 """
169 Could not find the requested document.
170
171 Note that this is a 404, not a tombstone.
172 """
173
174
[docs]
175class Deleted(Exception):
176 """
177 Requested a deleted document.
178
179 Note that this is a document with a tombstone, not a 404.
180 """
181
182
[docs]
183class TooManyResults(Exception):
184 """
185 Requested a single document in a find operation and found more than one.
186
187 Note this is based on number of results. Request was successful.
188 """
189
190
[docs]
191class FindWarning(UserWarning):
192 """
193 Warnings reported by the CouchDB _find endpoint.
194 """
195
196
[docs]
197class CouchSession:
198 """
199 A connection to CouchDB.
200
201 You probably want to override like::
202
203 class MySession(CouchSession):
204 loader = MyRegistry
205
206 """
207
208 _client: httpx.AsyncClient
209 _root: httpx.URL
210
211 #: Class responsible for de/serializing data.
212 loader: type[DocumentLoader]
213
214 def __init__(self, client: httpx.AsyncClient, root: httpx.URL):
215 self._client = client
216 self._root = root
217
218 @staticmethod
219 def _fix_params(params):
220 rv = {}
221 for key, value in params.items():
222 if value is None:
223 continue
224 elif isinstance(value, str):
225 rv[key] = value
226 else:
227 rv[key] = json.dumps(value)
228 return rv
229
230 async def _request(self, method, *urlparts, **kwargs):
231 url = self._root.join("/".join(urlparts))
232 if "params" in kwargs:
233 kwargs["params"] = self._fix_params(kwargs["params"])
234 resp = await self._client.request(method, url, **kwargs)
235 try:
236 resp.raise_for_status()
237 except httpx.HTTPStatusError as exc:
238 exc.add_note(f"Body: {exc.response.text}")
239 match exc.response.status_code:
240 case 404:
241 raise Missing(f"Could not find {'/'.join(urlparts)}") from exc
242 case 409:
243 raise Conflict(f"Conflict updating {'/'.join(urlparts)}") from exc
244 case _:
245 raise
246 return resp
247
248 def __getitem__(self, key: str) -> "Database":
249 """
250 Gets a database.
251
252 (Does not actually check if it exists.)
253 """
254 return Database(self, key)
255
[docs]
256 async def get_db(self, dbname: str) -> "Database":
257 """
258 Gets a database. Checks if it exists.
259
260 See :http:head:`/{db}`
261 """
262 await self._request("HEAD", dbname)
263 return Database(self, dbname)
264
[docs]
265 async def create_db(
266 self,
267 dbname: str,
268 *,
269 shards: int | None = None,
270 replicas: int | None = None,
271 partitioned: bool | None = None,
272 ) -> "Database":
273 """
274 Create a database
275
276 See :http:post:`/{db}`
277 """
278 try:
279 await self._request(
280 "PUT",
281 dbname,
282 params={
283 "q": shards,
284 "n": replicas,
285 "partitioned": partitioned,
286 },
287 )
288 except httpx.HTTPStatusError as exc:
289 match exc.response.status_code:
290 case 412:
291 raise Conflict("Database already exists") from exc
292 case _:
293 raise
294 return Database(self, dbname)
295
[docs]
296 async def delete_db(self, dbname: str):
297 """
298 Delete a database.
299
300 See :http:delete:`/{db}`
301 """
302 await self._request("DELETE", dbname)
303
[docs]
304 async def iter_dbs(self) -> AsyncIterator[str]:
305 """
306 List all databases
307
308 See :http:get:`/_all_dbs`
309 """
310 resp = await self._request("GET", "_all_dbs")
311 for dbname in resp.json():
312 yield dbname
313
314 # TODO: Database metadata
315
316
[docs]
317class Database:
318 """
319 An individual database.
320 """
321
322 def __init__(self, session, name):
323 """
324 :private:
325 """
326 self._session = session
327 self._name = name
328
329 def _blob2doc(self, blob, db, docid, etag=...):
330 if docid is ...:
331 docid = blob["_id"]
332 if etag is ...:
333 etag = f'"{blob["_rev"]}"'
334 doc = self._session.loader().load_from_blob(blob)
335 doc.__db = db
336 doc.__docid = docid
337 doc.__etag = etag
338 return doc
339
340 def _doc2blob(self, doc):
341 blob = self._session.loader().dump_to_blob(doc)
342 db = docid = etag = None
343 try:
344 db = doc.__db
345 docid = doc.__docid
346 etag = doc.__etag
347 except AttributeError:
348 pass
349 return blob, db, docid, etag
350
351 def _touch_doc(self, doc, *, db=None, etag=None, rev=None, id=None):
352 fields = {}
353 if etag is not None:
354 doc.__etag = etag
355 if db is not None:
356 doc.__db = db
357 if rev is not None:
358 fields["_rev"] = rev
359 if id is not None:
360 doc.__docid = id
361 fields["_id"] = id
362 if fields:
363 self._session.loader().update_doc(doc, **fields)
364
[docs]
365 async def get(
366 self,
367 docid: str,
368 *,
369 attachments: bool = False,
370 conflicts: bool = False,
371 deleted_conflicts: bool = False,
372 latest: bool = False,
373 local_seq: bool = False,
374 meta: bool = False,
375 open_revs: list[str] | Literal["all"] | None = None,
376 rev: str | None = None,
377 revs: bool = False,
378 revs_info: bool = False,
379 ):
380 """
381 Get a document
382
383 See :http:get:`/{db}/{docid}`
384 """
385 resp = await self._session._request(
386 "GET",
387 self._name,
388 docid,
389 params={
390 "attachments": attachments,
391 "conflicts": conflicts,
392 "deleted_conflicts": deleted_conflicts,
393 "latest": latest,
394 "local_seq": local_seq,
395 "meta": meta,
396 "open_revs": open_revs,
397 "rev": rev,
398 "revs": revs,
399 "revs_info": revs_info,
400 },
401 headers={
402 "Accept": "application/json",
403 },
404 )
405
406 blob = resp.json()
407 if blob.get("_deleted", False): # TODO: Flag to override this
408 raise Deleted("Document {self._name}/{docid} is marked as deleted")
409 if "ETag" in resp.headers:
410 etag = resp.headers["ETag"]
411 else:
412 # Conflicts mode
413 etag = f'"{blob["_rev"]}"'
414
415 # For some reason, requested fiels are omitted when empty
416 if attachments:
417 blob.setdefault("_attachments", {})
418 if conflicts:
419 blob.setdefault("_conflicts", [])
420 if deleted_conflicts:
421 blob.setdefault("_deleted_conflicts", [])
422 if revs:
423 blob.setdefault("_revisions", {})
424 if revs_info or open_revs:
425 blob.setdefault("_revisions", {})
426
427 doc = self._blob2doc(blob, self._name, docid, etag)
428 return doc
429
430 # TODO: Attachments
431
[docs]
432 async def find_one(
433 self, selector: typing.Mapping, use_index: str | list[str] | None = None
434 ):
435 """
436 Get a single document based on ``selector``.
437
438 See :http:post:`/{db}/_find`
439 """
440 json_body = {"selector": selector, "limit": 2}
441
442 if use_index is not None:
443 json_body |= {"use_index": use_index}
444
445 resp = await self._session._request("POST", self._name, "_find", json=json_body)
446
447 blob = resp.json()
448 results = blob.get("docs")
449 match len(results):
450 case 0:
451 raise Missing("No results found.")
452 case 1:
453 return self._blob2doc(results[0], self._name, ...)
454 case _:
455 raise TooManyResults("More than one result found.")
456
[docs]
457 async def find(
458 self,
459 selector: typing.Mapping,
460 use_index: str | list[str] | None = None,
461 pagesize: int | None = None,
462 ) -> AsyncIterator:
463 """
464 Generate documents based on ``selector``.
465
466 See :http:post:`/{db}/_find`
467 """
468 json_body: dict[str, Any] = {"selector": selector, "bookmark": None}
469
470 if use_index is not None:
471 json_body |= {"use_index": use_index}
472
473 if pagesize is not None:
474 json_body |= {"limit": pagesize}
475
476 while True:
477 resp = await self._session._request(
478 "POST", self._name, "_find", json=json_body
479 )
480 payload = resp.json()
481 if payload.get("warning", None):
482 warnings.warn(payload["warning"], FindWarning)
483
484 for doc in payload["docs"]:
485 yield self._blob2doc(doc, self._name, ...)
486
487 if not payload["docs"]:
488 break
489
490 json_body["bookmark"] = payload["bookmark"]
491
[docs]
492 async def attempt_put(
493 self,
494 doc,
495 docid: str | None = None,
496 *,
497 batch: bool = False,
498 ):
499 """
500 Update a document.
501
502 db and docid only need to be given if it's a new document.
503
504 See :http:put:`/{db}/{docid}`
505 """
506 blob, _db, _docid, etag = self._doc2blob(doc)
507 assert _db is None or _db == self._name
508 resp = await self._session._request(
509 "PUT",
510 self._name,
511 _docid or docid,
512 params={"batch": "ok"} if batch else {},
513 headers={"If-Match": etag} if etag else {},
514 json=blob,
515 )
516 payload = resp.json()
517 assert payload["ok"]
518 self._touch_doc(
519 doc,
520 db=self._name,
521 id=payload["id"],
522 etag=resp.headers["ETag"],
523 rev=resp.json()["rev"],
524 )
525
[docs]
526 async def attempt_delete(self, doc, *, batch: bool = False):
527 """
528 Delete a document
529
530 See :http:delete:`/{db}/{docid}`
531 """
532 _, db, docid, etag = self._doc2blob(doc)
533 assert db == self._name
534 assert docid
535 await self._session._request(
536 "DELETE",
537 db,
538 docid,
539 params={"batch": "ok"} if batch else {},
540 headers={"If-Match": etag},
541 )
542
[docs]
543 async def attempt_copy(self, src_doc, dst_doc, *, batch: bool = False):
544 """
545 Copy a document
546
547 .. todo::
548
549 Implement
550
551 See :http:copy:`/{db}/{docid}`
552 """
553 # FIXME: Figure out signature
554
[docs]
555 async def mutate(self, docid: str) -> AsyncIterator:
556 """
557 A document mutation loop::
558
559 async for doc in couch.mutate_doc("spam"):
560 doc.foo = "bar"
561
562 Will replay the mutation until it goes through.
563 """
564 doc = await self.get(docid)
565 while True:
566 yield doc
567 try:
568 await self.attempt_put(doc)
569 except Conflict:
570 doc = await self.get(docid)
571 else:
572 break
573
[docs]
574 async def iter_all_docs(
575 self, *, include_docs: bool = False
576 ) -> AsyncIterator[structs.AllDocs_DocRef]:
577 """
578 List all documents.
579
580 This excludes design documents, see :meth:`.iter_design_docs`.
581
582 Args:
583 include_docs: Pre-load documents
584
585 See :http:get:`/{db}/_all_docs`
586 """
587 # TODO: Pagination
588 resp = await self._session._request(
589 "GET",
590 self._name,
591 "_all_docs",
592 params={
593 "include_docs": include_docs,
594 },
595 headers={
596 "Accept": "application/json",
597 },
598 )
599 blob = resp.json()
600 for ref in blob["rows"]:
601 if ref["id"].startswith("_design/"):
602 continue
603 if "doc" in ref:
604 doc = self._blob2doc(ref["doc"], self._name, ref["id"])
605 else:
606 doc = None
607 yield structs.AllDocs_DocRef(
608 _db=self, docid=ref["id"], rev=ref["value"]["rev"], _doc=doc
609 )
610
[docs]
611 async def iter_indexes(self) -> AsyncIterator[structs.Index]:
612 resp = await self._session._request("GET", self._name, "_index")
613 payload = resp.json()
614 for idx in payload["indexes"]:
615 if idx["ddoc"] is None and idx["name"] == "_all_docs":
616 continue
617 yield structs.Index.from_dict(idx)
618
[docs]
619 async def add_index(
620 self,
621 name: str | None = None,
622 *,
623 fields: list[str],
624 ddoc: str | None = None,
625 type: str = "json",
626 ):
627 await self._session._request(
628 "POST",
629 self._name,
630 "_index",
631 json={
632 "index": {"fields": list(fields)},
633 "name": name,
634 "ddoc": ddoc,
635 "type": type,
636 },
637 )
638
639
[docs]
640class NoServerFound(Exception):
641 """
642 None of the configured servers seem to be working.
643 """
644
645
[docs]
646class SessionPool:
647 """
648 Responsible for giving out Couch connections.
649
650 You probably want to override like::
651
652 class MyPool(SessionPool):
653 session_class = MySession
654 """
655
656 _client: httpx.AsyncClient
657
658 #: Class to use for sessions
659 session_class: type[CouchSession]
660
661 def __init__(self):
662 super().__init__()
663 self._client = self.make_client()
664
[docs]
665 def make_client(self) -> httpx.AsyncClient:
666 """
667 Produce an httpx client.
668 """
669 return httpx.AsyncClient(http2=True, follow_redirects=True)
670
[docs]
671 async def iter_servers(self) -> AsyncIterator[str | httpx.URL]:
672 """
673 Produce the list of potential servers.
674
675 Override this
676 """
677 raise NotImplementedError
678 for _ in ():
679 yield
680
681 async def _check_server(self, url: httpx.URL):
682 resp = await self._client.get(url.join("_up"))
683 resp.raise_for_status()
684
[docs]
685 async def session(self) -> CouchSession:
686 """
687 Get a session
688 """
689 excs = []
690 async for url in self.iter_servers():
691 url = httpx.URL(url)
692 try:
693 await self._check_server(url)
694 except Exception as e:
695 excs.append(e)
696 else:
697 return self.session_class(self._client, url)
698 else:
699 raise NoServerFound() from ExceptionGroup(
700 "There were errors checking servers", excs
701 )