Skip to content

Commit 65baaa9

Browse files
committed
[ROCm] create rocm_sdk_core_jll
and introduce a platform augmentation mechanism for the ROCm platform. Ref JuliaPackaging#12672
1 parent eb45c09 commit 65baaa9

File tree

3 files changed

+356
-0
lines changed

3 files changed

+356
-0
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
using BinaryBuilder, Pkg
2+
3+
const YGGDRASIL_DIR = "../../.."
4+
include(joinpath(YGGDRASIL_DIR, "fancy_toys.jl"))
5+
include(joinpath(YGGDRASIL_DIR, "platforms", "rocm.jl"))
6+
7+
name = "rocm_sdk_core"
8+
version = v"7.0.01120251130"
9+
10+
augment_platform_block = read(joinpath(@__DIR__, "platform_augmentation.jl"), String)
11+
12+
script = raw"""
13+
cd ${WORKSPACE}/srcdir
14+
15+
unzip rocm_sdk_core-*.whl
16+
17+
# Copy the rocm_sysdeps folder
18+
cp -rv _rocm_sdk_core/* ${prefix}/
19+
20+
install_license LICENSE.md
21+
"""
22+
23+
products = [
24+
ExecutableProduct("ld.lld", :lld, "lib/llvm/bin"),
25+
LibraryProduct("libamdhip64", :libhip),
26+
]
27+
28+
dependencies = [
29+
Dependency(PackageSpec(name="CompilerSupportLibraries_jll", uuid="e66e0078-7015-5450-92f7-15fbd957f2ae"))
30+
RuntimeDependency(PackageSpec(name="HSARuntime_jll", uuid="0a197bc1-b33e-53f1-a9ca-cd02b99357ac"); compat = "7")
31+
]
32+
33+
34+
# determine exactly which tarballs we should build
35+
builds = []
36+
for augmented_platform in ROCm.supported_platforms()
37+
should_build_platform(triplet(augmented_platform)) || continue
38+
39+
p = augmented_platform["rocm_platform"]
40+
sha256sum = Dict(
41+
"gfx101x_dgpu" => "4a1903f4afece374b008d376825a81b9d0d5901844c78db0dce82c17c0c66f8f",
42+
"gfx103x_dgpu" => "dd4e9eceb3bc93b4f235e27d2572bedae470e32f03609487129fba14f6d512b2",
43+
"gfx110x_all" => "45327fb6874797c104275617541e6cce9b706159a962f17b410e06d9c4f66008",
44+
#"gfx110x_dgpu" => "b0d27556dd07d30345624eb487ca2c7cf40060ec5f94ba6e3d788c2fead4b345",
45+
"gfx1150" => "01250b8baa92d45f0af2a32456db8e2d6d42f575afb781ef1b8fee47fe644ed2",
46+
"gfx1151" => "565b7e96e04b3f0cdd743eef5e498f697ee407c70c9a0b3702f10d3d0dcb6fc5",
47+
"gfx120x_all" => "bbeb58b80951aa0ba4c5df8ce49fc6a493012ecc49910574175fd84c948b1eb2",
48+
"gfx90x_dcgpu" => "1a1ae75beabba18d5a7d94942f128801ba2448a0e4cdc74da5d9c4b26789bdbe",
49+
"gfx94x_dcgpu" => "f18fda3295a8e6aa54f2d04b7dcb4631c0c7a2aac57fc774fbacf32f6071bbe0",
50+
"gfx950_dcgpu" => "923a5fccd4dbd45a13d89e77aa83313e5f5aafb870caddcd0595b741b1886b59",
51+
)[p]
52+
p = replace(p, "x_" => "X-", "_" => "-")
53+
sources = [
54+
FileSource("https://rocm.nightlies.amd.com/v2/$p/rocm_sdk_core-7.11.0a20251130-py3-none-linux_x86_64.whl",
55+
sha256sum),
56+
FileSource("https://raw.githubusercontent.com/ROCm/rocm-systems/fd61b0f5073a6c4c3b6693532d3cfb8972b1951f/projects/hip/LICENSE.md",
57+
"b185aaa652b0bf066c37a0d6314ce4bf4521e4a3c9bf46edd2f6a777ac522223"),
58+
]
59+
60+
push!(builds,
61+
(; platforms=[augmented_platform], sources)
62+
)
63+
end
64+
65+
# don't allow `build_tarballs` to override platform selection based on ARGS.
66+
# we handle that ourselves by calling `should_build_platform`
67+
non_platform_ARGS = filter(arg -> startswith(arg, "--"), ARGS)
68+
69+
# `--register` should only be passed to the latest `build_tarballs` invocation
70+
non_reg_ARGS = filter(arg -> arg != "--register", non_platform_ARGS)
71+
72+
for (i,build) in enumerate(builds)
73+
build_tarballs(i == lastindex(builds) ? non_platform_ARGS : non_reg_ARGS,
74+
name, version, build.sources, script,
75+
build.platforms, products, dependencies;
76+
skip_audit = true, julia_compat="1.6", lazy_artifacts=true, augment_platform_block)
77+
end
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
using Base.BinaryPlatforms
2+
3+
const rocm_sdk_core_jll_uuid = Base.UUID("9ab9228b-5f62-5ec0-95ae-72487824505f")
4+
const preferences = Base.get_preferences(rocm_sdk_core_jll_uuid)
5+
Base.record_compiletime_preference(rocm_sdk_core_jll_uuid, "local")
6+
7+
const local_preference = if haskey(preferences, "local")
8+
if isa(preferences["local"], Bool)
9+
preferences["local"]
10+
elseif isa(preferences["local"], String)
11+
use_local = tryparse(Bool, preferences["local"])
12+
if use_local === nothing
13+
@error "ROCm local preference is not valid; expected a boolean, but got '$(preferences["local"])'"
14+
missing
15+
else
16+
use_local
17+
end
18+
else
19+
@error "ROCm local preference is not valid; expected a boolean, but got '$(preferences["local"])'"
20+
missing
21+
end
22+
else
23+
missing
24+
end
25+
26+
try
27+
using HSARuntime_jll
28+
catch
29+
# during initial package installation, HSARuntime_jll may not be available.
30+
# in that case, we just won't select an artifact.
31+
end
32+
33+
struct hsa_agent_t
34+
handle::UInt64
35+
end
36+
37+
const HSA_AGENT_INFO_NAME::Cint = 0
38+
39+
@enum hsa_status_t::Cint begin
40+
HSA_STATUS_SUCCESS = 0x0
41+
HSA_STATUS_INFO_BREAK = 0x1
42+
43+
HSA_STATUS_ERROR = 0x1000
44+
HSA_STATUS_ERROR_INVALID_ARGUMENT = 0x1001
45+
HSA_STATUS_ERROR_INVALID_QUEUE_CREATION = 0x1002
46+
HSA_STATUS_ERROR_INVALID_ALLOCATION = 0x1003
47+
HSA_STATUS_ERROR_INVALID_AGENT = 0x1004
48+
HSA_STATUS_ERROR_INVALID_REGION = 0x1005
49+
HSA_STATUS_ERROR_INVALID_SIGNAL = 0x1006
50+
HSA_STATUS_ERROR_INVALID_QUEUE = 0x1007
51+
HSA_STATUS_ERROR_OUT_OF_RESOURCES = 0x1008
52+
HSA_STATUS_ERROR_INVALID_PACKET_FORMAT = 0x1009
53+
HSA_STATUS_ERROR_RESOURCE_FREE = 0x100A
54+
HSA_STATUS_ERROR_NOT_INITIALIZED = 0x100B
55+
HSA_STATUS_ERROR_REFCOUNT_OVERFLOW = 0x100C
56+
HSA_STATUS_ERROR_INCOMPATIBLE_ARGUMENTS = 0x100D
57+
HSA_STATUS_ERROR_INVALID_INDEX = 0x100E
58+
HSA_STATUS_ERROR_INVALID_ISA = 0x100F
59+
60+
HSA_STATUS_ERROR_INVALID_CODE_OBJECT = 0x1010
61+
HSA_STATUS_ERROR_INVALID_EXECUTABLE = 0x1011
62+
HSA_STATUS_ERROR_FROZEN_EXECUTABLE = 0x1012
63+
HSA_STATUS_ERROR_INVALID_SYMBOL_NAME = 0x1013
64+
HSA_STATUS_ERROR_VARIABLE_ALREADY_DEFINED = 0x1014
65+
HSA_STATUS_ERROR_VARIABLE_UNDEFINED = 0x1015
66+
HSA_STATUS_ERROR_EXCEPTION = 0x1016
67+
HSA_STATUS_ERROR_INVALID_ISA_NAME = 0x1017
68+
HSA_STATUS_ERROR_INVALID_CODE_SYMBOL = 0x1018
69+
HSA_STATUS_ERROR_INVALID_EXECUTABLE_SYMBOL = 0x1019
70+
71+
HSA_STATUS_ERROR_INVALID_FILE = 0x1020
72+
HSA_STATUS_ERROR_INVALID_CODE_OBJECT_READER = 0x1021
73+
HSA_STATUS_ERROR_INVALID_CACHE = 0x1022
74+
HSA_STATUS_ERROR_INVALID_WAVEFRONT = 0x1023
75+
HSA_STATUS_ERROR_INVALID_SIGNAL_GROUP = 0x1024
76+
HSA_STATUS_ERROR_INVALID_RUNTIME_STATE = 0x1025
77+
HSA_STATUS_ERROR_FATAL = 0x1026
78+
end
79+
80+
function callback(agent::hsa_agent_t, data::Ptr{Vector{String}})
81+
a = Base.unsafe_pointer_to_objref(data)
82+
_name = zeros(Cchar, 64)
83+
status = @ccall libhsa_runtime64.hsa_agent_get_info(agent::hsa_agent_t, HSA_AGENT_INFO_NAME::Cint, _name::Ptr{Cchar})::hsa_status_t
84+
if status == HSA_STATUS_SUCCESS
85+
GC.@preserve _name push!(a, Base.unsafe_string(pointer(_name)))
86+
end
87+
return status
88+
end
89+
90+
function agent_names()
91+
r = Ref(String[])
92+
ptr = Base.unsafe_convert(Ptr{Vector{String}}, r)
93+
cb = @cfunction(callback, hsa_status_t, (hsa_agent_t, Ptr{Vector{String}}))
94+
status = @ccall libhsa_runtime64.hsa_init()::hsa_status_t
95+
status != HSA_STATUS_SUCCESS && error(status)
96+
status = @ccall libhsa_runtime64.hsa_iterate_agents(cb::Ptr{Cvoid}, ptr::Ptr{Vector{String}})::hsa_status_t
97+
status != HSA_STATUS_SUCCESS && error(status)
98+
status = @ccall libhsa_runtime64.hsa_shut_down()::hsa_status_t
99+
status != HSA_STATUS_SUCCESS && error(status)
100+
return r[]
101+
end
102+
103+
function name_to_platform(name::String)
104+
if startswith(name, "gfx101")
105+
return "gfx101x_dgpu"
106+
elseif startswith(name, "gfx103")
107+
return "gfx103x_dgpu"
108+
elseif startswith(name, "gfx110")
109+
return "gfx110x_all"
110+
elseif name == "gfx1150"
111+
return "gfx1150"
112+
elseif name == "gfx1151"
113+
return "gfx1151"
114+
elseif startswith(name, "gfx120")
115+
return "gfx120x_all"
116+
elseif startswith(name, "gfx90")
117+
return "gfx90x_dcgpu"
118+
elseif startswith(name, "gfx94")
119+
return "gfx94x_dcgpu"
120+
elseif startswith(name, "gfx950")
121+
return "gfx950_dcgpu"
122+
else
123+
return nothing
124+
end
125+
end
126+
127+
function detect_rocm_platform()
128+
names = try
129+
agent_names()
130+
catch e
131+
@warn "Failed to detect ROCm platform: $e"
132+
String[]
133+
end
134+
filter!(startswith("gfx"), names)
135+
136+
if isempty(names)
137+
@warn "No ROCm GPU agents detected on this system."
138+
return "none"
139+
end
140+
141+
platforms = unique!(filter(!isnothing, map(name_to_platform, names)))
142+
if isempty(platforms)
143+
@warn "Unrecognized ROCm GPU agents detected on this system: $(join(names, ", "))."
144+
return "none"
145+
elseif length(platforms) > 1
146+
@warn "Multiple supported ROCm platforms detected on this system: $(join(platforms, ", ")). Using the first one. Override by setting the `rocm_platform` preference."
147+
end
148+
149+
return first(platforms)
150+
end
151+
152+
function augment_platform!(platform::Platform)
153+
# Only augment Linux x86_64 platforms
154+
if Sys.islinux() && arch(platform) == "x86_64"
155+
if !haskey(platform, "rocm_platform")
156+
platform["rocm_platform"] = detect_rocm_platform()
157+
end
158+
159+
# Store the fact that we're using a local ROCm installation
160+
platform["rocm_local"] = string(local_preference !== missing && local_preference)
161+
end
162+
163+
return platform
164+
end

