data.cpp
6.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
/*===--------------------------------------------------------------------------
* ATMI (Asynchronous Task and Memory Interface)
*
* This file is distributed under the MIT License. See LICENSE.txt for details.
*===------------------------------------------------------------------------*/
#include "data.h"
#include "atmi_runtime.h"
#include "internal.h"
#include "machine.h"
#include "rt.h"
#include <cassert>
#include <hsa.h>
#include <hsa_ext_amd.h>
#include <iostream>
#include <stdio.h>
#include <string.h>
#include <thread>
#include <vector>
using core::TaskImpl;
extern ATLMachine g_atl_machine;
namespace core {
ATLPointerTracker g_data_map; // Track all am pointer allocations.
void allow_access_to_all_gpu_agents(void *ptr);
const char *getPlaceStr(atmi_devtype_t type) {
switch (type) {
case ATMI_DEVTYPE_CPU:
return "CPU";
case ATMI_DEVTYPE_GPU:
return "GPU";
default:
return NULL;
}
}
std::ostream &operator<<(std::ostream &os, const ATLData *ap) {
atmi_mem_place_t place = ap->place();
os << " devicePointer:" << ap->ptr() << " sizeBytes:" << ap->size()
<< " place:(" << getPlaceStr(place.dev_type) << ", " << place.dev_id
<< ", " << place.mem_id << ")";
return os;
}
void ATLPointerTracker::insert(void *pointer, ATLData *p) {
std::lock_guard<std::mutex> l(mutex_);
DEBUG_PRINT("insert: %p + %zu\n", pointer, p->size());
tracker_.insert(std::make_pair(ATLMemoryRange(pointer, p->size()), p));
}
void ATLPointerTracker::remove(void *pointer) {
std::lock_guard<std::mutex> l(mutex_);
DEBUG_PRINT("remove: %p\n", pointer);
tracker_.erase(ATLMemoryRange(pointer, 1));
}
ATLData *ATLPointerTracker::find(const void *pointer) {
std::lock_guard<std::mutex> l(mutex_);
ATLData *ret = NULL;
auto iter = tracker_.find(ATLMemoryRange(pointer, 1));
DEBUG_PRINT("find: %p\n", pointer);
if (iter != tracker_.end()) // found
ret = iter->second;
return ret;
}
ATLProcessor &get_processor_by_mem_place(atmi_mem_place_t place) {
int dev_id = place.dev_id;
switch (place.dev_type) {
case ATMI_DEVTYPE_CPU:
return g_atl_machine.processors<ATLCPUProcessor>()[dev_id];
case ATMI_DEVTYPE_GPU:
return g_atl_machine.processors<ATLGPUProcessor>()[dev_id];
}
}
static hsa_agent_t get_mem_agent(atmi_mem_place_t place) {
return get_processor_by_mem_place(place).agent();
}
hsa_amd_memory_pool_t get_memory_pool_by_mem_place(atmi_mem_place_t place) {
ATLProcessor &proc = get_processor_by_mem_place(place);
return get_memory_pool(proc, place.mem_id);
}
void register_allocation(void *ptr, size_t size, atmi_mem_place_t place) {
ATLData *data = new ATLData(ptr, size, place);
g_data_map.insert(ptr, data);
if (place.dev_type == ATMI_DEVTYPE_CPU)
allow_access_to_all_gpu_agents(ptr);
// TODO(ashwinma): what if one GPU wants to access another GPU?
}
atmi_status_t Runtime::Malloc(void **ptr, size_t size, atmi_mem_place_t place) {
atmi_status_t ret = ATMI_STATUS_SUCCESS;
hsa_amd_memory_pool_t pool = get_memory_pool_by_mem_place(place);
hsa_status_t err = hsa_amd_memory_pool_allocate(pool, size, 0, ptr);
ErrorCheck(atmi_malloc, err);
DEBUG_PRINT("Malloced [%s %d] %p\n",
place.dev_type == ATMI_DEVTYPE_CPU ? "CPU" : "GPU", place.dev_id,
*ptr);
if (err != HSA_STATUS_SUCCESS)
ret = ATMI_STATUS_ERROR;
register_allocation(*ptr, size, place);
return ret;
}
atmi_status_t Runtime::Memfree(void *ptr) {
atmi_status_t ret = ATMI_STATUS_SUCCESS;
hsa_status_t err;
ATLData *data = g_data_map.find(ptr);
if (!data)
ErrorCheck(Checking pointer info userData,
HSA_STATUS_ERROR_INVALID_ALLOCATION);
g_data_map.remove(ptr);
delete data;
err = hsa_amd_memory_pool_free(ptr);
ErrorCheck(atmi_free, err);
DEBUG_PRINT("Freed %p\n", ptr);
if (err != HSA_STATUS_SUCCESS || !data)
ret = ATMI_STATUS_ERROR;
return ret;
}
static hsa_status_t invoke_hsa_copy(hsa_signal_t sig, void *dest,
const void *src, size_t size,
hsa_agent_t agent) {
const hsa_signal_value_t init = 1;
const hsa_signal_value_t success = 0;
hsa_signal_store_screlease(sig, init);
hsa_status_t err =
hsa_amd_memory_async_copy(dest, agent, src, agent, size, 0, NULL, sig);
if (err != HSA_STATUS_SUCCESS) {
return err;
}
// async_copy reports success by decrementing and failure by setting to < 0
hsa_signal_value_t got = init;
while (got == init) {
got = hsa_signal_wait_scacquire(sig, HSA_SIGNAL_CONDITION_NE, init,
UINT64_MAX, ATMI_WAIT_STATE);
}
if (got != success) {
return HSA_STATUS_ERROR;
}
return err;
}
struct atmiFreePtrDeletor {
void operator()(void *p) {
atmi_free(p); // ignore failure to free
}
};
atmi_status_t Runtime::Memcpy(hsa_signal_t sig, void *dest, const void *src,
size_t size) {
ATLData *src_data = g_data_map.find(src);
ATLData *dest_data = g_data_map.find(dest);
atmi_mem_place_t cpu = ATMI_MEM_PLACE_CPU_MEM(0, 0, 0);
void *temp_host_ptr;
atmi_status_t ret = atmi_malloc(&temp_host_ptr, size, cpu);
if (ret != ATMI_STATUS_SUCCESS) {
return ret;
}
std::unique_ptr<void, atmiFreePtrDeletor> del(temp_host_ptr);
if (src_data && !dest_data) {
// Copy from device to scratch to host
hsa_agent_t agent = get_mem_agent(src_data->place());
DEBUG_PRINT("Memcpy D2H device agent: %lu\n", agent.handle);
if (invoke_hsa_copy(sig, temp_host_ptr, src, size, agent) !=
HSA_STATUS_SUCCESS) {
return ATMI_STATUS_ERROR;
}
memcpy(dest, temp_host_ptr, size);
} else if (!src_data && dest_data) {
// Copy from host to scratch to device
hsa_agent_t agent = get_mem_agent(dest_data->place());
DEBUG_PRINT("Memcpy H2D device agent: %lu\n", agent.handle);
memcpy(temp_host_ptr, src, size);
if (invoke_hsa_copy(sig, dest, temp_host_ptr, size, agent) !=
HSA_STATUS_SUCCESS) {
return ATMI_STATUS_ERROR;
}
} else if (!src_data && !dest_data) {
// would be host to host, just call memcpy, or missing metadata
DEBUG_PRINT("atmi_memcpy invoked without metadata\n");
return ATMI_STATUS_ERROR;
} else {
DEBUG_PRINT("atmi_memcpy unimplemented device to device copy\n");
return ATMI_STATUS_ERROR;
}
return ATMI_STATUS_SUCCESS;
}
} // namespace core