diff --git a/python/perf-kernels/tools/rocm-triton-prof/README.md b/python/perf-kernels/tools/rocm-triton-prof/README.md index 350c0aa0ce00..04b9bf5c2cfb 100644 --- a/python/perf-kernels/tools/rocm-triton-prof/README.md +++ b/python/perf-kernels/tools/rocm-triton-prof/README.md @@ -129,7 +129,20 @@ max 261.981219 dtype: float64 ``` +You can also run the tool using a yaml configuration file as well. For example: + +```bash +$ cat ./config.yaml +kernel: ".*fwd" +cmd: 'python3 ./flash-attention.py -b 2 -hq 4 -hk 4 -sq 113 -sk 123 -d 1 -causal -layout bhsd' + +$ python3 ./rocm-triton-prof.py -f ./config.yaml +``` + +Note, you can use a regular expression for the kernel name (see example above). The tool is +going to abort of more than one kernel match the given regular expression. + ### Known limits -The tool currently supports only FP64, FP32 and FP16 operations. -Note, it can be extended to supoprt other data types. +The tool currently supports only FP64, FP32 and FP16 operations. Note, it can be extended +to supoprt other data types. diff --git a/python/perf-kernels/tools/rocm-triton-prof/rocm-triton-prof.py b/python/perf-kernels/tools/rocm-triton-prof/rocm-triton-prof.py index feafb74db311..d375fa1309e2 100755 --- a/python/perf-kernels/tools/rocm-triton-prof/rocm-triton-prof.py +++ b/python/perf-kernels/tools/rocm-triton-prof/rocm-triton-prof.py @@ -104,9 +104,22 @@ def filter(df, name): return df[df['Kernel_Name'] == name] -def process_files(metrics_dir, timing_dir, kernel_name, verbose): +def find_kernel_name(df, kernel_expr, verbose): + names = df['Kernel_Name'] + raw_kernel_expr = r'{}'.format(kernel_expr) + names_set = set(names[names.str.match(raw_kernel_expr)]) + if len(names_set) != 1: + raise RuntimeError(f'Error: found several kernels matching `kernel` regex i.e., {names_set}') + kernel_name = next(iter(names_set)) + if verbose: + print(f'\nTracking: `{kernel_name}` kernel\n') + return next(iter(names_set)) + + +def process_files(metrics_dir, timing_dir, kernel_expr, verbose): timing_file = find_file(timing_dir, re.compile(r'.*kernel_trace.csv')) df = pd.read_csv(timing_file) + kernel_name = find_kernel_name(df, kernel_expr, verbose) df = filter(df, kernel_name) timing = df['End_Timestamp'] - df['Start_Timestamp'] print('Timing info in `nsec`:') @@ -247,10 +260,33 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-k", "--kernel", type=str, required=True, help="name of a kernel") - parser.add_argument('-c', '--cmd', required=True, nargs=argparse.REMAINDER, help='user command') + parser.add_argument("-f", "--file", type=str, required=False, help="config file") + parser.add_argument("-k", "--kernel", type=str, required=False, help="name of a kernel") + parser.add_argument('-c', '--cmd', required=False, nargs=argparse.REMAINDER, help='user command') parser.add_argument("--display-only", action='store_true', help='display info without running') parser.add_argument("-v", "--verbose", action='store_true', help='verbose output') args = parser.parse_args() - + if args.file: + if os.path.exists(args.file): + RuntimeError(f'file `{args.file}` does not exist') + with open(args.file, 'r') as stream: + data = yaml.safe_load(stream) + args.kernel = data['kernel'] + + args.cmd = data['cmd'] + assert (type(args.cmd) is str) + args.cmd = args.cmd.split(' ') + + args.display_only = args.display_only + if 'display_only' in data: + args.display_only = data['display_only'] + + args.verbose = args.verbose + if 'verbose' in data: + args.verbose = data['verbose'] + else: + if not hasattr(args, 'kernel'): + RuntimeError('kernel expr must be specified (--kernel)') + if not hasattr(args, 'cmd'): + RuntimeError('program command must be specified (--cmd)') main(args)