[go: up one dir, main page]

Skip to content

Commit

Permalink
perf: DB-API uses more efficient query_and_wait when no job ID is p…
Browse files Browse the repository at this point in the history
…rovided (#1747)

Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly:
- [ ] Make sure to open an issue as a [bug/issue](https://togithub.com/googleapis/python-bigquery/issues/new/choose) before writing your code!  That way we can discuss the change, evaluate designs, and agree on the general idea
- [ ] Ensure the tests and linter pass
- [ ] Code coverage does not decrease (if any source code was changed)
- [ ] Appropriate docs were updated (if necessary)

Fixes #1745  🦕
  • Loading branch information
tswast committed Dec 19, 2023
1 parent 02a7d12 commit d225a94
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 115 deletions.
1 change: 1 addition & 0 deletions google/cloud/bigquery/_job_helpers.py
Expand Up @@ -491,6 +491,7 @@ def do_query():
job_id=query_results.job_id,
query_id=query_results.query_id,
project=query_results.project,
num_dml_affected_rows=query_results.num_dml_affected_rows,
)

if job_retry is not None:
Expand Down
6 changes: 6 additions & 0 deletions google/cloud/bigquery/client.py
Expand Up @@ -3963,6 +3963,7 @@ def _list_rows_from_query_results(
timeout: TimeoutType = DEFAULT_TIMEOUT,
query_id: Optional[str] = None,
first_page_response: Optional[Dict[str, Any]] = None,
num_dml_affected_rows: Optional[int] = None,
) -> RowIterator:
"""List the rows of a completed query.
See
Expand Down Expand Up @@ -4007,6 +4008,10 @@ def _list_rows_from_query_results(
and not guaranteed to be populated.
first_page_response (Optional[dict]):
API response for the first page of results (if available).
num_dml_affected_rows (Optional[int]):
If this RowIterator is the result of a DML query, the number of
rows that were affected.
Returns:
google.cloud.bigquery.table.RowIterator:
Iterator of row data
Expand Down Expand Up @@ -4047,6 +4052,7 @@ def _list_rows_from_query_results(
job_id=job_id,
query_id=query_id,
first_page_response=first_page_response,
num_dml_affected_rows=num_dml_affected_rows,
)
return row_iterator

Expand Down
122 changes: 69 additions & 53 deletions google/cloud/bigquery/dbapi/cursor.py
Expand Up @@ -14,11 +14,12 @@

"""Cursor for the Google BigQuery DB-API."""

from __future__ import annotations

import collections
from collections import abc as collections_abc
import copy
import logging
import re
from typing import Optional

try:
from google.cloud.bigquery_storage import ArrowSerializationOptions
Expand All @@ -34,8 +35,6 @@
import google.cloud.exceptions # type: ignore


_LOGGER = logging.getLogger(__name__)

# Per PEP 249: A 7-item sequence containing information describing one result
# column. The first two items (name and type_code) are mandatory, the other
# five are optional and are set to None if no meaningful values can be
Expand Down Expand Up @@ -76,18 +75,31 @@ def __init__(self, connection):
# most appropriate size.
self.arraysize = None
self._query_data = None
self._query_job = None
self._query_rows = None
self._closed = False

@property
def query_job(self):
"""google.cloud.bigquery.job.query.QueryJob: The query job created by
the last ``execute*()`` call.
def query_job(self) -> Optional[job.QueryJob]:
"""google.cloud.bigquery.job.query.QueryJob | None: The query job
created by the last ``execute*()`` call, if a query job was created.
.. note::
If the last ``execute*()`` call was ``executemany()``, this is the
last job created by ``executemany()``."""
return self._query_job
rows = self._query_rows

if rows is None:
return None

job_id = rows.job_id
project = rows.project
location = rows.location
client = self.connection._client

if job_id is None:
return None

return client.get_job(job_id, location=location, project=project)

def close(self):
"""Mark the cursor as closed, preventing its further use."""
Expand Down Expand Up @@ -117,8 +129,8 @@ def _set_description(self, schema):
for field in schema
)

