summaryrefslogtreecommitdiff
path: root/pkgs/development/rocm-modules/6/tensile/tensile-solutionstructs-perf-fix.diff
blob: 7157238042ece876bd37eec212a42a1df7114f56 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
diff --git a/Tensile/SolutionStructs.py b/Tensile/SolutionStructs.py
index f663c6f1..17bcf897 100644
--- a/Tensile/SolutionStructs.py
+++ b/Tensile/SolutionStructs.py
@@ -4828,24 +4828,26 @@ class Solution(collections.abc.Mapping):
   # create a dictionary of lists of parameter values
   @staticmethod
   def getSerialNaming(objs):
+    valid_params = sorted(validParameters.keys())
     data = {}
-    for objIdx in range(0, len(objs)):
-      obj = objs[objIdx]
-      for paramName in sorted(obj.keys()):
-        if paramName in list(validParameters.keys()):
-          paramValue = obj[paramName]
-          if paramName in data:
-            if paramValue not in data[paramName]:
-              data[paramName].append(paramValue)
-          else:
-            data[paramName] = [ paramValue ]
-    maxObjs = 1
-    for paramName in data:
-      if not isinstance(data[paramName][0],dict):
-        data[paramName] = sorted(data[paramName])
-      maxObjs *= len(data[paramName])
-    numDigits = len(str(maxObjs))
-    return [ data, numDigits ]
+
+    objs = [getattr(obj, "_state", obj) for obj in objs]
+
+    for param in valid_params:
+      d = []
+      for obj in objs:
+        if param in obj:
+          v = obj[param]
+          if v not in d:
+            d.append(v)
+      if len(d):
+        if not isinstance(d[0], dict): d.sort()
+        data[param] = d
+
+    # Calculate max objects using prod() from math module
+    max_objs = math.prod(len(values) for values in data.values())
+    num_digits = len(str(max_objs))
+    return data, num_digits
 
   ########################################
   # Get Name Serial