1import json
2import typing
3from typing import AsyncIterator, Literal, Callable, Protocol, TypeVar, Generic
4import warnings
5
6import httpx
7
8from . import structs
9
10
11DOCT = TypeVar("DOCT")
12
13
[docs]
14class DocumentLoader(Protocol, Generic[DOCT]):
15 """
16 Protocol for things that go between JSON blobs and hydrated document
17 instances.
18 """
19
[docs]
20 def loadj(self, blob: dict) -> DOCT:
21 """
22 Convert a JSON blob into a document object.
23 """
24
[docs]
25 def dumpj(self, doc: DOCT) -> dict:
26 """
27 Convert a document into a JSON blob.
28 """
29
30
[docs]
31class DocumentRegistry:
32 """
33 Handles de/serialization, manages migrations, etc.
34
35 Default loader implementation.
36
37 Do not use directly. You probably want one of the :ref:`integrations`.
38 """
39
40 TYPE_KEY = ""
41
42 _docclasses = {}
43 _migrations = []
44
45 def __init_sublcass__(cls):
46 cls._docclasses = {}
47 cls._migrations = []
48
49 @classmethod
50 def _get_class_from_name(cls, name: str) -> type:
51 return cls._docclasses[name]
52
53 @classmethod
54 def _get_name_from_class(cls, klass: type) -> str:
55 for name, kind in cls._docclasses.items():
56 if issubclass(klass, kind): # In case of decorator shenanigans
57 return name
58 else:
59 raise ValueError(f"Couldn't find name for {klass}")
60
[docs]
61 @classmethod
62 def document(cls, name: str):
63 """
64 Register a class as a loadable couch document.
65
66 Args:
67 name: The type identifier to save to CouchDB
68
69 .. note::
70
71 The given name must be globally unique and must never change.
72 """
73 assert not isinstance(name, type)
74 assert name not in cls._docclasses
75
76 def _(klass: type):
77 cls._docclasses[name] = klass
78 return klass
79
80 return _
81
[docs]
82 @classmethod
83 def migration(cls, before: type, after: type):
84 """
85 Define a function that'll convert between documents.
86 """
87 # Normalize to the document classes previously registered
88 bname = cls._get_name_from_class(before)
89 aname = cls._get_name_from_class(after)
90 # Enforce linearity
91 assert not any(b == bname for b, _, _ in cls._migrations)
92
93 def _(func: Callable):
94 cls._migrations.append((bname, aname, func))
95 return func
96
97 return _
98
[docs]
99 def load_doc(self, cls: type, blob: dict):
100 """
101 Converts a JSON blob into a document.
102
103 Override me.
104 """
105 raise NotImplementedError
106
[docs]
107 def dump_doc(self, doc) -> dict:
108 """
109 Convert a document into a JSON blob.
110
111 Override me.
112 """
113 raise NotImplementedError
114
115 def _migrate(self, bname, doc):
116 while funcs := [f for b, _, f in self._migrations if b == bname]:
117 (func,) = funcs
118 doc = func(doc)
119 bname = self._get_name_from_class(type(doc))
120 return doc
121
[docs]
122 def loadj(self, blob):
123 type = blob.pop(self.TYPE_KEY)
124 klass = self._get_class_from_name(type)
125 doc = self.load_doc(klass, blob)
126 doc = self._migrate(type, doc)
127 return doc
128
[docs]
129 def dumpj(self, doc):
130 blob = self.dump_doc(doc)
131 blob[self.TYPE_KEY] = self._get_name_from_class(type(doc))
132 return blob
133
134
[docs]
135class Conflict(Exception):
136 """
137 There was a conflict when trying to perform the operation.
138 """
139
140
[docs]
141class Missing(Exception):
142 """
143 Could not find the requested document.
144
145 Note that this is a 404, not a tombstone.
146 """
147
148
[docs]
149class Deleted(Exception):
150 """
151 Requested a deleted document.
152
153 Note that this is a document with a tombstone, not a 404.
154 """
155
156
[docs]
157class TooManyResults(Exception):
158 """
159 Requested a single document in a find operation and found more than one.
160
161 Note this is based on number of results. Request was successful.
162 """
163
164
[docs]
165class FindWarning(UserWarning):
166 """
167 Warnings reported by the CouchDB _find endpoint.
168 """
169
170
[docs]
171class CouchSession:
172 """
173 A connection to CouchDB.
174
175 You probably want to override like::
176
177 class MySession(CouchSession):
178 loader = MyRegistry
179
180 """
181
182 _client: httpx.AsyncClient
183 _root: httpx.URL
184
185 #: Class responsible for de/serializing data.
186 loader: type[DocumentLoader]
187
188 def __init__(self, client: httpx.AsyncClient, root: httpx.URL):
189 self._client = client
190 self._root = root
191
192 @staticmethod
193 def _fix_params(params):
194 rv = {}
195 for key, value in params.items():
196 if value is None:
197 continue
198 elif isinstance(value, str):
199 rv[key] = value
200 else:
201 rv[key] = json.dumps(value)
202 return rv
203
204 async def _request(self, method, *urlparts, **kwargs):
205 url = self._root.join("/".join(urlparts))
206 if "params" in kwargs:
207 kwargs["params"] = self._fix_params(kwargs["params"])
208 resp = await self._client.request(method, url, **kwargs)
209 try:
210 resp.raise_for_status()
211 except httpx.HTTPStatusError as exc:
212 exc.add_note(f"Body: {exc.response.text}")
213 match exc.response.status_code:
214 case 404:
215 raise Missing(f"Could not find {'/'.join(urlparts)}") from exc
216 case 409:
217 raise Conflict(f"Conflict updating {'/'.join(urlparts)}") from exc
218 case _:
219 raise
220 return resp
221
222 def __getitem__(self, key: str) -> "Database":
223 """
224 Gets a database.
225
226 (Does not actually check if it exists.)
227 """
228 return Database(self, key)
229
[docs]
230 async def get_db(self, dbname: str) -> "Database":
231 """
232 Gets a database. Checks if it exists.
233
234 See :http:head:`/{db}`
235 """
236 await self._request("HEAD", dbname)
237 return Database(self, dbname)
238
[docs]
239 async def create_db(
240 self,
241 dbname: str,
242 *,
243 shards: int | None = None,
244 replicas: int | None = None,
245 partitioned: bool | None = None,
246 ) -> "Database":
247 """
248 Create a database
249
250 See :http:post:`/{db}`
251 """
252 try:
253 await self._request(
254 "PUT",
255 dbname,
256 params={
257 "q": shards,
258 "n": replicas,
259 "partitioned": partitioned,
260 },
261 )
262 except httpx.HTTPStatusError as exc:
263 match exc.response.status_code:
264 case 412:
265 raise Conflict("Database already exists") from exc
266 case _:
267 raise
268 return Database(self, dbname)
269
[docs]
270 async def delete_db(self, dbname: str):
271 """
272 Delete a database.
273
274 See :http:delete:`/{db}`
275 """
276 await self._request("DELETE", dbname)
277
[docs]
278 async def iter_dbs(self) -> AsyncIterator[str]:
279 """
280 List all databases
281
282 See :http:get:`/_all_dbs`
283 """
284 resp = await self._request("GET", "_all_dbs")
285 for dbname in resp.json():
286 yield dbname
287
288 # TODO: Database metadata
289
290
[docs]
291class Database:
292 """
293 An individual database.
294 """
295
296 def __init__(self, session, name):
297 """
298 :private:
299 """
300 self._session = session
301 self._name = name
302
303 def _blob2doc(self, blob, db, docid, etag=...):
304 if docid is ...:
305 docid = blob["_id"]
306 if etag is ...:
307 etag = f'"{blob["_rev"]}"'
308 doc = self._session.loader().loadj(blob)
309 doc.__db = db
310 doc.__docid = docid
311 doc.__etag = etag
312 return doc
313
314 def _doc2blob(self, doc):
315 blob = self._session.loader().dumpj(doc)
316 db = docid = etag = None
317 try:
318 db = doc.__db
319 docid = doc.__docid
320 etag = doc.__etag
321 except AttributeError:
322 pass
323 return blob, db, docid, etag
324
[docs]
325 async def get(
326 self,
327 docid: str,
328 *,
329 attachments: bool = False,
330 conflicts: bool = False,
331 deleted_conflicts: bool = False,
332 latest: bool = False,
333 local_seq: bool = False,
334 meta: bool = False,
335 open_revs: list[str] | Literal["all"] | None = None,
336 rev: str | None = None,
337 revs: bool = False,
338 revs_info: bool = False,
339 ):
340 """
341 Get a document
342
343 See :http:get:`/{db}/{docid}`
344 """
345 resp = await self._session._request(
346 "GET",
347 self._name,
348 docid,
349 params={
350 "attachments": attachments,
351 "conflicts": conflicts,
352 "deleted_conflicts": deleted_conflicts,
353 "latest": latest,
354 "local_seq": local_seq,
355 "meta": meta,
356 "open_revs": open_revs,
357 "rev": rev,
358 "revs": revs,
359 "revs_info": revs_info,
360 },
361 headers={
362 "Accept": "application/json",
363 },
364 )
365
366 blob = resp.json()
367 if blob.get("_deleted", False): # TODO: Flag to override this
368 raise Deleted("Document {self._name}/{docid} is marked as deleted")
369 if "ETag" in resp.headers:
370 etag = resp.headers["ETag"]
371 else:
372 # Conflicts mode
373 etag = f'"{blob["_rev"]}"'
374
375 # For some reason, requested fiels are omitted when empty
376 if attachments:
377 blob.setdefault("_attachments", {})
378 if conflicts:
379 blob.setdefault("_conflicts", [])
380 if deleted_conflicts:
381 blob.setdefault("_deleted_conflicts", [])
382 if revs:
383 blob.setdefault("_revisions", {})
384 if revs_info or open_revs:
385 blob.setdefault("_revisions", {})
386
387 doc = self._blob2doc(blob, self._name, docid, etag)
388 return doc
389
390 # TODO: Attachments
391
[docs]
392 async def find_one(
393 self, selector: typing.Mapping, use_index: str | list[str] | None = None
394 ):
395 """
396 Get a single document based on ``selector``.
397
398 See :http:post:`/{db}/_find`
399 """
400 json_body = {"selector": selector, "limit": 2}
401
402 if use_index is not None:
403 json_body |= {"use_index": use_index}
404
405 resp = await self._session._request("POST", self._name, "_find", json=json_body)
406
407 blob = resp.json()
408 results = blob.get("docs")
409 match len(results):
410 case 0:
411 raise Missing("No results found.")
412 case 1:
413 return self._blob2doc(results[0], self._name, ...)
414 case _:
415 raise TooManyResults("More than one result found.")
416
[docs]
417 async def find(
418 self,
419 selector: typing.Mapping,
420 use_index: str | list[str] | None = None,
421 pagesize: int | None = None,
422 ) -> AsyncIterator:
423 """
424 Generate documents based on ``selector``.
425
426 See :http:post:`/{db}/_find`
427 """
428 json_body = {"selector": selector, "bookmark": None}
429
430 if use_index is not None:
431 json_body |= {"use_index": use_index}
432
433 if pagesize is not None:
434 json_body |= {"limit": pagesize}
435
436 while True:
437 resp = await self._session._request(
438 "POST", self._name, "_find", json=json_body
439 )
440 payload = resp.json()
441 if payload.get("warning", None):
442 warnings.warn(payload["warning"], FindWarning)
443
444 for doc in payload["docs"]:
445 yield self._blob2doc(doc, self._name, ...)
446
447 if not payload["docs"]:
448 break
449
450 json_body["bookmark"] = payload["bookmark"]
451
[docs]
452 async def attempt_put(
453 self,
454 doc,
455 docid: str | None = None,
456 *,
457 batch: bool = False,
458 ):
459 """
460 Update a document.
461
462 db and docid only need to be given if it's a new document.
463
464 See :http:put:`/{db}/{docid}`
465 """
466 blob, _db, _docid, etag = self._doc2blob(doc)
467 assert _db is None or _db == self._name
468 await self._session._request(
469 "PUT",
470 self._name,
471 _docid or docid,
472 params={"batch": "ok"} if batch else {},
473 headers={"If-Match": etag} if etag else {},
474 json=blob,
475 )
476
[docs]
477 async def attempt_delete(self, doc, *, batch: bool = False):
478 """
479 Delete a document
480
481 See :http:delete:`/{db}/{docid}`
482 """
483 _, db, docid, etag = self._doc2blob(doc)
484 assert db == self._name
485 assert docid
486 await self._session._request(
487 "DELETE",
488 db,
489 docid,
490 params={"batch": "ok"} if batch else {},
491 headers={"If-Match": etag},
492 )
493
[docs]
494 async def attempt_copy(self, src_doc, dst_doc, *, batch: bool = False):
495 """
496 Copy a document
497
498 .. todo::
499
500 Implement
501
502 See :http:copy:`/{db}/{docid}`
503 """
504 # FIXME: Figure out signature
505
[docs]
506 async def mutate(self, docid: str) -> AsyncIterator:
507 """
508 A document mutation loop::
509
510 async for doc in couch.mutate_doc("spam"):
511 doc.foo = "bar"
512
513 Will replay the mutation until it goes through.
514 """
515 doc = await self.get(docid)
516 while True:
517 yield doc
518 try:
519 await self.attempt_put(doc)
520 except Conflict:
521 doc = await self.get(docid)
522 else:
523 break
524
[docs]
525 async def iter_all_docs(
526 self, include_docs: bool = False
527 ) -> AsyncIterator[structs.AllDocs_DocRef]:
528 """
529 List all documents
530
531 TODO: More params
532
533 Args:
534 include_docs: Pre-load documents
535
536 See :http:get:`/{db}/_all_docs`
537 """
538 resp = await self._session._request(
539 "GET",
540 self._name,
541 "_all_docs",
542 params={
543 "include_docs": include_docs,
544 },
545 headers={
546 "Accept": "application/json",
547 },
548 )
549 blob = resp.json()
550 for ref in blob["rows"]:
551 if "doc" in ref:
552 doc = self._blob2doc(ref["doc"], self._name, ref["id"])
553 else:
554 doc = None
555 yield structs.AllDocs_DocRef(
556 _db=self, docid=ref["id"], rev=ref["value"]["rev"], _doc=doc
557 )
558
559 # TODO: Mango searches
560 # TODO: Database operations
561
562
[docs]
563class SessionPool:
564 """
565 Responsible for giving out Couch connections.
566
567 You probably want to override like::
568
569 class MyPool(SessionPool):
570 session_class = MySession
571 """
572
573 _client: httpx.AsyncClient
574
575 #: Class to use for sessions
576 session_class: type[CouchSession]
577
578 def __init__(self):
579 super().__init__()
580 self._client = self.make_client()
581
[docs]
582 def make_client(self) -> httpx.AsyncClient:
583 """
584 Produce an httpx client.
585 """
586 return httpx.AsyncClient(http2=True, follow_redirects=True)
587
[docs]
588 async def iter_servers(self) -> AsyncIterator[str]:
589 """
590 Produce the list of potential servers.
591
592 Override this
593 """
594 raise NotImplementedError
595 for _ in ():
596 yield
597
598 async def _check_server(self, url: httpx.URL):
599 resp = await self._client.get(url.join("_up"))
600 return resp.is_success
601
[docs]
602 async def session(self) -> CouchSession:
603 """
604 Get a session
605 """
606 async for url in self.iter_servers():
607 url = httpx.URL(url)
608 if await self._check_server(url):
609 return self.session_class(self._client, url)