def _set_rowcount(self, query_results):
"""Set the rowcount from query results.
def _set_rowcount(self, rows):
"""Set the rowcount from a RowIterator.
Normally, this sets rowcount to the number of rows returned by the
query, but if it was a DML statement, it sets rowcount to the number
Expand All @@ -129,10 +141,10 @@ def _set_rowcount(self, query_results):
Results of a query.
"""
total_rows = 0
num_dml_affected_rows = query_results.num_dml_affected_rows
num_dml_affected_rows = rows.num_dml_affected_rows

if query_results.total_rows is not None and query_results.total_rows > 0:
total_rows = query_results.total_rows
if rows.total_rows is not None and rows.total_rows > 0:
total_rows = rows.total_rows
if num_dml_affected_rows is not None and num_dml_affected_rows > 0:
total_rows = num_dml_affected_rows
self.rowcount = total_rows
Expand Down Expand Up @@ -165,9 +177,10 @@ def execute(self, operation, parameters=None, job_id=None, job_config=None):
parameters (Union[Mapping[str, Any], Sequence[Any]]):
(Optional) dictionary or sequence of parameter values.
job_id (str):
(Optional) The job_id to use. If not set, a job ID
is generated at random.
job_id (str | None):
(Optional and discouraged) The job ID to use when creating
the query job. For best performance and reliability, manually
setting a job ID is discouraged.
job_config (google.cloud.bigquery.job.QueryJobConfig):
(Optional) Extra configuration options for the query job.
Expand All @@ -181,7 +194,7 @@ def _execute(
self, formatted_operation, parameters, job_id, job_config, parameter_types
):
self._query_data = None
self._query_job = None
self._query_results = None
client = self.connection._client

# The DB-API uses the pyformat formatting, since the way BigQuery does
Expand All @@ -190,33 +203,35 @@ def _execute(
# libraries.
query_parameters = _helpers.to_query_parameters(parameters, parameter_types)

if client._default_query_job_config:
if job_config:
config = job_config._fill_from_default(client._default_query_job_config)
else:
config = copy.deepcopy(client._default_query_job_config)
else:
config = job_config or job.QueryJobConfig(use_legacy_sql=False)

config = job_config or job.QueryJobConfig()
config.query_parameters = query_parameters
self._query_job = client.query(
formatted_operation, job_config=config, job_id=job_id
)

if self._query_job.dry_run:
self._set_description(schema=None)
self.rowcount = 0
return

# Wait for the query to finish.
# Start the query and wait for the query to finish.
try:
self._query_job.result()
if job_id is not None:
rows = client.query(
formatted_operation,
job_config=job_config,
job_id=job_id,
).result(
page_size=self.arraysize,
)
else:
rows = client.query_and_wait(
formatted_operation,
job_config=config,
page_size=self.arraysize,
)
except google.cloud.exceptions.GoogleCloudError as exc:
raise exceptions.DatabaseError(exc)

query_results = self._query_job._query_results
self._set_rowcount(query_results)
self._set_description(query_results.schema)
self._query_rows = rows
self._set_description(rows.schema)

if config.dry_run:
self.rowcount = 0
else:
self._set_rowcount(rows)

def executemany(self, operation, seq_of_parameters):
"""Prepare and execute a database operation multiple times.
Expand Down Expand Up @@ -250,25 +265,26 @@ def _try_fetch(self, size=None):
Mutates self to indicate that iteration has started.
"""
if self._query_job is None:
if self._query_data is not None:
# Already started fetching the data.
return

rows = self._query_rows
if rows is None:
raise exceptions.InterfaceError(
"No query results: execute() must be called before fetch."
)

if self._query_job.dry_run:
self._query_data = iter([])
bqstorage_client = self.connection._bqstorage_client
if rows._should_use_bqstorage(
bqstorage_client,
create_bqstorage_client=False,
):
rows_iterable = self._bqstorage_fetch(bqstorage_client)
self._query_data = _helpers.to_bq_table_rows(rows_iterable)
return

if self._query_data is None:
bqstorage_client = self.connection._bqstorage_client

if bqstorage_client is not None:
rows_iterable = self._bqstorage_fetch(bqstorage_client)
self._query_data = _helpers.to_bq_table_rows(rows_iterable)
return

rows_iter = self._query_job.result(page_size=self.arraysize)
self._query_data = iter(rows_iter)
self._query_data = iter(rows)

def _bqstorage_fetch(self, bqstorage_client):
"""Start fetching data with the BigQuery Storage API.
Expand All @@ -290,7 +306,7 @@ def _bqstorage_fetch(self, bqstorage_client):
# bigquery_storage can indeed be imported here without errors.
from google.cloud import bigquery_storage

table_reference = self._query_job.destination
table_reference = self._query_rows._table

