Skip to content

Instantly share code, notes, and snippets.

@masahi
Created April 11, 2020 08:20
Show Gist options
  • Save masahi/7754387797169380c32c8f93a9cbe9e3 to your computer and use it in GitHub Desktop.
Save masahi/7754387797169380c32c8f93a9cbe9e3 to your computer and use it in GitHub Desktop.
fn (%v54: Tensor[(16, 3), float32], %v60: Tensor[(16), float32], %v61: Tensor[(16), float32], %v66: Tensor[(16, 4), float32], %v72: Tensor[(16), float32], %v73: Tensor[(16), float32], %v94: Tensor[(4), float32], %v95: Tensor[(4), float32], %states: List[List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])]], %input: Tensor[(5, 2, 3), float32], %v135: Tensor[(16, 3), float32], %v141: Tensor[(16), float32], %v142: Tensor[(16), float32], %v147: Tensor[(16, 4), float32], %v153: Tensor[(16), float32], %v154: Tensor[(16), float32], %v175: Tensor[(4), float32], %v176: Tensor[(4), float32], %v222: Tensor[(16, 4), float32], %v228: Tensor[(16), float32], %v229: Tensor[(16), float32], %v234: Tensor[(16, 4), float32], %v240: Tensor[(16), float32], %v241: Tensor[(16), float32], %v262: Tensor[(4), float32], %v263: Tensor[(4), float32], %v303: Tensor[(16, 4), float32], %v309: Tensor[(16), float32], %v310: Tensor[(16), float32], %v315: Tensor[(16, 4), float32], %v321: Tensor[(16), float32], %v322: Tensor[(16), float32], %v343: Tensor[(4), float32], %v344: Tensor[(4), float32], %v390: Tensor[(16, 4), float32], %v396: Tensor[(16), float32], %v397: Tensor[(16), float32], %v402: Tensor[(16, 4), float32], %v408: Tensor[(16), float32], %v409: Tensor[(16), float32], %v430: Tensor[(4), float32], %v431: Tensor[(4), float32], %v471: Tensor[(16, 4), float32], %v477: Tensor[(16), float32], %v478: Tensor[(16), float32], %v483: Tensor[(16, 4), float32], %v489: Tensor[(16), float32], %v490: Tensor[(16), float32], %v511: Tensor[(4), float32], %v512: Tensor[(4), float32]) -> (Tensor[(?, 2, 4), float32], List[List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])]]) {
%0 = Nil /* ty=List[Tensor[(?, 2, 4), float32]] */;
%1 = Nil /* ty=List[Tensor[(2, 4), float32]] */;
%2 = @nth(%states, 2 /* ty=int32 */) /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%3 = @nth(%2, 0 /* ty=int32 */) /* ty=(Tensor[(2, 4), float32], Tensor[(2, 4), float32]) */;
%4 = Nil /* ty=List[Tensor[(?, 2, 4), float32]] */;
%5 = Nil /* ty=List[Tensor[(2, 4), float32]] */;
%6 = @nth(%states, 1 /* ty=int32 */) /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%7 = @nth(%6, 0 /* ty=int32 */) /* ty=(Tensor[(2, 4), float32], Tensor[(2, 4), float32]) */;
%8 = Nil /* ty=List[Tensor[(?, 2, 4), float32]] */;
%9 = Nil /* ty=List[Tensor[(2, 4), float32]] */;
%10 = @nth(%states, 0 /* ty=int32 */) /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%11 = @nth(%10, 0 /* ty=int32 */) /* ty=(Tensor[(2, 4), float32], Tensor[(2, 4), float32]) */;
%47 = (
let %while_loop: fn (int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(5, 2, 3), float32]) -> (int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(5, 2, 3), float32]) = fn (%i.10: int32, %outputs.14: List[Tensor[(2, 4), float32]], %state.13: (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), %input.1: Tensor[(5, 2, 3), float32]) -> (int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(5, 2, 3), float32]) {
%12 = less(%i.10, 5 /* ty=int32 */) /* ty=bool */;
if (%12) {
%13 = add(%i.10, 1 /* ty=int32 */) /* ty=int32 */;
%14 = take(%input.1, %i.10, axis=0) /* ty=Tensor[(2, 3), float32] */;
%15 = transpose(%v54, axes=[1, 0]) /* ty=Tensor[(3, 16), float32] */;
%16 = transpose(%15, axes=[1, 0]) /* ty=Tensor[(16, 3), float32] */;
%17 = nn.dense(%14, %16, units=None) /* ty=Tensor[(2, 16), float32] */;
%18 = nn.layer_norm(%17, %v60, %v61) /* ty=Tensor[(2, 16), float32] */;
%19 = %state.13.0;
%20 = transpose(%v66, axes=[1, 0]) /* ty=Tensor[(4, 16), float32] */;
%21 = transpose(%20, axes=[1, 0]) /* ty=Tensor[(16, 4), float32] */;
%22 = nn.dense(%19, %21, units=None) /* ty=Tensor[(2, 16), float32] */;
%23 = nn.layer_norm(%22, %v72, %v73) /* ty=Tensor[(2, 16), float32] */;
%24 = add(%18, %23) /* ty=Tensor[(2, 16), float32] */;
%25 = strided_slice(%24, begin=[0, 12], end=[2, 16], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%26 = sigmoid(%25) /* ty=Tensor[(2, 4), float32] */;
%27 = strided_slice(%24, begin=[0, 4], end=[2, 8], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%28 = sigmoid(%27) /* ty=Tensor[(2, 4), float32] */;
%29 = %state.13.1;
%30 = multiply(%28, %29) /* ty=Tensor[(2, 4), float32] */;
%31 = strided_slice(%24, begin=[0, 0], end=[2, 4], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%32 = sigmoid(%31) /* ty=Tensor[(2, 4), float32] */;
%33 = strided_slice(%24, begin=[0, 8], end=[2, 12], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%34 = tanh(%33) /* ty=Tensor[(2, 4), float32] */;
%35 = multiply(%32, %34) /* ty=Tensor[(2, 4), float32] */;
%36 = add(%30, %35) /* ty=Tensor[(2, 4), float32] */;
%37 = nn.layer_norm(%36, %v94, %v95) /* ty=Tensor[(2, 4), float32] */;
%38 = tanh(%37) /* ty=Tensor[(2, 4), float32] */;
%39 = multiply(%26, %38) /* ty=Tensor[(2, 4), float32] */;
%40 = (%39, %37);
%41 = (%39, %40);
%42 = %41.0;
%43 = Nil /* ty=List[Tensor[(2, 4), float32]] */;
%44 = Cons(%42, %43) /* ty=List[Tensor[(2, 4), float32]] */;
%45 = @concat(%outputs.14, %44) /* ty=List[Tensor[(2, 4), float32]] */;
%46 = %41.1;
%while_loop(%13, %45, %46, %input.1) /* ty=(int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(5, 2, 3), float32]) */
} else {
(%i.10, %outputs.14, %state.13, %input.1)
}
};
%while_loop
);
%48 = %47(0 /* ty=int32 */, %9, %11, %input) /* ty=(int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(5, 2, 3), float32]) */;
%49 = %48.1;
%50 = @map(tensor_constructor_float32_2_4(Tensor[(2, 4), float32]), %49) /* ty=List[static_tensor_float32_2_4_t[]] */;
%51 = @tensor_array_stack_float32_2_4(%50) /* ty=static_tensor_float32_?_2_4_t[] */;
%52 = @tensor_get_data_float32_2_4(%51) /* ty=Tensor[(?, 2, 4), float32] */;
%53 = %48.2;
%54 = (%52, %53);
%55 = %54.0;
%56 = Nil /* ty=List[Tensor[(?, 2, 4), float32]] */;
%57 = Cons(%55, %56) /* ty=List[Tensor[(?, 2, 4), float32]] */;
%58 = @concat(%8, %57) /* ty=List[Tensor[(?, 2, 4), float32]] */;
%59 = Nil /* ty=List[Tensor[(2, 4), float32]] */;
%60 = @nth(%10, 1 /* ty=int32 */) /* ty=(Tensor[(2, 4), float32], Tensor[(2, 4), float32]) */;
%98 = (
let %while_loop1: fn (int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(5, 2, 3), float32]) -> (int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(5, 2, 3), float32]) = fn (%i.11: int32, %outputs.19: List[Tensor[(2, 4), float32]], %state.17: (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), %input.11: Tensor[(5, 2, 3), float32]) -> (int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(5, 2, 3), float32]) {
%61 = less(%i.11, 5 /* ty=int32 */) /* ty=bool */;
if (%61) {
%62 = add(%i.11, 1 /* ty=int32 */) /* ty=int32 */;
%63 = subtract(5 /* ty=int32 */, %i.11) /* ty=int32 */;
%64 = subtract(%63, 1 /* ty=int32 */) /* ty=int32 */;
%65 = take(%input.11, %64, axis=0) /* ty=Tensor[(2, 3), float32] */;
%66 = transpose(%v135, axes=[1, 0]) /* ty=Tensor[(3, 16), float32] */;
%67 = transpose(%66, axes=[1, 0]) /* ty=Tensor[(16, 3), float32] */;
%68 = nn.dense(%65, %67, units=None) /* ty=Tensor[(2, 16), float32] */;
%69 = nn.layer_norm(%68, %v141, %v142) /* ty=Tensor[(2, 16), float32] */;
%70 = %state.17.0;
%71 = transpose(%v147, axes=[1, 0]) /* ty=Tensor[(4, 16), float32] */;
%72 = transpose(%71, axes=[1, 0]) /* ty=Tensor[(16, 4), float32] */;
%73 = nn.dense(%70, %72, units=None) /* ty=Tensor[(2, 16), float32] */;
%74 = nn.layer_norm(%73, %v153, %v154) /* ty=Tensor[(2, 16), float32] */;
%75 = add(%69, %74) /* ty=Tensor[(2, 16), float32] */;
%76 = strided_slice(%75, begin=[0, 12], end=[2, 16], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%77 = sigmoid(%76) /* ty=Tensor[(2, 4), float32] */;
%78 = strided_slice(%75, begin=[0, 4], end=[2, 8], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%79 = sigmoid(%78) /* ty=Tensor[(2, 4), float32] */;
%80 = %state.17.1;
%81 = multiply(%79, %80) /* ty=Tensor[(2, 4), float32] */;
%82 = strided_slice(%75, begin=[0, 0], end=[2, 4], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%83 = sigmoid(%82) /* ty=Tensor[(2, 4), float32] */;
%84 = strided_slice(%75, begin=[0, 8], end=[2, 12], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%85 = tanh(%84) /* ty=Tensor[(2, 4), float32] */;
%86 = multiply(%83, %85) /* ty=Tensor[(2, 4), float32] */;
%87 = add(%81, %86) /* ty=Tensor[(2, 4), float32] */;
%88 = nn.layer_norm(%87, %v175, %v176) /* ty=Tensor[(2, 4), float32] */;
%89 = tanh(%88) /* ty=Tensor[(2, 4), float32] */;
%90 = multiply(%77, %89) /* ty=Tensor[(2, 4), float32] */;
%91 = (%90, %88);
%92 = (%90, %91);
%93 = %92.0;
%94 = Nil /* ty=List[Tensor[(2, 4), float32]] */;
%95 = Cons(%93, %94) /* ty=List[Tensor[(2, 4), float32]] */;
%96 = @concat(%95, %outputs.19) /* ty=List[Tensor[(2, 4), float32]] */;
%97 = %92.1;
%while_loop1(%62, %96, %97, %input.11) /* ty=(int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(5, 2, 3), float32]) */
} else {
(%i.11, %outputs.19, %state.17, %input.11)
}
};
%while_loop1
);
%99 = %98(0 /* ty=int32 */, %59, %60, %input) /* ty=(int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(5, 2, 3), float32]) */;
%100 = %99.1;
%101 = @map(tensor_constructor_float32_2_4(Tensor[(2, 4), float32]), %100) /* ty=List[static_tensor_float32_2_4_t[]] */;
%102 = @tensor_array_stack_float32_2_4(%101) /* ty=static_tensor_float32_?_2_4_t[] */;
%103 = @tensor_get_data_float32_2_4(%102) /* ty=Tensor[(?, 2, 4), float32] */;
%104 = %99.2;
%105 = (%103, %104);
%106 = %105.0;
%107 = Nil /* ty=List[Tensor[(?, 2, 4), float32]] */;
%108 = Cons(%106, %107) /* ty=List[Tensor[(?, 2, 4), float32]] */;
%109 = @concat(%58, %108) /* ty=List[Tensor[(?, 2, 4), float32]] */;
%110 = @map(tensor_constructor_float32_?_2_4(Tensor[(?, 2, 4), float32]), %109) /* ty=List[static_tensor_float32_?_2_4_t[]] */;
%111 = @tensor_array_concat_float32_?_2_4(%110) /* ty=static_tensor_float32_?_2_4_t[] */;
%112 = @tensor_get_data_float32_?_2_4(%111) /* ty=Tensor[(?, 2, 4), float32] */;
%113 = Nil /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%114 = %54.1;
%115 = Nil /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%116 = Cons(%114, %115) /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%117 = @concat(%113, %116) /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%118 = %105.1;
%119 = Nil /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%120 = Cons(%118, %119) /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%121 = @concat(%117, %120) /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%122 = (%112, %121);
%123 = %122.0;
%161 = (
let %while_loop2: fn (int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(?, 2, 4), float32]) -> (int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(?, 2, 4), float32]) = fn (%i.14: int32, %outputs.25: List[Tensor[(2, 4), float32]], %state.21: (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), %output.2: Tensor[(?, 2, 4), float32]) -> (int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(?, 2, 4), float32]) {
%124 = shape_of(%123, dtype="int32") /* ty=Tensor[(3), int32] */;
%125 = take(%124, 0 /* ty=int32 */, axis=0) /* ty=int32 */;
%126 = less(%i.14, %125) /* ty=bool */;
if (%126) {
%127 = add(%i.14, 1 /* ty=int32 */) /* ty=int32 */;
%128 = take(%output.2, %i.14, axis=0) /* ty=Tensor[(2, 4), float32] */;
%129 = transpose(%v222, axes=[1, 0]) /* ty=Tensor[(4, 16), float32] */;
%130 = transpose(%129, axes=[1, 0]) /* ty=Tensor[(16, 4), float32] */;
%131 = nn.dense(%128, %130, units=None) /* ty=Tensor[(2, 16), float32] */;
%132 = nn.layer_norm(%131, %v228, %v229) /* ty=Tensor[(2, 16), float32] */;
%133 = %state.21.0;
%134 = transpose(%v234, axes=[1, 0]) /* ty=Tensor[(4, 16), float32] */;
%135 = transpose(%134, axes=[1, 0]) /* ty=Tensor[(16, 4), float32] */;
%136 = nn.dense(%133, %135, units=None) /* ty=Tensor[(2, 16), float32] */;
%137 = nn.layer_norm(%136, %v240, %v241) /* ty=Tensor[(2, 16), float32] */;
%138 = add(%132, %137) /* ty=Tensor[(2, 16), float32] */;
%139 = strided_slice(%138, begin=[0, 12], end=[2, 16], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%140 = sigmoid(%139) /* ty=Tensor[(2, 4), float32] */;
%141 = strided_slice(%138, begin=[0, 4], end=[2, 8], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%142 = sigmoid(%141) /* ty=Tensor[(2, 4), float32] */;
%143 = %state.21.1;
%144 = multiply(%142, %143) /* ty=Tensor[(2, 4), float32] */;
%145 = strided_slice(%138, begin=[0, 0], end=[2, 4], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%146 = sigmoid(%145) /* ty=Tensor[(2, 4), float32] */;
%147 = strided_slice(%138, begin=[0, 8], end=[2, 12], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%148 = tanh(%147) /* ty=Tensor[(2, 4), float32] */;
%149 = multiply(%146, %148) /* ty=Tensor[(2, 4), float32] */;
%150 = add(%144, %149) /* ty=Tensor[(2, 4), float32] */;
%151 = nn.layer_norm(%150, %v262, %v263) /* ty=Tensor[(2, 4), float32] */;
%152 = tanh(%151) /* ty=Tensor[(2, 4), float32] */;
%153 = multiply(%140, %152) /* ty=Tensor[(2, 4), float32] */;
%154 = (%153, %151);
%155 = (%153, %154);
%156 = %155.0;
%157 = Nil /* ty=List[Tensor[(2, 4), float32]] */;
%158 = Cons(%156, %157) /* ty=List[Tensor[(2, 4), float32]] */;
%159 = @concat(%outputs.25, %158) /* ty=List[Tensor[(2, 4), float32]] */;
%160 = %155.1;
%while_loop2(%127, %159, %160, %output.2) /* ty=(int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(?, 2, 4), float32]) */
} else {
(%i.14, %outputs.25, %state.21, %output.2)
}
};
%while_loop2
);
%162 = %161(0 /* ty=int32 */, %5, %7, %123) /* ty=(int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(?, 2, 4), float32]) */;
%163 = %162.1;
%164 = @map(tensor_constructor_float32_2_4(Tensor[(2, 4), float32]), %163) /* ty=List[static_tensor_float32_2_4_t[]] */;
%165 = @tensor_array_stack_float32_2_4(%164) /* ty=static_tensor_float32_?_2_4_t[] */;
%166 = @tensor_get_data_float32_2_4(%165) /* ty=Tensor[(?, 2, 4), float32] */;
%167 = %162.2;
%168 = (%166, %167);
%169 = %168.0;
%170 = Nil /* ty=List[Tensor[(?, 2, 4), float32]] */;
%171 = Cons(%169, %170) /* ty=List[Tensor[(?, 2, 4), float32]] */;
%172 = @concat(%4, %171) /* ty=List[Tensor[(?, 2, 4), float32]] */;
%173 = Nil /* ty=List[Tensor[(2, 4), float32]] */;
%174 = @nth(%6, 1 /* ty=int32 */) /* ty=(Tensor[(2, 4), float32], Tensor[(2, 4), float32]) */;
%175 = shape_of(%123, dtype="int32") /* ty=Tensor[(3), int32] */;
%176 = take(%175, 0 /* ty=int32 */, axis=0) /* ty=int32 */;
%214 = (
let %while_loop3: fn (int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), int32, Tensor[(?, 2, 4), float32]) -> (int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), int32, Tensor[(?, 2, 4), float32]) = fn (%i.15: int32, %outputs.30: List[Tensor[(2, 4), float32]], %state.25: (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), %seq_len.3: int32, %output.21: Tensor[(?, 2, 4), float32]) -> (int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), int32, Tensor[(?, 2, 4), float32]) {
%177 = less(%i.15, %176) /* ty=bool */;
if (%177) {
%178 = add(%i.15, 1 /* ty=int32 */) /* ty=int32 */;
%179 = subtract(%seq_len.3, %i.15) /* ty=int32 */;
%180 = subtract(%179, 1 /* ty=int32 */) /* ty=int32 */;
%181 = take(%output.21, %180, axis=0) /* ty=Tensor[(2, 4), float32] */;
%182 = transpose(%v303, axes=[1, 0]) /* ty=Tensor[(4, 16), float32] */;
%183 = transpose(%182, axes=[1, 0]) /* ty=Tensor[(16, 4), float32] */;
%184 = nn.dense(%181, %183, units=None) /* ty=Tensor[(2, 16), float32] */;
%185 = nn.layer_norm(%184, %v309, %v310) /* ty=Tensor[(2, 16), float32] */;
%186 = %state.25.0;
%187 = transpose(%v315, axes=[1, 0]) /* ty=Tensor[(4, 16), float32] */;
%188 = transpose(%187, axes=[1, 0]) /* ty=Tensor[(16, 4), float32] */;
%189 = nn.dense(%186, %188, units=None) /* ty=Tensor[(2, 16), float32] */;
%190 = nn.layer_norm(%189, %v321, %v322) /* ty=Tensor[(2, 16), float32] */;
%191 = add(%185, %190) /* ty=Tensor[(2, 16), float32] */;
%192 = strided_slice(%191, begin=[0, 12], end=[2, 16], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%193 = sigmoid(%192) /* ty=Tensor[(2, 4), float32] */;
%194 = strided_slice(%191, begin=[0, 4], end=[2, 8], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%195 = sigmoid(%194) /* ty=Tensor[(2, 4), float32] */;
%196 = %state.25.1;
%197 = multiply(%195, %196) /* ty=Tensor[(2, 4), float32] */;
%198 = strided_slice(%191, begin=[0, 0], end=[2, 4], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%199 = sigmoid(%198) /* ty=Tensor[(2, 4), float32] */;
%200 = strided_slice(%191, begin=[0, 8], end=[2, 12], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%201 = tanh(%200) /* ty=Tensor[(2, 4), float32] */;
%202 = multiply(%199, %201) /* ty=Tensor[(2, 4), float32] */;
%203 = add(%197, %202) /* ty=Tensor[(2, 4), float32] */;
%204 = nn.layer_norm(%203, %v343, %v344) /* ty=Tensor[(2, 4), float32] */;
%205 = tanh(%204) /* ty=Tensor[(2, 4), float32] */;
%206 = multiply(%193, %205) /* ty=Tensor[(2, 4), float32] */;
%207 = (%206, %204);
%208 = (%206, %207);
%209 = %208.0;
%210 = Nil /* ty=List[Tensor[(2, 4), float32]] */;
%211 = Cons(%209, %210) /* ty=List[Tensor[(2, 4), float32]] */;
%212 = @concat(%211, %outputs.30) /* ty=List[Tensor[(2, 4), float32]] */;
%213 = %208.1;
%while_loop3(%178, %212, %213, %seq_len.3, %output.21) /* ty=(int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), int32, Tensor[(?, 2, 4), float32]) */
} else {
(%i.15, %outputs.30, %state.25, %seq_len.3, %output.21)
}
};
%while_loop3
);
%215 = %214(0 /* ty=int32 */, %173, %174, %176, %123) /* ty=(int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), int32, Tensor[(?, 2, 4), float32]) */;
%216 = %215.1;
%217 = @map(tensor_constructor_float32_2_4(Tensor[(2, 4), float32]), %216) /* ty=List[static_tensor_float32_2_4_t[]] */;
%218 = @tensor_array_stack_float32_2_4(%217) /* ty=static_tensor_float32_?_2_4_t[] */;
%219 = @tensor_get_data_float32_2_4(%218) /* ty=Tensor[(?, 2, 4), float32] */;
%220 = %215.2;
%221 = (%219, %220);
%222 = %221.0;
%223 = Nil /* ty=List[Tensor[(?, 2, 4), float32]] */;
%224 = Cons(%222, %223) /* ty=List[Tensor[(?, 2, 4), float32]] */;
%225 = @concat(%172, %224) /* ty=List[Tensor[(?, 2, 4), float32]] */;
%226 = @map(tensor_constructor_float32_?_2_4(Tensor[(?, 2, 4), float32]), %225) /* ty=List[static_tensor_float32_?_2_4_t[]] */;
%227 = @tensor_array_concat_float32_?_2_4(%226) /* ty=static_tensor_float32_?_2_4_t[] */;
%228 = @tensor_get_data_float32_?_2_4(%227) /* ty=Tensor[(?, 2, 4), float32] */;
%229 = Nil /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%230 = %168.1;
%231 = Nil /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%232 = Cons(%230, %231) /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%233 = @concat(%229, %232) /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%234 = %221.1;
%235 = Nil /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%236 = Cons(%234, %235) /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%237 = @concat(%233, %236) /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%238 = (%228, %237);
%239 = %238.0;
%277 = (
let %while_loop4: fn (int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(?, 2, 4), float32]) -> (int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(?, 2, 4), float32]) = fn (%i.4: int32, %outputs.9: List[Tensor[(2, 4), float32]], %state.7: (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), %output.4: Tensor[(?, 2, 4), float32]) -> (int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(?, 2, 4), float32]) {
%240 = shape_of(%239, dtype="int32") /* ty=Tensor[(3), int32] */;
%241 = take(%240, 0 /* ty=int32 */, axis=0) /* ty=int32 */;
%242 = less(%i.4, %241) /* ty=bool */;
if (%242) {
%243 = add(%i.4, 1 /* ty=int32 */) /* ty=int32 */;
%244 = take(%output.4, %i.4, axis=0) /* ty=Tensor[(2, 4), float32] */;
%245 = transpose(%v390, axes=[1, 0]) /* ty=Tensor[(4, 16), float32] */;
%246 = transpose(%245, axes=[1, 0]) /* ty=Tensor[(16, 4), float32] */;
%247 = nn.dense(%244, %246, units=None) /* ty=Tensor[(2, 16), float32] */;
%248 = nn.layer_norm(%247, %v396, %v397) /* ty=Tensor[(2, 16), float32] */;
%249 = %state.7.0;
%250 = transpose(%v402, axes=[1, 0]) /* ty=Tensor[(4, 16), float32] */;
%251 = transpose(%250, axes=[1, 0]) /* ty=Tensor[(16, 4), float32] */;
%252 = nn.dense(%249, %251, units=None) /* ty=Tensor[(2, 16), float32] */;
%253 = nn.layer_norm(%252, %v408, %v409) /* ty=Tensor[(2, 16), float32] */;
%254 = add(%248, %253) /* ty=Tensor[(2, 16), float32] */;
%255 = strided_slice(%254, begin=[0, 12], end=[2, 16], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%256 = sigmoid(%255) /* ty=Tensor[(2, 4), float32] */;
%257 = strided_slice(%254, begin=[0, 4], end=[2, 8], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%258 = sigmoid(%257) /* ty=Tensor[(2, 4), float32] */;
%259 = %state.7.1;
%260 = multiply(%258, %259) /* ty=Tensor[(2, 4), float32] */;
%261 = strided_slice(%254, begin=[0, 0], end=[2, 4], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%262 = sigmoid(%261) /* ty=Tensor[(2, 4), float32] */;
%263 = strided_slice(%254, begin=[0, 8], end=[2, 12], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%264 = tanh(%263) /* ty=Tensor[(2, 4), float32] */;
%265 = multiply(%262, %264) /* ty=Tensor[(2, 4), float32] */;
%266 = add(%260, %265) /* ty=Tensor[(2, 4), float32] */;
%267 = nn.layer_norm(%266, %v430, %v431) /* ty=Tensor[(2, 4), float32] */;
%268 = tanh(%267) /* ty=Tensor[(2, 4), float32] */;
%269 = multiply(%256, %268) /* ty=Tensor[(2, 4), float32] */;
%270 = (%269, %267);
%271 = (%269, %270);
%272 = %271.0;
%273 = Nil /* ty=List[Tensor[(2, 4), float32]] */;
%274 = Cons(%272, %273) /* ty=List[Tensor[(2, 4), float32]] */;
%275 = @concat(%outputs.9, %274) /* ty=List[Tensor[(2, 4), float32]] */;
%276 = %271.1;
%while_loop4(%243, %275, %276, %output.4) /* ty=(int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(?, 2, 4), float32]) */
} else {
(%i.4, %outputs.9, %state.7, %output.4)
}
};
%while_loop4
);
%278 = %277(0 /* ty=int32 */, %1, %3, %239) /* ty=(int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(?, 2, 4), float32]) */;
%279 = %278.1;
%280 = @map(tensor_constructor_float32_2_4(Tensor[(2, 4), float32]), %279) /* ty=List[static_tensor_float32_2_4_t[]] */;
%281 = @tensor_array_stack_float32_2_4(%280) /* ty=static_tensor_float32_?_2_4_t[] */;
%282 = @tensor_get_data_float32_2_4(%281) /* ty=Tensor[(?, 2, 4), float32] */;
%283 = %278.2;
%284 = (%282, %283);
%285 = %284.0;
%286 = Nil /* ty=List[Tensor[(?, 2, 4), float32]] */;
%287 = Cons(%285, %286) /* ty=List[Tensor[(?, 2, 4), float32]] */;
%288 = @concat(%0, %287) /* ty=List[Tensor[(?, 2, 4), float32]] */;
%289 = Nil /* ty=List[Tensor[(2, 4), float32]] */;
%290 = @nth(%2, 1 /* ty=int32 */) /* ty=(Tensor[(2, 4), float32], Tensor[(2, 4), float32]) */;
%291 = shape_of(%239, dtype="int32") /* ty=Tensor[(3), int32] */;
%292 = take(%291, 0 /* ty=int32 */, axis=0) /* ty=int32 */;
%330 = (
let %while_loop5: fn (int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), int32, Tensor[(?, 2, 4), float32]) -> (int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), int32, Tensor[(?, 2, 4), float32]) = fn (%i.1: int32, %outputs.6: List[Tensor[(2, 4), float32]], %state.6: (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), %seq_len.1: int32, %output.41: Tensor[(?, 2, 4), float32]) -> (int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), int32, Tensor[(?, 2, 4), float32]) {
%293 = less(%i.1, %292) /* ty=bool */;
if (%293) {
%294 = add(%i.1, 1 /* ty=int32 */) /* ty=int32 */;
%295 = subtract(%seq_len.1, %i.1) /* ty=int32 */;
%296 = subtract(%295, 1 /* ty=int32 */) /* ty=int32 */;
%297 = take(%output.41, %296, axis=0) /* ty=Tensor[(2, 4), float32] */;
%298 = transpose(%v471, axes=[1, 0]) /* ty=Tensor[(4, 16), float32] */;
%299 = transpose(%298, axes=[1, 0]) /* ty=Tensor[(16, 4), float32] */;
%300 = nn.dense(%297, %299, units=None) /* ty=Tensor[(2, 16), float32] */;
%301 = nn.layer_norm(%300, %v477, %v478) /* ty=Tensor[(2, 16), float32] */;
%302 = %state.6.0;
%303 = transpose(%v483, axes=[1, 0]) /* ty=Tensor[(4, 16), float32] */;
%304 = transpose(%303, axes=[1, 0]) /* ty=Tensor[(16, 4), float32] */;
%305 = nn.dense(%302, %304, units=None) /* ty=Tensor[(2, 16), float32] */;
%306 = nn.layer_norm(%305, %v489, %v490) /* ty=Tensor[(2, 16), float32] */;
%307 = add(%301, %306) /* ty=Tensor[(2, 16), float32] */;
%308 = strided_slice(%307, begin=[0, 12], end=[2, 16], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%309 = sigmoid(%308) /* ty=Tensor[(2, 4), float32] */;
%310 = strided_slice(%307, begin=[0, 4], end=[2, 8], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%311 = sigmoid(%310) /* ty=Tensor[(2, 4), float32] */;
%312 = %state.6.1;
%313 = multiply(%311, %312) /* ty=Tensor[(2, 4), float32] */;
%314 = strided_slice(%307, begin=[0, 0], end=[2, 4], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%315 = sigmoid(%314) /* ty=Tensor[(2, 4), float32] */;
%316 = strided_slice(%307, begin=[0, 8], end=[2, 12], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%317 = tanh(%316) /* ty=Tensor[(2, 4), float32] */;
%318 = multiply(%315, %317) /* ty=Tensor[(2, 4), float32] */;
%319 = add(%313, %318) /* ty=Tensor[(2, 4), float32] */;
%320 = nn.layer_norm(%319, %v511, %v512) /* ty=Tensor[(2, 4), float32] */;
%321 = tanh(%320) /* ty=Tensor[(2, 4), float32] */;
%322 = multiply(%309, %321) /* ty=Tensor[(2, 4), float32] */;
%323 = (%322, %320);
%324 = (%322, %323);
%325 = %324.0;
%326 = Nil /* ty=List[Tensor[(2, 4), float32]] */;
%327 = Cons(%325, %326) /* ty=List[Tensor[(2, 4), float32]] */;
%328 = @concat(%327, %outputs.6) /* ty=List[Tensor[(2, 4), float32]] */;
%329 = %324.1;
%while_loop5(%294, %328, %329, %seq_len.1, %output.41) /* ty=(int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), int32, Tensor[(?, 2, 4), float32]) */
} else {
(%i.1, %outputs.6, %state.6, %seq_len.1, %output.41)
}
};
%while_loop5
);
%331 = %330(0 /* ty=int32 */, %289, %290, %292, %239) /* ty=(int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), int32, Tensor[(?, 2, 4), float32]) */;
%332 = %331.1;
%333 = @map(tensor_constructor_float32_2_4(Tensor[(2, 4), float32]), %332) /* ty=List[static_tensor_float32_2_4_t[]] */;
%334 = @tensor_array_stack_float32_2_4(%333) /* ty=static_tensor_float32_?_2_4_t[] */;
%335 = @tensor_get_data_float32_2_4(%334) /* ty=Tensor[(?, 2, 4), float32] */;
%336 = %331.2;
%337 = (%335, %336);
%338 = %337.0;
%339 = Nil /* ty=List[Tensor[(?, 2, 4), float32]] */;
%340 = Cons(%338, %339) /* ty=List[Tensor[(?, 2, 4), float32]] */;
%341 = @concat(%288, %340) /* ty=List[Tensor[(?, 2, 4), float32]] */;
%342 = @map(tensor_constructor_float32_?_2_4(Tensor[(?, 2, 4), float32]), %341) /* ty=List[static_tensor_float32_?_2_4_t[]] */;
%343 = @tensor_array_concat_float32_?_2_4(%342) /* ty=static_tensor_float32_?_2_4_t[] */;
%344 = @tensor_get_data_float32_?_2_4(%343) /* ty=Tensor[(?, 2, 4), float32] */;
%345 = Nil /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%346 = %284.1;
%347 = Nil /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%348 = Cons(%346, %347) /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%349 = @concat(%345, %348) /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%350 = %337.1;
%351 = Nil /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%352 = Cons(%350, %351) /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%353 = @concat(%349, %352) /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%354 = (%344, %353);
%355 = %354.0;
%356 = Nil /* ty=List[List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])]] */;
%357 = %122.1;
%358 = Nil /* ty=List[List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])]] */;
%359 = Cons(%357, %358) /* ty=List[List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])]] */;
%360 = @concat(%356, %359) /* ty=List[List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])]] */;
%361 = %238.1;
%362 = Nil /* ty=List[List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])]] */;
%363 = Cons(%361, %362) /* ty=List[List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])]] */;
%364 = @concat(%360, %363) /* ty=List[List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])]] */;
%365 = %354.1;
%366 = Nil /* ty=List[List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])]] */;
%367 = Cons(%365, %366) /* ty=List[List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])]] */;
%368 = @concat(%364, %367) /* ty=List[List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])]] */;
(%355, %368)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment