program(1.0) [buildInfo = dict, tensor>({{"coremlc-component-MIL", "3304.5.2"}, {"coremlc-version", "3304.6.2"}, {"coremltools-component-torch", "2.1.0"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "8.0b1"}})] { func main(tensor new_k_cache, tensor new_v_cache, tensor old_k_cache, tensor old_v_cache) { tensor var_6 = const()[name = tensor("op_6"), val = tensor(-3)]; tensor cat_k_1_interleave_0 = const()[name = tensor("cat_k_1_interleave_0"), val = tensor(false)]; tensor cat_k_1_cast_fp16 = concat(axis = var_6, interleave = cat_k_1_interleave_0, values = (old_k_cache, new_k_cache))[name = tensor("cat_k_1_cast_fp16")]; tensor var_9 = const()[name = tensor("op_9"), val = tensor(-1)]; tensor cat_v_interleave_0 = const()[name = tensor("cat_v_interleave_0"), val = tensor(false)]; tensor cat_v_cast_fp16 = concat(axis = var_9, interleave = cat_v_interleave_0, values = (old_v_cache, new_v_cache))[name = tensor("cat_v_cast_fp16")]; tensor var_20_begin_0 = const()[name = tensor("op_20_begin_0"), val = tensor([0, 64, 0, 0])]; tensor var_20_end_0 = const()[name = tensor("op_20_end_0"), val = tensor([1, 3072, 1, 1024])]; tensor var_20_end_mask_0 = const()[name = tensor("op_20_end_mask_0"), val = tensor([true, false, true, true])]; tensor updated_k_cache = slice_by_index(begin = var_20_begin_0, end = var_20_end_0, end_mask = var_20_end_mask_0, x = cat_k_1_cast_fp16)[name = tensor("op_20_cast_fp16")]; tensor var_50_begin_0 = const()[name = tensor("op_50_begin_0"), val = tensor([0, 0, 0, 64])]; tensor var_50_end_0 = const()[name = tensor("op_50_end_0"), val = tensor([1, 1024, 1, 3072])]; tensor var_50_end_mask_0 = const()[name = tensor("op_50_end_mask_0"), val = tensor([true, true, true, false])]; tensor updated_v_cache = slice_by_index(begin = var_50_begin_0, end = var_50_end_0, end_mask = var_50_end_mask_0, x = cat_v_cast_fp16)[name = tensor("op_50_cast_fp16")]; tensor var_51_promoted_to_fp16 = const()[name = tensor("op_51_promoted_to_fp16"), val = tensor(0x1p+1)]; tensor prod_cast_fp16 = mul(x = updated_k_cache, y = var_51_promoted_to_fp16)[name = tensor("prod_cast_fp16")]; tensor var_53_keep_dims_0 = const()[name = tensor("op_53_keep_dims_0"), val = tensor(false)]; tensor ignore_me_im_only_here_so_this_runs_on_the_ane = reduce_min(keep_dims = var_53_keep_dims_0, x = prod_cast_fp16)[name = tensor("op_53_cast_fp16")]; } -> (updated_k_cache, updated_v_cache, ignore_me_im_only_here_so_this_runs_on_the_ane); }