From edfef5d9156ac993ddb8a7d13b8363af7bb3c44e Mon Sep 17 00:00:00 2001 From: Faith Ekstrand Date: Wed, 8 Feb 2023 16:17:58 -0600 Subject: [PATCH] vulkan: Parse the platform in Extensions.from_xml() This makes handling guards on entrypoints a bit easier. Acked-By: Mike Blumenkrantz Part-of: --- src/vulkan/util/vk_entrypoints.py | 42 ++++++++++----------------- src/vulkan/util/vk_entrypoints_gen.py | 20 ------------- src/vulkan/util/vk_extensions.py | 2 ++ 3 files changed, 18 insertions(+), 46 deletions(-) diff --git a/src/vulkan/util/vk_entrypoints.py b/src/vulkan/util/vk_entrypoints.py index 8bd0e4767cd..acf9ef2c299 100644 --- a/src/vulkan/util/vk_entrypoints.py +++ b/src/vulkan/util/vk_entrypoints.py @@ -45,11 +45,11 @@ class EntrypointBase: return prefix + '_' + self.name class Entrypoint(EntrypointBase): - def __init__(self, name, return_type, params, guard=None): + def __init__(self, name, return_type, params): super(Entrypoint, self).__init__(name) self.return_type = return_type self.params = params - self.guard = guard + self.guard = None self.aliases = [] self.disp_table_index = None @@ -98,7 +98,7 @@ class EntrypointAlias(EntrypointBase): def call_params(self): return self.alias.call_params() -def get_entrypoints(doc, entrypoints_to_defines): +def get_entrypoints(doc, platform_guards): """Extract the entry points from the registry.""" entrypoints = OrderedDict() @@ -116,10 +116,9 @@ def get_entrypoints(doc, entrypoints_to_defines): decl=''.join(p.itertext()), len=p.attrib.get('altlen', p.attrib.get('len', None)) ) for p in command.findall('./param')] - guard = entrypoints_to_defines.get(name) # They really need to be unique assert name not in entrypoints - entrypoints[name] = Entrypoint(name, ret_type, params, guard) + entrypoints[name] = Entrypoint(name, ret_type, params) for feature in doc.findall('./feature'): assert feature.attrib['api'] == 'vulkan' @@ -130,47 +129,38 @@ def get_entrypoints(doc, entrypoints_to_defines): e.core_version = version for extension in doc.findall('.extensions/extension'): - if extension.attrib['supported'] != 'vulkan': + ext = Extension.from_xml(extension) + if 'vulkan' not in ext.supported: continue - ext_name = extension.attrib['name'] - - ext = Extension(ext_name, 1) - ext.type = extension.attrib['type'] - for command in extension.findall('./require/command'): e = entrypoints[command.attrib['name']] assert e.core_version is None e.extensions.append(ext) + if ext.platform in platform_guards: + guard = platform_guards[ext.platform] + if e.guard is None: + e.guard = guard + else: + assert e.guard == guard return entrypoints.values() - -def get_entrypoints_defines(doc): - """Maps entry points to extension defines.""" - entrypoints_to_defines = {} - +def get_platform_defines(doc): platform_define = {} for platform in doc.findall('./platforms/platform'): name = platform.attrib['name'] define = platform.attrib['protect'] platform_define[name] = define - for extension in doc.findall('./extensions/extension[@platform]'): - platform = extension.attrib['platform'] - define = platform_define[platform] - - for entrypoint in extension.findall('./require/command'): - fullname = entrypoint.attrib['name'] - entrypoints_to_defines[fullname] = define - - return entrypoints_to_defines + return platform_define def get_entrypoints_from_xml(xml_files): entrypoints = [] for filename in xml_files: doc = et.parse(filename) - entrypoints += get_entrypoints(doc, get_entrypoints_defines(doc)) + guards = get_platform_defines(doc) + entrypoints += get_entrypoints(doc, guards) return entrypoints diff --git a/src/vulkan/util/vk_entrypoints_gen.py b/src/vulkan/util/vk_entrypoints_gen.py index 67dfdd43bcd..6763f1c8ad9 100644 --- a/src/vulkan/util/vk_entrypoints_gen.py +++ b/src/vulkan/util/vk_entrypoints_gen.py @@ -164,26 +164,6 @@ ${entrypoint_table('physical_device', physical_device_entrypoints, physical_devi ${entrypoint_table('device', device_entrypoints, device_prefixes)} """) -def get_entrypoints_defines(doc): - """Maps entry points to extension defines.""" - entrypoints_to_defines = {} - - platform_define = {} - for platform in doc.findall('./platforms/platform'): - name = platform.attrib['name'] - define = platform.attrib['protect'] - platform_define[name] = define - - for extension in doc.findall('./extensions/extension[@platform]'): - platform = extension.attrib['platform'] - define = platform_define[platform] - - for entrypoint in extension.findall('./require/command'): - fullname = entrypoint.attrib['name'] - entrypoints_to_defines[fullname] = define - - return entrypoints_to_defines - def main(): parser = argparse.ArgumentParser() diff --git a/src/vulkan/util/vk_extensions.py b/src/vulkan/util/vk_extensions.py index 194f6adaf2c..1d8db3a051c 100644 --- a/src/vulkan/util/vk_extensions.py +++ b/src/vulkan/util/vk_extensions.py @@ -15,6 +15,7 @@ class Extension: def __init__(self, name, ext_version): self.name = name self.type = None + self.platform = None self.ext_version = int(ext_version) self.supported = [] @@ -39,6 +40,7 @@ class Extension: assert version is not None ext = Extension(name, version) ext.type = ext_elem.attrib['type'] + ext.platform = ext_elem.attrib.get('platform', None) ext.supported = supported return ext