"""Get a GeoDataFrame from an ArcGIS FeatureLayer."""
from __future__ import annotations
import asyncio
import math
import warnings
from asyncio import gather
from collections.abc import AsyncGenerator, Mapping
from functools import reduce
from typing import TYPE_CHECKING, Any, Literal
from aiohttp import ClientSession
from restgdf._client._protocols import AsyncHTTPSession
from restgdf._config import get_config
from restgdf._logging import get_logger
from restgdf._models._drift import _parse_response
from restgdf._models.responses import FeaturesResponse, LayerMetadata
from restgdf.errors import (
PaginationError,
PaginationInconsistencyWarning,
RestgdfResponseError,
)
from restgdf.telemetry._spans import start_feature_layer_stream_span
from restgdf.utils.getinfo import (
default_data,
default_headers,
get_feature_count,
get_max_record_count,
get_metadata,
get_object_ids,
supports_pagination,
)
from restgdf.utils._http import _arcgis_request, default_timeout
from restgdf.utils._metadata import (
normalize_spatial_reference,
supports_pagination_explicitly,
)
from restgdf.utils._optional import (
require_geo_stack,
require_geodataframe,
require_geopandas_read_file,
require_pandas_concat,
require_pyogrio_list_drivers,
)
from restgdf.utils._pagination import build_pagination_plan
from restgdf.utils.utils import where_var_in_list
if TYPE_CHECKING:
from geopandas import GeoDataFrame
supported_drivers: dict[str, str] | None = None
_METADATA_LOG = get_logger("transport")
def _require_geo_query_support(feature: str) -> None:
"""Fail fast for GeoDataFrame entrypoints when the geo stack is missing."""
require_geo_stack(feature)
[docs]
def read_file(*args, **kwargs):
"""Load a vector payload with geopandas only when geo support is needed."""
return require_geopandas_read_file("GeoDataFrame queries")(*args, **kwargs)
def _get_supported_drivers() -> dict[str, str]:
"""Load pyogrio drivers lazily so base installs can still import restgdf."""
global supported_drivers
if supported_drivers is None:
supported_drivers = require_pyogrio_list_drivers("GeoDataFrame queries")()
return supported_drivers
async def _get_sub_features(
url: str,
session: AsyncHTTPSession,
query_data: dict,
*,
batch_index: int | None = None,
**kwargs,
) -> list[dict[str, Any]]:
"""Fetch a single query batch as raw ArcGIS feature dicts."""
kwargs = {k: v for k, v in kwargs.items() if k != "data"}
kwargs.setdefault("timeout", default_timeout())
response = await _arcgis_request(
session,
f"{url}/query",
dict(query_data),
headers=default_headers(kwargs.pop("headers", None)),
**kwargs,
)
raw = await response.json(content_type=None)
envelope = _parse_response(FeaturesResponse, raw, context=f"{url}/query")
if envelope.exceeded_transfer_limit:
raise PaginationError(
f"{url}/query returned exceededTransferLimit=true; query batching missed "
"records and the response page is incomplete.",
batch_index=batch_index,
page_size=query_data.get("resultRecordCount"),
)
return envelope.features or []
async def _feature_batch_generator(
url: str,
session: AsyncHTTPSession,
**kwargs,
) -> AsyncGenerator[list[dict[str, Any]]]:
"""Yield raw ArcGIS feature batches without requiring pandas/geopandas."""
query_data_batches = await get_query_data_batches(url, session, **kwargs)
max_inflight = get_config().concurrency.max_concurrent_requests
batch_iter = iter(enumerate(query_data_batches))
tasks: set[asyncio.Task] = set()
task_order: dict[asyncio.Task, int] = {}
def _submit_next() -> asyncio.Task | None:
try:
idx, query_data = next(batch_iter)
except StopIteration:
return None
task = asyncio.create_task(
get_sub_features(
url,
session,
query_data=query_data,
batch_index=idx,
**kwargs,
),
)
tasks.add(task)
task_order[task] = idx
return task
try:
for _ in range(max_inflight):
task = _submit_next()
if task is None:
break
while tasks:
done, pending = await asyncio.wait(
tasks,
return_when=asyncio.FIRST_COMPLETED,
)
tasks = set(pending)
completed_batches: list[list[dict[str, Any]]] = []
for task in sorted(done, key=task_order.__getitem__):
replacement = _submit_next()
if replacement is not None:
tasks.add(replacement)
completed_batches.append(await task)
task_order.pop(task, None)
for feature_batch in completed_batches:
yield feature_batch
finally:
for task in tasks:
if not task.done():
task.cancel()
[docs]
def get_sub_features(*args, **kwargs):
"""Compatibility wrapper for the raw feature query helper."""
return _get_sub_features(*args, **kwargs)
def _feature_to_row_dict(feature: dict[str, Any]) -> dict[str, Any]:
"""Flatten an ArcGIS feature into a row-shaped dictionary."""
row = dict(feature.get("attributes") or {})
if "geometry" in feature:
row["geometry"] = feature["geometry"]
for key, value in feature.items():
if key not in {"attributes", "geometry"} and key not in row:
row[key] = value
return row
[docs]
def combine_where_clauses(base_where: str | None, extra_where: str) -> str:
"""Combine where clauses without changing the default all-records predicate."""
if base_where in (None, "", "1=1"):
return extra_where
return f"({base_where}) AND ({extra_where})"
[docs]
def chunk_values(values: list[int], chunk_size: int) -> list[list[int]]:
"""Split values into evenly-sized chunks."""
return [values[i : i + chunk_size] for i in range(0, len(values), chunk_size)]
def _advertised_max_record_count_factor(
metadata: Mapping[str, Any] | LayerMetadata,
) -> float | None:
"""Return the server-advertised ``maxRecordCountFactor`` or ``None``.
Accepts both the raw metadata mapping returned by low-level helpers
and the typed :class:`LayerMetadata` model used by the live
:class:`~restgdf.featurelayer.FeatureLayer` path. Returns ``None``
when the ``advancedQueryCapabilities`` block is missing, when the
factor key itself is absent, or when the advertised value is not a
positive number (``None`` / 0 / negative / non-numeric). The return
value is intended to be threaded straight through to
``build_pagination_plan(..., advertised_factor=...)``.
"""
if isinstance(metadata, Mapping):
aqc = metadata.get("advancedQueryCapabilities")
else:
aqc = metadata.advanced_query_capabilities
if isinstance(aqc, Mapping):
raw = aqc.get("maxRecordCountFactor")
else:
raw = getattr(aqc, "max_record_count_factor", None)
if raw is None or isinstance(raw, bool):
# bool is a subclass of int; reject it so True/False never leak
# into the numeric path and silently wire advertised_factor=1.0.
return None
try:
value = float(raw)
except (TypeError, ValueError):
return None
if not math.isfinite(value) or value <= 0:
# Reject NaN and ±inf; both are parseable by float() but
# nonsensical as pagination multipliers.
return None
return value
[docs]
async def get_query_data_batches(
url: str,
session: AsyncHTTPSession,
**kwargs,
) -> list[dict]:
"""Build query payloads for each request needed to read a layer.
When the layer metadata advertises an explicit
``advancedQueryCapabilities.maxRecordCountFactor`` (R-72), the
value is forwarded to ``build_pagination_plan`` as
``advertised_factor=`` so pagination batch sizes honor the
server-published upper bound. Layers that do **not** advertise the
field keep today's byte-for-byte batching: no ``advertised_factor``
kwarg is supplied and the planner falls back to its
``_DEFAULT_FACTOR`` (``1.0``).
Pages observed at stream time that return zero features while
setting ``exceededTransferLimit=true`` are flagged with
``PaginationInconsistencyWarning`` (R-73) from the internal page
resolver; see that helper for details.
"""
request_data = dict(kwargs.get("data") or {})
feature_count = await get_feature_count(url, session, **kwargs)
token = request_data.get("token")
metadata = await get_metadata(url, session, token=token)
max_record_count = get_max_record_count(metadata)
requested_page_size = request_data.get("resultRecordCount")
if isinstance(requested_page_size, int) and requested_page_size > 0:
page_size = min(requested_page_size, max_record_count)
else:
page_size = max_record_count
if feature_count <= max_record_count:
return [request_data]
if supports_pagination(metadata) and supports_pagination_explicitly(metadata):
if isinstance(requested_page_size, int) and requested_page_size > 0:
return [
{
**request_data,
"resultOffset": offset,
"resultRecordCount": min(page_size, feature_count - offset),
}
for offset in range(0, feature_count, page_size)
]
# R-72: opt-in wire of advertised maxRecordCountFactor. Only
# pass ``advertised_factor`` when the server actually publishes
# it, so layers without the field keep byte-exact 3.0 batching.
planner_kwargs: dict[str, Any] = {}
advertised_factor = _advertised_max_record_count_factor(metadata)
if advertised_factor is not None:
planner_kwargs["advertised_factor"] = advertised_factor
plan = build_pagination_plan(
feature_count,
max_record_count,
**planner_kwargs,
)
return [
{
**request_data,
"resultOffset": offset,
"resultRecordCount": count,
}
for offset, count in plan.batches
]
object_id_field_name, object_ids = await get_object_ids(url, session, **kwargs)
base_where = request_data.get("where")
return [
{
**request_data,
"where": combine_where_clauses(
base_where,
where_var_in_list(object_id_field_name, object_id_chunk),
),
}
for object_id_chunk in chunk_values(object_ids, max_record_count)
]
[docs]
async def get_sub_gdf(
url: str,
session: AsyncHTTPSession,
query_data: dict,
**kwargs,
) -> GeoDataFrame:
_require_geo_query_support("get_sub_gdf()")
data = dict(query_data)
gdfdriver = "ESRIJSON" if "ESRIJSON" in _get_supported_drivers() else "GeoJSON"
if gdfdriver == "GeoJSON":
data["f"] = "GeoJSON"
kwargs = {k: v for k, v in kwargs.items() if k != "data"}
kwargs.setdefault("timeout", default_timeout())
response = await _arcgis_request(
session,
f"{url}/query",
data,
headers=default_headers(kwargs.pop("headers", None)),
**kwargs,
)
sub_gdf = read_file(
await response.text(),
# driver=gdfdriver, # this line raises a warning when using pyogrio w/ ESRIJSON
engine="pyogrio",
)
return sub_gdf
[docs]
async def get_gdf_list(
url: str,
session: AsyncHTTPSession,
**kwargs,
) -> list[GeoDataFrame]:
_require_geo_query_support("get_gdf_list()")
query_data_batches = await get_query_data_batches(url, session, **kwargs)
sem = asyncio.BoundedSemaphore(get_config().concurrency.max_concurrent_requests)
tasks = [
asyncio.create_task(
_run_get_sub_gdf_bounded(url, session, sem, query_data, **kwargs),
)
for query_data in query_data_batches
]
try:
gdf_list = await gather(*tasks)
return gdf_list
except Exception:
for task in tasks:
if not task.done():
task.cancel()
await gather(*tasks, return_exceptions=True)
raise
async def _run_get_sub_gdf_bounded(
url: str,
session: AsyncHTTPSession,
sem: asyncio.BoundedSemaphore,
query_data: dict,
**kwargs,
) -> GeoDataFrame:
async with sem:
return await get_sub_gdf(url, session, query_data=query_data, **kwargs)
[docs]
async def chunk_generator(
url: str,
session: AsyncHTTPSession,
**kwargs,
) -> AsyncGenerator[GeoDataFrame]:
"""
Asynchronously yield GeoDataFrames from a FeatureLayer in chunks.
This function retrieves GeoDataFrames in chunks based on the offset range
and yields each GeoDataFrame as it is retrieved. Each yielded chunk has
``gdf.attrs["spatial_reference"]`` populated from the layer's metadata
(R-65) when the layer reports a spatial reference.
"""
_require_geo_query_support("chunk_generator()")
query_data_batches = await get_query_data_batches(url, session, **kwargs)
request_data = kwargs.get("data") or {}
token = request_data.get("token") if isinstance(request_data, Mapping) else None
raw_sr: dict[str, Any] | None
try:
metadata = await get_metadata(url, session, token=token)
except Exception as exc: # pragma: no cover - metadata errors surface elsewhere
_METADATA_LOG.debug(
"spatial_reference.metadata_lookup_failed url=%s operation=chunk_generator",
url,
exc_info=exc,
)
raw_sr = None
else:
raw_sr = _extract_raw_spatial_reference(metadata)
max_inflight = get_config().concurrency.max_concurrent_requests
batch_iter = iter(query_data_batches)
tasks: set[asyncio.Task] = set()
task_order: dict[asyncio.Task, int] = {}
next_index = 0
def _submit_next() -> asyncio.Task | None:
nonlocal next_index
try:
query_data = next(batch_iter)
except StopIteration:
return None
task = asyncio.create_task(
get_sub_gdf(url, session, query_data=query_data, **kwargs),
)
tasks.add(task)
task_order[task] = next_index
next_index += 1
return task
try:
for _ in range(max_inflight):
task = _submit_next()
if task is None:
break
while tasks:
done, pending = await asyncio.wait(
tasks,
return_when=asyncio.FIRST_COMPLETED,
)
tasks = set(pending)
completed_chunks: list[GeoDataFrame] = []
for task in sorted(done, key=task_order.__getitem__):
replacement = _submit_next()
if replacement is not None:
tasks.add(replacement)
chunk = await task
task_order.pop(task, None)
if raw_sr is not None:
chunk.attrs["spatial_reference"] = raw_sr
completed_chunks.append(chunk)
for chunk in completed_chunks:
yield chunk
finally:
for task in tasks:
if not task.done():
task.cancel()
[docs]
async def row_dict_generator(
url: str,
session: AsyncHTTPSession,
**kwargs,
) -> AsyncGenerator[dict]:
"""Yield row-shaped dicts from an ArcGIS FeatureLayer.
.. deprecated:: 2.0
Module-level ``row_dict_generator`` is retained for backwards
compatibility. Prefer :meth:`restgdf.FeatureLayer.stream_rows` or
``restgdf.adapters.stream.iter_rows`` in new code.
"""
async for feature_batch in _feature_batch_generator(url, session, **kwargs):
for feature in feature_batch:
yield _feature_to_row_dict(feature)
[docs]
async def concat_gdfs(gdfs: list[GeoDataFrame]) -> GeoDataFrame:
GeoDataFrame = require_geodataframe("GeoDataFrame concatenation")
concat = require_pandas_concat("GeoDataFrame concatenation")
crs = gdfs[0].crs
saved_attrs = dict(gdfs[0].attrs)
if not all(gdf.crs == crs for gdf in gdfs):
raise ValueError("gdfs must have the same crs")
result = reduce(
lambda gdf1, gdf2: GeoDataFrame(
concat([gdf1, gdf2], ignore_index=True),
crs=gdf1.crs,
),
gdfs,
)
result.attrs.update(saved_attrs)
return result
[docs]
async def gdf_by_concat(
url: str,
session: AsyncHTTPSession,
**kwargs,
) -> GeoDataFrame:
_require_geo_query_support("gdf_by_concat()")
gdfs = await get_gdf_list(url, session, **kwargs)
result = await concat_gdfs(gdfs)
await _apply_spatial_reference_attr(result, url, session, **kwargs)
return result
[docs]
async def get_gdf(
url: str,
session: ClientSession | None = None,
where: str | None = None,
token: str | None = None,
**kwargs,
) -> GeoDataFrame:
_require_geo_query_support("get_gdf()")
owns_session = session is None
session = session or ClientSession()
datadict = default_data(kwargs.pop("data", None) or {})
if where is not None:
datadict["where"] = where
if token is not None:
existing_token = datadict.get("token")
if existing_token is not None and existing_token != token:
raise ValueError(
"Pass token either via token= or data['token'], not both with different values.",
)
datadict["token"] = token
try:
return await gdf_by_concat(url, session, data=datadict, **kwargs)
finally:
if owns_session:
await session.close()
# ---------------------------------------------------------------------------
# Spatial-reference propagation (R-65)
# ---------------------------------------------------------------------------
def _extract_raw_spatial_reference(
metadata: LayerMetadata | Mapping[str, Any] | None,
) -> dict[str, Any] | None:
"""Return the raw ``spatialReference`` dict from a layer metadata envelope.
Reads ``extent.spatialReference`` first (preferred), then falls back to a
top-level ``spatialReference`` key (sometimes present on non-spatial or
non-extent-bearing layers). Returns ``None`` when neither is present.
"""
if metadata is None:
return None
if isinstance(metadata, LayerMetadata):
extras = metadata.model_extra or {}
dumped = metadata.model_dump(by_alias=True, exclude_none=True)
elif isinstance(metadata, Mapping):
extras = dict(metadata)
dumped = dict(metadata)
else:
return None
for source in (extras, dumped):
extent = source.get("extent")
if isinstance(extent, Mapping):
sr = extent.get("spatialReference")
if sr is not None:
_, raw = normalize_spatial_reference(sr)
if raw is not None:
return raw
sr = source.get("spatialReference")
if sr is not None:
_, raw = normalize_spatial_reference(sr)
if raw is not None:
return raw
return None
async def _apply_spatial_reference_attr(
gdf: GeoDataFrame,
url: str,
session: AsyncHTTPSession,
**kwargs,
) -> None:
"""Stamp ``gdf.attrs['spatial_reference']`` from layer metadata (R-65).
Silent no-op when the metadata envelope carries no spatial reference.
"""
request_data = kwargs.get("data") or {}
token = request_data.get("token") if isinstance(request_data, Mapping) else None
try:
metadata = await get_metadata(url, session, token=token)
except Exception as exc: # pragma: no cover - metadata errors surface elsewhere
_METADATA_LOG.debug(
"spatial_reference.metadata_lookup_failed url=%s operation=apply_attr",
url,
exc_info=exc,
)
return
raw_sr = _extract_raw_spatial_reference(metadata)
if raw_sr is not None:
gdf.attrs["spatial_reference"] = raw_sr
# ---------------------------------------------------------------------------
# Streaming page-level primitive (BL-24)
# ---------------------------------------------------------------------------
async def _fetch_page_dict(
url: str,
session: AsyncHTTPSession,
query_data: Mapping[str, Any],
**kwargs,
) -> dict[str, Any]:
"""Fetch one query page and return the raw envelope dict."""
kwargs = {k: v for k, v in kwargs.items() if k != "data"}
kwargs.setdefault("timeout", default_timeout())
response = await _arcgis_request(
session,
f"{url}/query",
dict(query_data),
headers=default_headers(kwargs.pop("headers", None)),
**kwargs,
)
raw = await response.json(content_type=None)
if not isinstance(raw, dict):
raise RestgdfResponseError(
f"{url}/query returned a non-object JSON payload.",
context="query_response_shape",
raw=raw,
url=f"{url}/query",
)
return raw
async def _resolve_page(
url: str,
session: AsyncHTTPSession,
page: dict[str, Any],
query_data: Mapping[str, Any],
*,
on_truncation: Literal["raise", "ignore", "split"],
depth: int,
max_depth: int,
request_kwargs: dict[str, Any],
) -> AsyncGenerator[dict[str, Any]]:
"""Yield ``page`` (and any sub-pages) honoring ``on_truncation``."""
envelope = _parse_response(
FeaturesResponse,
page,
context=f"{url}/query",
)
if not envelope.exceeded_transfer_limit:
yield page
return
# R-73: 0-feature + exceededTransferLimit=true is an ArcGIS-side
# pagination bug — the cursor cannot advance but the service claims
# more rows exist. Flag it regardless of ``on_truncation`` so the
# inconsistency is visible to callers who choose to ignore the
# normal truncation signal.
if not envelope.features:
warnings.warn(
(
f"{url}/query returned exceededTransferLimit=true with "
"zero features; pagination cursor cannot advance."
),
PaginationInconsistencyWarning,
stacklevel=2,
)
if on_truncation == "ignore":
get_logger("pagination").warning(
"exceededTransferLimit=true on page; continuing (on_truncation='ignore'); "
"response is incomplete for url=%s",
url,
)
yield page
return
if on_truncation == "raise":
raise RestgdfResponseError(
f"{url}/query returned exceededTransferLimit=true; response page is incomplete.",
context="exceededTransferLimit",
raw=page,
url=f"{url}/query",
)
# on_truncation == "split": bisect OID list under the current predicate.
if depth >= max_depth:
raise RestgdfResponseError(
f"{url}/query: on_truncation='split' reached max depth {max_depth}; "
"layer cannot be bisected further.",
context="exceededTransferLimit",
raw=page,
url=f"{url}/query",
)
current_where = query_data.get("where", "1=1") or "1=1"
split_kwargs = {k: v for k, v in request_kwargs.items() if k != "data"}
split_kwargs["data"] = {
**(request_kwargs.get("data") or {}),
"where": current_where,
}
oid_field, oids = await get_object_ids(url, session, **split_kwargs)
if len(oids) <= 1:
raise RestgdfResponseError(
f"{url}/query: on_truncation='split' could not bisect "
f"{len(oids)} OID(s) further.",
context="exceededTransferLimit",
raw=page,
url=f"{url}/query",
)
mid = len(oids) // 2
halves = (oids[:mid], oids[mid:])
for half in halves:
half_where = combine_where_clauses(
current_where,
where_var_in_list(oid_field, half),
)
sub_qd = dict(query_data)
sub_qd["where"] = half_where
# Bisection changes the partitioning scheme; offset/count no longer apply.
sub_qd.pop("resultOffset", None)
sub_qd.pop("resultRecordCount", None)
sub_page = await _fetch_page_dict(
url,
session,
sub_qd,
**{k: v for k, v in request_kwargs.items() if k != "data"},
)
async for resolved in _resolve_page(
url,
session,
sub_page,
sub_qd,
on_truncation=on_truncation,
depth=depth + 1,
max_depth=max_depth,
request_kwargs=request_kwargs,
):
yield resolved
async def _iter_pages_raw(
url: str,
session: AsyncHTTPSession,
*,
order: Literal["request", "completion"] = "request",
max_concurrent_pages: int | None = None,
on_truncation: Literal["raise", "ignore", "split"] = "raise",
max_split_depth: int = 32,
span_layer_id: int | None = None,
span_out_fields: Any = None,
span_where: str | None = None,
**kwargs,
) -> AsyncGenerator[dict[str, Any]]:
"""Yield raw ArcGIS page envelopes for a FeatureLayer query.
Implements the streaming primitive that powers
:meth:`FeatureLayer.iter_pages`. See that method for the public
contract on ordering, concurrency, and truncation handling.
The optional ``span_*`` parameters carry the FeatureLayer-derived
attributes for the R-61 INTERNAL parent span so the caller does not
need to import telemetry helpers from ``restgdf.featurelayer`` (see
``tests/test_telemetry_no_dangling_imports_from_featurelayer.py``).
"""
if order not in ("request", "completion"):
raise ValueError(
f"order must be 'request' or 'completion', got {order!r}",
)
if on_truncation not in ("raise", "ignore", "split"):
raise ValueError(
"on_truncation must be 'raise', 'ignore', or 'split'; "
f"got {on_truncation!r}",
)
if max_concurrent_pages is not None and max_concurrent_pages < 1:
raise ValueError(
f"max_concurrent_pages must be >= 1, got {max_concurrent_pages!r}",
)
# R-61: open a NON-current INTERNAL span and end it from the outer
# ``finally:`` block. Using ``start_as_current_span`` here would attach
# an asyncio Context token that the async-generator machinery cannot
# safely detach when the consumer breaks early / calls ``aclose()`` /
# is cancelled, producing "Failed to detach context" errors and a
# leaked span. See rd-gate2-phase4a remediation.
span = start_feature_layer_stream_span(
layer_url=url,
layer_id=span_layer_id,
out_fields=span_out_fields,
where=span_where,
order=order,
)
tasks: list[asyncio.Task] = []
try:
query_data_batches = await get_query_data_batches(url, session, **kwargs)
fetch_kwargs = {k: v for k, v in kwargs.items() if k != "data"}
async def _fetch_bounded(query_data: dict) -> tuple[dict, dict[str, Any]]:
page = await _fetch_page_dict(
url,
session,
query_data,
**fetch_kwargs,
)
return query_data, page
if max_concurrent_pages is None:
tasks = [
asyncio.create_task(_fetch_bounded(qd)) for qd in query_data_batches
]
if order == "completion":
for fut in asyncio.as_completed(tasks):
query_data, page = await fut
async for resolved in _resolve_page(
url,
session,
page,
query_data,
on_truncation=on_truncation,
depth=0,
max_depth=max_split_depth,
request_kwargs=kwargs,
):
yield resolved
else:
for task in tasks:
query_data, page = await task
async for resolved in _resolve_page(
url,
session,
page,
query_data,
on_truncation=on_truncation,
depth=0,
max_depth=max_split_depth,
request_kwargs=kwargs,
):
yield resolved
return
batch_iter = iter(query_data_batches)
def _submit_next() -> asyncio.Task | None:
try:
query_data = next(batch_iter)
except StopIteration:
return None
task = asyncio.create_task(_fetch_bounded(query_data))
tasks.append(task)
return task
if order == "completion":
pending: set[asyncio.Task] = set()
for _ in range(max_concurrent_pages):
next_task = _submit_next()
if next_task is None:
break
pending.add(next_task)
while pending:
done, pending = await asyncio.wait(
pending,
return_when=asyncio.FIRST_COMPLETED,
)
completed_pages: list[tuple[dict, dict[str, Any]]] = []
for task in done:
query_data, page = await task
replacement = _submit_next()
if replacement is not None:
pending.add(replacement)
completed_pages.append((query_data, page))
for query_data, page in completed_pages:
async for resolved in _resolve_page(
url,
session,
page,
query_data,
on_truncation=on_truncation,
depth=0,
max_depth=max_split_depth,
request_kwargs=kwargs,
):
yield resolved
return
pending_in_order: list[asyncio.Task] = []
for _ in range(max_concurrent_pages):
next_task = _submit_next()
if next_task is None:
break
pending_in_order.append(next_task)
while pending_in_order:
task = pending_in_order.pop(0)
query_data, page = await task
replacement = _submit_next()
if replacement is not None:
pending_in_order.append(replacement)
async for resolved in _resolve_page(
url,
session,
page,
query_data,
on_truncation=on_truncation,
depth=0,
max_depth=max_split_depth,
request_kwargs=kwargs,
):
yield resolved
finally:
for task in tasks:
if not task.done():
task.cancel()
if span is not None:
span.end()