_backend.py
4.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import scipy._lib.uarray as ua
from . import _pocketfft
class _ScipyBackend:
"""The default backend for fft calculations
Notes
-----
We use the domain ``numpy.scipy`` rather than ``scipy`` because in the
future ``uarray`` will treat the domain as a hierarchy. This means the user
can install a single backend for ``numpy`` and have it implement
``numpy.scipy.fft`` as well.
"""
__ua_domain__ = "numpy.scipy.fft"
@staticmethod
def __ua_function__(method, args, kwargs):
fn = getattr(_pocketfft, method.__name__, None)
if fn is None:
return NotImplemented
return fn(*args, **kwargs)
_named_backends = {
'scipy': _ScipyBackend,
}
def _backend_from_arg(backend):
"""Maps strings to known backends and validates the backend"""
if isinstance(backend, str):
try:
backend = _named_backends[backend]
except KeyError:
raise ValueError('Unknown backend {}'.format(backend))
if backend.__ua_domain__ != 'numpy.scipy.fft':
raise ValueError('Backend does not implement "numpy.scipy.fft"')
return backend
def set_global_backend(backend):
"""Sets the global fft backend
The global backend has higher priority than registered backends, but lower
priority than context-specific backends set with `set_backend`.
Parameters
----------
backend: {object, 'scipy'}
The backend to use.
Can either be a ``str`` containing the name of a known backend
{'scipy'}, or an object that implements the uarray protocol.
Raises
------
ValueError: If the backend does not implement ``numpy.scipy.fft``
Notes
-----
This will overwrite the previously set global backend, which by default is
the SciPy implementation.
"""
backend = _backend_from_arg(backend)
ua.set_global_backend(backend)
def register_backend(backend):
"""
Register a backend for permanent use.
Registered backends have the lowest priority and will be tried after the
global backend.
Parameters
----------
backend: {object, 'scipy'}
The backend to use.
Can either be a ``str`` containing the name of a known backend
{'scipy'}, or an object that implements the uarray protocol.
Raises
------
ValueError: If the backend does not implement ``numpy.scipy.fft``
"""
backend = _backend_from_arg(backend)
ua.register_backend(backend)
def set_backend(backend, coerce=False, only=False):
"""Context manager to set the backend within a fixed scope.
Upon entering the ``with`` statement, the given backend will be added to
the list of available backends with the highest priority. Upon exit, the
backend is reset to the state before entering the scope.
Parameters
----------
backend: {object, 'scipy'}
The backend to use.
Can either be a ``str`` containing the name of a known backend
{'scipy'}, or an object that implements the uarray protocol.
coerce: bool, optional
Whether to allow expensive conversions for the ``x`` parameter. e.g.
copying a numpy array to the GPU for a CuPy backend. Implies ``only``.
only: bool, optional
If only is ``True`` and this backend returns ``NotImplemented`` then a
BackendNotImplemented error will be raised immediately. Ignoring any
lower priority backends.
Examples
--------
>>> import scipy.fft as fft
>>> with fft.set_backend('scipy', only=True):
... fft.fft([1]) # Always calls the scipy implementation
array([1.+0.j])
"""
backend = _backend_from_arg(backend)
return ua.set_backend(backend, coerce=coerce, only=only)
def skip_backend(backend):
"""Context manager to skip a backend within a fixed scope.
Within the context of a ``with`` statement, the given backend will not be
called. This covers backends registered both locally and globally. Upon
exit, the backend will again be considered.
Parameters
----------
backend: {object, 'scipy'}
The backend to skip.
Can either be a ``str`` containing the name of a known backend
{'scipy'}, or an object that implements the uarray protocol.
Examples
--------
>>> import scipy.fft as fft
>>> fft.fft([1]) # Calls default scipy backend
array([1.+0.j])
>>> with fft.skip_backend('scipy'): # We expicitly skip the scipy backend
... fft.fft([1]) # leaving no implementation available
Traceback (most recent call last):
...
BackendNotImplementedError: No selected backends had an implementation ...
"""
backend = _backend_from_arg(backend)
return ua.skip_backend(backend)
set_global_backend('scipy')