-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdev_dblptr.py
58 lines (41 loc) · 1.29 KB
/
dev_dblptr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# -*- coding: utf-8 -*-
__all__ = [
"Device_DblPtr",
]
from ctypes import cast, c_void_p
import numpy as np
from cuda_helpers import (cu_free, cu_malloc_dblptr)
dtype_map={np.dtype('f4') :0,
np.dtype('f8') :1,
np.dtype('c8') :2,
np.dtype('c16'):3}
class Device_DblPtr(object):
def __init__(self, device_ptr, n, batch_size):
"""
Allocates space for a double pointer on the device.
Parameters
----------
device_ptr : Device_Ptr
Original Device_Ptr object to map to double pointer.
"""
dev_dblptr = cu_malloc_dblptr(device_ptr.ptr,
n*n, batch_size,
dtype_map[device_ptr.dtype])
self.ptr = cast(dev_dblptr, c_void_p)
self.batch_size = batch_size
self.dtype = device_ptr.dtype
def __call__(self):
return self.ptr
def __len__(self):
return self.batch_size
def __repr__(self):
return repr(self.__dict__)
def __enter__(self):
return self
def __exit__(self, *args, **kwargs):
"""
Frees the memory used by the object, and then
deletes the object.
"""
cu_free(self.ptr)
del self