Source code for invenio_records_rest.facets

# -*- coding: utf-8 -*-
#
# This file is part of Invenio.
# Copyright (C) 2016-2018 CERN.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.

"""Facets and factories for result filtering and aggregation.

See :data:`invenio_records_rest.config.RECORDS_REST_FACETS` for more
information on how to specify aggregations and filters.
"""

from flask import current_app, request
from invenio_rest.errors import FieldError, RESTValidationError
from invenio_search.engine import dsl
from six import text_type
from werkzeug.datastructures import MultiDict

from invenio_records_rest.utils import make_comma_list_a_list


[docs]def terms_filter(field): """Create a term filter. :param field: Field name. :returns: Function that returns the Terms query. """ def inner(values): return dsl.Q("terms", **{field: values}) return inner
[docs]def range_filter(field, start_date_math=None, end_date_math=None, **kwargs): """Create a range filter. :param field: Field name. :param start_date_math: Starting date. :param end_date_math: Ending date. :param kwargs: Addition arguments passed to the Range query. :returns: Function that returns the Range query. """ def inner(values): if len(values) != 1 or values[0].count("--") != 1 or values[0] == "--": raise RESTValidationError( errors=[FieldError(field, "Invalid range format.")] ) range_ends = values[0].split("--") range_args = dict() ineq_opers = [ {"strict": "gt", "nonstrict": "gte"}, {"strict": "lt", "nonstrict": "lte"}, ] date_maths = [start_date_math, end_date_math] # Add the proper values to the dict for range_end, strict, opers, date_math in zip( range_ends, [">", "<"], ineq_opers, date_maths ): if range_end != "": # If first char is '>' for start or '<' for end if range_end[0] == strict: dict_key = opers["strict"] range_end = range_end[1:] else: dict_key = opers["nonstrict"] if date_math: range_end = "{0}||{1}".format(range_end, date_math) range_args[dict_key] = range_end args = kwargs.copy() args.update(range_args) return dsl.query.Range(**{field: args}) return inner
def _create_filter_dsl(urlkwargs, definitions): """Create a filter DSL expression.""" filters = [] for name, filter_factory in definitions.items(): values = request.values.getlist(name, type=text_type) if values: filters.append(filter_factory(values)) for v in values: urlkwargs.add(name, v) return (filters, urlkwargs) def _post_filter(search, urlkwargs, definitions): """Ingest post filter in query.""" filters, urlkwargs = _create_filter_dsl(urlkwargs, definitions) for filter_ in filters: search = search.post_filter(filter_) return (search, urlkwargs) def _query_filter(search, urlkwargs, definitions): """Ingest query filter in query.""" filters, urlkwargs = _create_filter_dsl(urlkwargs, definitions) for filter_ in filters: search = search.filter(filter_) return (search, urlkwargs) def _aggregations(search, definitions): """Add aggregations to query.""" if definitions: for name, agg in definitions.items(): search.aggs[name] = agg if not callable(agg) else agg() return search
[docs]def default_facets_factory(search, index): """Add a default facets to query. It's possible to select facets which should be added to query by passing their name in `facets` parameter. :param search: Basic search object. :param index: Index name. :returns: A tuple containing the new search object and a dictionary with all fields and values used. """ urlkwargs = MultiDict() facets = current_app.config["RECORDS_REST_FACETS"].get(index) if facets is not None: # Aggregations. # First get requested facets, also split by ',' to get facets names # if they were provided as list separated by comma. selected_facets = make_comma_list_a_list(request.args.getlist("facets", None)) all_aggs = facets.get("aggs", {}) # If no facets were requested, assume default behaviour - Take all. if not selected_facets: search = _aggregations(search, all_aggs) # otherwise, check if there are facets to chose elif selected_facets and all_aggs: aggs = {} # Go through all available facets and check if they were requested. for facet_name, facet_body in all_aggs.items(): if facet_name in selected_facets: aggs.update({facet_name: facet_body}) search = _aggregations(search, aggs) # Query filter search, urlkwargs = _query_filter(search, urlkwargs, facets.get("filters", {})) # Post filter search, urlkwargs = _post_filter( search, urlkwargs, facets.get("post_filters", {}) ) return (search, urlkwargs)