Skip to content

Commit 219132e

Browse files
q10facebook-github-bot
authored andcommitted
Migrate make_pta_acc_format() away from old macros, v2]
Summary: X-link: facebookresearch/FBGEMM#1111 - Migrate Jinja `make_pta_acc_format()` from the old `MAKE_PTA_WITH_NAME` and `MAKE_PTA_ACC_WITH_NAME` to using `PTA_B` and `PTA_ACC_B` Reviewed By: sryap Differential Revision: D73417820 fbshipit-source-id: 8dd700bacb389b6a7443f380892b453a27cea3f6
1 parent 0f00a8a commit 219132e

File tree

4 files changed

+207
-197
lines changed

4 files changed

+207
-197
lines changed

fbgemm_gpu/codegen/genscript/jinja_environment.py

+25
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,30 @@ def make_pta_acc_format(pta_str_list: List[str], func_name: str) -> List[str]:
388388
return new_str_list
389389

390390

391+
def make_pta_acc_builder_format(pta_str_list: List[str]) -> List[str]:
392+
new_str_list = []
393+
for pta_str in pta_str_list:
394+
if "packed_accessor" in pta_str:
395+
match = re.search(
396+
r"([a-zA-z0-9_]*)[.]packed_accessor([3|6][2|4])<(.*)>\(\)", pta_str
397+
)
398+
assert match is not None and len(match.groups()) == 3
399+
tensor, acc_nbits, args = match.groups()
400+
if "acc_type" in args:
401+
match = re.search("at::acc_type<([a-zA-Z_0-9]*), true>", args)
402+
assert match is not None and len(match.groups()) == 1
403+
new_type = match.group(1)
404+
args = re.sub("at::acc_type<[a-zA-Z_]*, true>", new_type, args)
405+
macro_name = "PTA_ACC_B"
406+
else:
407+
macro_name = "PTA_B"
408+
args = args.replace(", at::RestrictPtrTraits", "")
409+
new_str_list.append(f"{macro_name}({tensor}, {args}, {acc_nbits})")
410+
else:
411+
new_str_list.append(pta_str)
412+
return new_str_list
413+
414+
391415
def replace_pta_namespace(pta_str_list: List[str]) -> List[str]:
392416
return [
393417
pta_str.replace("at::PackedTensorAccessor", "pta::PackedTensorAccessor")
@@ -431,6 +455,7 @@ def to_upper_placeholder_types(arg_str_list: List[str]) -> List[str]:
431455
################################################################################
432456

433457
env.filters["make_pta_acc_format"] = make_pta_acc_format
458+
env.filters["make_pta_acc_builder_format"] = make_pta_acc_builder_format
434459
env.filters["replace_pta_namespace"] = replace_pta_namespace
435460
env.filters["replace_placeholder_types"] = replace_placeholder_types
436461
env.filters["to_upper_placeholder_types"] = to_upper_placeholder_types

0 commit comments

Comments
 (0)