platforms/rocm.jl

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
module ROCm
2+
3+
using Pkg
4+
5+
using BinaryBuilder
6+
7+
using Base.BinaryPlatforms
8+
using Base.BinaryPlatforms: arch, os, tags
9+
10+
# the "rocm_platform" platform tag contains the GPU architecture (e.g., "gfx103X-dgpu")
11+
# detected by querying the HSA runtime, and is used to select artifacts that depend on ROCm.
12+
13+
const augment = """
14+
using Base.BinaryPlatforms
15+
16+
try
17+
using rocm_sdk_core_jll
18+
catch
19+
# during initial package installation, rocm_sdk_core_jll may not be available.
20+
# in that case, we just won't select an artifact.
21+
end
22+
23+
# can't use Preferences for the same reason
24+
const rocm_sdk_core_jll_uuid = Base.UUID("9ab9228b-5f62-5ec0-95ae-72487824505f")
25+
const preferences = Base.get_preferences(rocm_sdk_core_jll_uuid)
26+
Base.record_compiletime_preference(rocm_sdk_core_jll_uuid, "local")
27+
const local_toolkit = something(tryparse(Bool, get(preferences, "local", "false")), false)
28+
29+
function rocm_comparison_strategy(a::String, b::String, a_requested::Bool, b_requested::Bool)
30+
# if we're using a local toolkit, we can't use artifacts
31+
if local_toolkit
32+
return false
33+
end
34+
return a == b
35+
end
36+
37+
function augment_platform!(platform::Platform)
38+
if !@isdefined(rocm_sdk_core_jll)
39+
# don't set to nothing or Pkg will download any artifact
40+
platform["rocm_platform"] = "none"
41+
end
42+
43+
if !haskey(platform, "rocm_platform")
44+
rocm_sdk_core_jll.augment_platform!(platform)
45+
end
46+
BinaryPlatforms.set_compare_strategy!(platform, "rocm_platform", rocm_comparison_strategy)
47+
48+
return platform
49+
end"""
50+
51+
# Known ROCm GPU architectures
52+
const rocm_platforms = [
53+
"gfx101x_dgpu",
54+
"gfx103x_dgpu",
55+
"gfx110x_all",
56+
#"gfx110x_dgpu",
57+
"gfx1150",
58+
"gfx1151",
59+
"gfx120x_all",
60+
"gfx90x_dcgpu",
61+
"gfx94x_dcgpu",
62+
"gfx950_dcgpu",
63+
]
64+
65+
"""
66+
supported_platforms(; platforms=rocm_platforms)
67+
68+
Return a list of supported platforms to build ROCm artifacts for.
69+
70+
# Arguments
71+
- `platforms=rocm_platforms`: List of ROCm GPU architectures to target.
72+
"""
73+
function supported_platforms(; platforms=rocm_platforms)
74+
base_platforms = [
75+
Platform("x86_64", "linux"; libc = "glibc", cxxstring_abi = "cxx11"),
76+
]
77+
78+
# augment with ROCm platforms
79+
result = Platform[]
80+
for base_platform in base_platforms
81+
for rocm_platform in platforms
82+
platform = deepcopy(base_platform)
83+
platform["rocm_platform"] = rocm_platform
84+
push!(result, platform)
85+
end
86+
end
87+
88+
return result
89+
end
90+
91+
"""
92+
is_supported(platform)
93+
94+
Check if a platform is supported by ROCm, and whether we can build artifacts for it.
95+
"""
96+
function is_supported(platform)
97+
return Sys.islinux(platform) && arch(platform) == "x86_64"
98+
end
99+
100+
"""
101+
required_dependencies(platform)
102+
103+
Return a list of dependencies required to build and use ROCm artifacts for a given platform.
104+
"""
105+
function required_dependencies(platform)
106+
if !haskey(tags(platform), "rocm_platform") || tags(platform)["rocm_platform"] == "none"
107+
return BinaryBuilder.AbstractDependency[]
108+
end
109+
110+
return BinaryBuilder.AbstractDependency[
111+
RuntimeDependency(PackageSpec(name="rocm_sdk_core_jll"))
112+
]
113+
end
114+
115+
end

0 commit comments

Comments
 (0)