1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
| @I.ir_module class Module: @T.prim_func(private=True) def my_conv2d(x: T.Buffer((T.int64(4), T.int64(1), T.int64(28), T.int64(28)), "float32"), B: T.Buffer((T.int64(32), T.int64(1), T.int64(3), T.int64(3)), "float32"), C: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(1)), "float32"), compute: T.Buffer((T.int64(4), T.int64(32), T.int64(26), T.int64(26)), "float32")): T.func_attr({"tir.noalias": T.bool(True)}) conv2d = T.alloc_buffer((T.int64(4), T.int64(32), T.int64(26), T.int64(26))) for n, co, oh, ow, k, r, s in T.grid(T.int64(4), T.int64(32), T.int64(26), T.int64(26), T.int64(1), T.int64(3), T.int64(3)): with T.block("conv2d"): v_n, v_co, v_oh, v_ow, v_k, v_r, v_s = T.axis.remap("SSSSRRR", [n, co, oh, ow, k, r, s]) T.reads(x[v_n, v_k, v_oh + v_r, v_ow + v_s], B[v_co, v_k, v_r, v_s]) T.writes(conv2d[v_n, v_co, v_oh, v_ow]) with T.init(): conv2d[v_n, v_co, v_oh, v_ow] = T.float32(0.0) conv2d[v_n, v_co, v_oh, v_ow] = conv2d[v_n, v_co, v_oh, v_ow] + x[v_n, v_k, v_oh + v_r, v_ow + v_s] * B[v_co, v_k, v_r, v_s] for n, co, oh, ow in T.grid(T.int64(4), T.int64(32), T.int64(26), T.int64(26)): with T.block("compute"): v_n, v_co, v_oh, v_ow = T.axis.remap("SSSS", [n, co, oh, ow]) T.reads(conv2d[v_n, v_co, v_oh, v_ow], C[T.int64(0), v_co, T.int64(0), T.int64(0)]) T.writes(compute[v_n, v_co, v_oh, v_ow]) compute[v_n, v_co, v_oh, v_ow] = conv2d[v_n, v_co, v_oh, v_ow] + C[T.int64(0), v_co, T.int64(0), T.int64(0)]
@T.prim_func(private=True) def my_flatten(lv2: T.Buffer((T.int64(4), T.int64(32), T.int64(13), T.int64(13)), "float32"), compute: T.Buffer((T.int64(4), T.int64(5408)), "float32")): T.func_attr({"tir.noalias": T.bool(True)}) for n, i in T.grid(T.int64(4), T.int64(5408)): with T.block("compute"): v_n, v_i = T.axis.remap("SS", [n, i]) T.reads(lv2[v_n, v_i // T.int64(169), v_i % T.int64(169) // T.int64(13), v_i % T.int64(13)]) T.writes(compute[v_n, v_i]) compute[v_n, v_i] = lv2[v_n, v_i // T.int64(169), v_i % T.int64(169) // T.int64(13), v_i % T.int64(13)]
@T.prim_func(private=True) def my_linear(lv3: T.Buffer((T.int64(4), T.int64(5408)), "float32"), B: T.Buffer((T.int64(100), T.int64(5408)), "float32"), C: T.Buffer((T.int64(1), T.int64(100)), "float32"), compute: T.Buffer((T.int64(4), T.int64(100)), "float32")): T.func_attr({"tir.noalias": T.bool(True)}) compute_1 = T.alloc_buffer((T.int64(4), T.int64(100))) for i, j, FI in T.grid(T.int64(4), T.int64(100), T.int64(5408)): with T.block("compute"): v_i, v_j, v_FI = T.axis.remap("SSR", [i, j, FI]) T.reads(lv3[v_i, v_FI], B[v_j, v_FI]) T.writes(compute_1[v_i, v_j]) with T.init(): compute_1[v_i, v_j] = T.float32(0.0) compute_1[v_i, v_j] = compute_1[v_i, v_j] + lv3[v_i, v_FI] * B[v_j, v_FI] for i, j in T.grid(T.int64(4), T.int64(100)): with T.block("compute_1"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(C[T.int64(0), v_j], compute_1[v_i, v_j]) T.writes(compute[v_i, v_j]) compute[v_i, v_j] = C[T.int64(0), v_j] + compute_1[v_i, v_j]
@T.prim_func(private=True) def my_linear1(lv5: T.Buffer((T.int64(4), T.int64(100)), "float32"), B: T.Buffer((T.int64(10), T.int64(100)), "float32"), C: T.Buffer((T.int64(1), T.int64(10)), "float32"), compute: T.Buffer((T.int64(4), T.int64(10)), "float32")): T.func_attr({"tir.noalias": T.bool(True)}) compute_1 = T.alloc_buffer((T.int64(4), T.int64(10))) for i, j, FI in T.grid(T.int64(4), T.int64(10), T.int64(100)): with T.block("compute"): v_i, v_j, v_FI = T.axis.remap("SSR", [i, j, FI]) T.reads(lv5[v_i, v_FI], B[v_j, v_FI]) T.writes(compute_1[v_i, v_j]) with T.init(): compute_1[v_i, v_j] = T.float32(0.0) compute_1[v_i, v_j] = compute_1[v_i, v_j] + lv5[v_i, v_FI] * B[v_j, v_FI] for i, j in T.grid(T.int64(4), T.int64(10)): with T.block("compute_1"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(C[T.int64(0), v_j], compute_1[v_i, v_j]) T.writes(compute[v_i, v_j]) compute[v_i, v_j] = C[T.int64(0), v_j] + compute_1[v_i, v_j]
@T.prim_func(private=True) def my_maxpool2d(lv1: T.Buffer((T.int64(4), T.int64(32), T.int64(26), T.int64(26)), "float32"), maxpool2d: T.Buffer((T.int64(4), T.int64(32), T.int64(13), T.int64(13)), "float32")): T.func_attr({"tir.noalias": T.bool(True)}) for n, co, oh, ow, i, j in T.grid(T.int64(4), T.int64(32), T.int64(13), T.int64(13), T.int64(2), T.int64(2)): with T.block("maxpool2d"): v_n, v_co, v_oh, v_ow, v_i, v_j = T.axis.remap("SSSSRR", [n, co, oh, ow, i, j]) T.reads(lv1[v_n, v_co, v_oh * T.int64(2) + v_i, v_ow * T.int64(2) + v_j]) T.writes(maxpool2d[v_n, v_co, v_oh, v_ow]) with T.init(): maxpool2d[v_n, v_co, v_oh, v_ow] = T.float32(-340282346638528859811704183484516925440.0) maxpool2d[v_n, v_co, v_oh, v_ow] = T.max(maxpool2d[v_n, v_co, v_oh, v_ow], lv1[v_n, v_co, v_oh * T.int64(2) + v_i, v_ow * T.int64(2) + v_j])
@T.prim_func(private=True) def my_relu(lv: T.Buffer((T.int64(4), T.int64(32), T.int64(26), T.int64(26)), "float32"), compute: T.Buffer((T.int64(4), T.int64(32), T.int64(26), T.int64(26)), "float32")): T.func_attr({"tir.noalias": T.bool(True)}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(32), T.int64(26), T.int64(26)): with T.block("compute"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(lv[v_i0, v_i1, v_i2, v_i3]) T.writes(compute[v_i0, v_i1, v_i2, v_i3]) compute[v_i0, v_i1, v_i2, v_i3] = T.max(lv[v_i0, v_i1, v_i2, v_i3], T.float32(0.0))
@T.prim_func(private=True) def my_relu1(lv4: T.Buffer((T.int64(4), T.int64(100)), "float32"), compute: T.Buffer((T.int64(4), T.int64(100)), "float32")): T.func_attr({"tir.noalias": T.bool(True)}) for i0, i1 in T.grid(T.int64(4), T.int64(100)): with T.block("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(lv4[v_i0, v_i1]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.max(lv4[v_i0, v_i1], T.float32(0.0))
@T.prim_func(private=True) def my_softmax(lv6: T.Buffer((T.int64(4), T.int64(10)), "float32"), compute: T.Buffer((T.int64(4), T.int64(10)), "float32")): T.func_attr({"tir.noalias": T.bool(True)}) compute_1 = T.alloc_buffer((T.int64(4),)) compute_2 = T.alloc_buffer((T.int64(4), T.int64(10))) compute_3 = T.alloc_buffer((T.int64(4),)) for i, c in T.grid(T.int64(4), T.int64(10)): with T.block("compute"): v_i, v_c = T.axis.remap("SR", [i, c]) T.reads(lv6[v_i, v_c]) T.writes(compute_1[v_i]) with T.init(): compute_1[v_i] = T.float32(-340282346638528859811704183484516925440.0) compute_1[v_i] = T.max(compute_1[v_i], lv6[v_i, v_c]) for i, j in T.grid(T.int64(4), T.int64(10)): with T.block("compute_1"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(lv6[v_i, v_j], compute_1[v_i]) T.writes(compute_2[v_i, v_j]) compute_2[v_i, v_j] = T.exp(lv6[v_i, v_j] - compute_1[v_i]) for i, c in T.grid(T.int64(4), T.int64(10)): with T.block("compute_2"): v_i, v_c = T.axis.remap("SR", [i, c]) T.reads(compute_2[v_i, v_c]) T.writes(compute_3[v_i]) with T.init(): compute_3[v_i] = T.float32(0.0) compute_3[v_i] = compute_3[v_i] + compute_2[v_i, v_c] for i, j in T.grid(T.int64(4), T.int64(10)): with T.block("compute_3"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(compute_2[v_i, v_j], compute_3[v_i]) T.writes(compute[v_i, v_j]) compute[v_i, v_j] = compute_2[v_i, v_j] / compute_3[v_i]
@R.function def main(x: R.Tensor((4, 1, 28, 28), dtype="float32")) -> R.Tensor((4, 10), dtype="float32"): cls = Module with R.dataflow(): lv = R.call_tir(cls.my_conv2d, (x, metadata["relax.expr.Constant"][0], metadata["relax.expr.Constant"][1]), out_sinfo=R.Tensor((4, 32, 26, 26), dtype="float32")) lv1 = R.call_tir(cls.my_relu, (lv,), out_sinfo=R.Tensor((4, 32, 26, 26), dtype="float32")) lv2 = R.call_tir(cls.my_maxpool2d, (lv1,), out_sinfo=R.Tensor((4, 32, 13, 13), dtype="float32")) lv3 = R.call_tir(cls.my_flatten, (lv2,), out_sinfo=R.Tensor((4, 5408), dtype="float32")) lv4 = R.call_tir(cls.my_linear, (lv3, metadata["relax.expr.Constant"][2], metadata["relax.expr.Constant"][3]), out_sinfo=R.Tensor((4, 100), dtype="float32")) lv5 = R.call_tir(cls.my_relu1, (lv4,), out_sinfo=R.Tensor((4, 100), dtype="float32")) lv6 = R.call_tir(cls.my_linear1, (lv5, metadata["relax.expr.Constant"][4], metadata["relax.expr.Constant"][5]), out_sinfo=R.Tensor((4, 10), dtype="float32")) lv7 = R.call_tir(cls.my_softmax, (lv6,), out_sinfo=R.Tensor((4, 10), dtype="float32")) gv: R.Tensor((4, 10), dtype="float32") = lv7 R.output(gv) return gv
|