diff --git a/sycl/doc/syclcompat/README.md b/sycl/doc/syclcompat/README.md index 6b776cfb777b3..3ffde6f224493 100644 --- a/sycl/doc/syclcompat/README.md +++ b/sycl/doc/syclcompat/README.md @@ -844,6 +844,15 @@ static inline sycl::context get_default_context(); // Util function to get a CPU device. static inline device_ext &cpu_device(); +/// Filter out devices; only keep the device whose name contains one of the +/// subname in \p dev_subnames. +/// May break device id mapping and change current device. It's better to be +/// called before other SYCLcompat or SYCL APIs. +static inline void filter_device(const std::vector &dev_subnames); + +/// Print all the devices (and their IDs) in the dev_mgr +static inline void list_devices(); + // Util function to select a device by its id static inline unsigned int select_device(unsigned int id); @@ -868,6 +877,12 @@ can be queried through `device_ext` as well. throws a `sycl::exception` if the device does not have the specified list of `sycl::aspect`. +Devices can be listed and filtered using `syclcompat::list_devices()` and +`syclcompat::filter_device()`. If `SYCLCOMPAT_VERBOSE` is defined at compile +time, the available SYCL devices are printed to the standard output both at +initialization time, and when the device list is filtered using +`syclcompat::filter_device`. + Users can manage queues through the `syclcompat::set_default_queue(sycl::queue q)` free function, and the `device_ext` `set_saved_queue`, `set_default_queue`, and `get_saved_queue` member functions. diff --git a/sycl/include/syclcompat/device.hpp b/sycl/include/syclcompat/device.hpp index 399efbd8b8933..080fac3ef5275 100644 --- a/sycl/include/syclcompat/device.hpp +++ b/sycl/include/syclcompat/device.hpp @@ -726,14 +726,64 @@ class dev_mgr { unsigned int device_count() { return _devs.size(); } unsigned int get_device_id(const sycl::device &dev) { + if (!_devs.size()) { + throw std::runtime_error( + "[SYCLcompat] No SYCL devices found in the device list. Device list " + "may have been filtered by syclcompat::filter_device"); + } unsigned int id = 0; for (auto dev_item : _devs) { if (*dev_item == dev) { - break; + return id; } id++; } - return id; + throw std::runtime_error("[SYCLcompat] The device[" + + dev.get_info() + + "] is filtered out by syclcompat::filter_device " + "in current device list!"); + } + + /// List all the devices with its id in dev_mgr. + void list_devices() const { + for (size_t i = 0; i < _devs.size(); ++i) { + std::cout << "Device " << i << ": " + << _devs[i]->get_info() << std::endl; + } + } + + /// Filter out devices; only keep the device whose name contains one of the + /// subname in \p dev_subnames. + /// May break device id mapping and change current device. It's better to be + /// called before other SYCLcompat/SYCL APIs. + void filter(const std::vector &dev_subnames) { + std::lock_guard lock(m_mutex); + auto iter = _devs.begin(); + while (iter != _devs.end()) { + std::string dev_name = (*iter)->get_info(); + bool matched = false; + for (const auto &name : dev_subnames) { + if (dev_name.find(name) != std::string::npos) { + matched = true; + break; + } + } + if (matched) + ++iter; + else + iter = _devs.erase(iter); + } + _cpu_device = -1; + for (unsigned i = 0; i < _devs.size(); ++i) { + if (_devs[i]->is_cpu()) { + _cpu_device = i; + break; + } + } + _thread2dev_map.clear(); +#ifdef SYCLCOMPAT_VERBOSE + list_devices(); +#endif } /// Select device with a Device Selector @@ -779,6 +829,9 @@ class dev_mgr { _cpu_device = _devs.size() - 1; } } +#ifdef SYCLCOMPAT_VERBOSE + list_devices(); +#endif } void check_id(unsigned int id) const { if (id >= _devs.size()) { @@ -853,6 +906,19 @@ static inline device_ext &cpu_device() { return detail::dev_mgr::instance().cpu_device(); } +/// Filter out devices; only keep the device whose name contains one of the +/// subname in \p dev_subnames. +/// May break device id mapping and change current device. It's better to be +/// called before other SYCLcompat or SYCL APIs. +static inline void filter_device(const std::vector &dev_subnames) { + detail::dev_mgr::instance().filter(dev_subnames); +} + +/// List all the devices with its id in dev_mgr. +static inline void list_devices() { + detail::dev_mgr::instance().list_devices(); +} + static inline unsigned int select_device(unsigned int id) { detail::dev_mgr::instance().select_device(id); return id; diff --git a/sycl/test-e2e/syclcompat/device/device.cpp b/sycl/test-e2e/syclcompat/device/device.cpp index 9e4c8edcd91c9..0845859c5d55a 100644 --- a/sycl/test-e2e/syclcompat/device/device.cpp +++ b/sycl/test-e2e/syclcompat/device/device.cpp @@ -359,6 +359,24 @@ void test_max_nd_range() { #endif } +void test_list_devices() { + std::cout << __PRETTY_FUNCTION__ << std::endl; + DeviceTestsFixt dtf; + + // Redirect std::cout to count new lines + CountingStream countingBuf(std::cout.rdbuf()); + std::streambuf *orig_buf = std::cout.rdbuf(); + std::cout.rdbuf(&countingBuf); + + syclcompat::list_devices(); + + // Restore back std::cout + std::cout.rdbuf(orig_buf); + + // Expected one line per device + assert(countingBuf.get_line_count() == dtf.get_n_devices()); +} + int main() { test_at_least_one_device(); test_matches_id(); @@ -377,6 +395,7 @@ int main() { test_version_parsing(); test_image_max_attrs(); test_max_nd_range(); + test_list_devices(); return 0; } diff --git a/sycl/test-e2e/syclcompat/device/device_filter.cpp b/sycl/test-e2e/syclcompat/device/device_filter.cpp new file mode 100644 index 0000000000000..3f03432401b0a --- /dev/null +++ b/sycl/test-e2e/syclcompat/device/device_filter.cpp @@ -0,0 +1,78 @@ +/*************************************************************************** + * + * Copyright (C) Codeplay Software Ltd. + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM + * Exceptions. See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SYCLcompat API + * + * device_filter.cpp + * + * Description: + * Device filtering tests + **************************************************************************/ + +// RUN: %clangxx -fsycl -fsycl-targets=%{sycl_triple} %s -o %t.out +// RUN: %{run} %t.out + +#include + +void test_filtering_existing_device() { + auto &dev = syclcompat::get_current_device(); + std::string dev_name = dev.get_info(); + + syclcompat::filter_device({dev_name}); + try { + syclcompat::get_device_id(dev); + } catch (std::runtime_error const &e) { + std::cout << " Unexpected SYCL exception caught: " << e.what() + << std::endl; + assert(0); + } + + // Checks for a substring of the device as well + std::string dev_substr = dev_name.substr(1, dev_name.find(" ") + 2); + syclcompat::filter_device({dev_substr}); + try { + syclcompat::get_device_id(dev); + } catch (std::runtime_error const &e) { + std::cout << " Unexpected SYCL exception caught: " << e.what() + << std::endl; + assert(0); + } +} + +void test_filter_devices() { + auto &dev = syclcompat::get_current_device(); + + assert(syclcompat::detail::dev_mgr::instance().device_count() > 0); + + syclcompat::filter_device({"NON-EXISTENT DEVICE"}); + assert(syclcompat::detail::dev_mgr::instance().device_count() == 0); + + try { + syclcompat::get_device_id(dev); + assert(0); + } catch (std::runtime_error const &e) { + std::cout << " Expected SYCL exception caught: " << e.what() << std::endl; + } +} + +int main() { + // syclcompat::dev_mgr is a singleton, so any changes to the device list is + // permanent between tests. Test isolated instead of relying on it being the + // last test in a different test suite. + test_filtering_existing_device(); + + test_filter_devices(); + + return 0; +} diff --git a/sycl/test-e2e/syclcompat/device/device_fixt.hpp b/sycl/test-e2e/syclcompat/device/device_fixt.hpp index 3a68eaf2317f1..ac0cc867a08d9 100644 --- a/sycl/test-e2e/syclcompat/device/device_fixt.hpp +++ b/sycl/test-e2e/syclcompat/device/device_fixt.hpp @@ -50,3 +50,32 @@ class DeviceExtFixt { syclcompat::device_ext &get_dev_ext() { return dev_; } }; + +// Helper for counting the output lines of syclcompat::list_devices +// Used to override std::cout +class CountingStream : public std::streambuf { +public: + CountingStream(std::streambuf *buf) : buf(buf), line_count(0) {} + + int overflow(int c) override { + if (c == '\n') { + ++line_count; + } + return buf->sputc(c); + } + + std::streamsize xsputn(const char_type *s, std::streamsize count) override { + for (std::streamsize i = 0; i < count; ++i) { + if (s[i] == '\n') { + ++line_count; + } + } + return buf->sputn(s, count); + } + + int get_line_count() const { return line_count; } + +private: + std::streambuf *buf; + int line_count; +};