requested_session = bigquery_storage.types.ReadSession(
table=table_reference.to_bqstorage(),
Expand Down
2 changes: 2 additions & 0 deletions google/cloud/bigquery/job/query.py
Expand Up @@ -1614,6 +1614,7 @@ def do_get_result():
project=self.project,
job_id=self.job_id,
query_id=self.query_id,
num_dml_affected_rows=self._query_results.num_dml_affected_rows,
)

# We know that there's at least 1 row, so only treat the response from
Expand All @@ -1639,6 +1640,7 @@ def do_get_result():
timeout=timeout,
query_id=self.query_id,
first_page_response=first_page_response,
num_dml_affected_rows=self._query_results.num_dml_affected_rows,
)
rows._preserve_order = _contains_order_by(self.query)
return rows
Expand Down
32 changes: 24 additions & 8 deletions google/cloud/bigquery/table.py
Expand Up @@ -1566,6 +1566,7 @@ def __init__(
job_id: Optional[str] = None,
query_id: Optional[str] = None,
project: Optional[str] = None,
num_dml_affected_rows: Optional[int] = None,
):
super(RowIterator, self).__init__(
client,
Expand All @@ -1592,6 +1593,7 @@ def __init__(
self._job_id = job_id
self._query_id = query_id
self._project = project
self._num_dml_affected_rows = num_dml_affected_rows

@property
def _billing_project(self) -> Optional[str]:
Expand All @@ -1616,6 +1618,16 @@ def location(self) -> Optional[str]:
"""
return self._location

@property
def num_dml_affected_rows(self) -> Optional[int]:
"""If this RowIterator is the result of a DML query, the number of
rows that were affected.
See:
https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/query#body.QueryResponse.FIELDS.num_dml_affected_rows
"""
return self._num_dml_affected_rows

@property
def project(self) -> Optional[str]:
"""GCP Project ID where these rows are read from."""
Expand All @@ -1635,7 +1647,10 @@ def _is_almost_completely_cached(self):
This is useful to know, because we can avoid alternative download
mechanisms.
"""
if self._first_page_response is None:
if (
not hasattr(self, "_first_page_response")
or self._first_page_response is None
):
return False

total_cached_rows = len(self._first_page_response.get(self._items_key, []))
Expand All @@ -1655,7 +1670,7 @@ def _is_almost_completely_cached(self):

return False

def _validate_bqstorage(self, bqstorage_client, create_bqstorage_client):
def _should_use_bqstorage(self, bqstorage_client, create_bqstorage_client):
"""Returns True if the BigQuery Storage API can be used.
Returns:
Expand All @@ -1669,8 +1684,9 @@ def _validate_bqstorage(self, bqstorage_client, create_bqstorage_client):
if self._table is None:
return False

# The developer is manually paging through results if this is set.
if self.next_page_token is not None:
# The developer has already started paging through results if
# next_page_token is set.
if hasattr(self, "next_page_token") and self.next_page_token is not None:
return False

if self._is_almost_completely_cached():
Expand Down Expand Up @@ -1726,7 +1742,7 @@ def schema(self):

@property
def total_rows(self):
"""int: The total number of rows in the table."""
"""int: The total number of rows in the table or query results."""
return self._total_rows

def _maybe_warn_max_results(
Expand All @@ -1752,7 +1768,7 @@ def _maybe_warn_max_results(
def _to_page_iterable(
self, bqstorage_download, tabledata_list_download, bqstorage_client=None
):
if not self._validate_bqstorage(bqstorage_client, False):
if not self._should_use_bqstorage(bqstorage_client, False):
bqstorage_client = None

result_pages = (
Expand Down Expand Up @@ -1882,7 +1898,7 @@ def to_arrow(

self._maybe_warn_max_results(bqstorage_client)

if not self._validate_bqstorage(bqstorage_client, create_bqstorage_client):
if not self._should_use_bqstorage(bqstorage_client, create_bqstorage_client):
create_bqstorage_client = False
bqstorage_client = None

Expand Down Expand Up @@ -2223,7 +2239,7 @@ def to_dataframe(

self._maybe_warn_max_results(bqstorage_client)

if not self._validate_bqstorage(bqstorage_client, create_bqstorage_client):
if not self._should_use_bqstorage(bqstorage_client, create_bqstorage_client):
create_bqstorage_client = False
bqstorage_client = None

Expand Down

0 comments on commit d225a94

Please sign in to comment.