lin
2025-08-14 dae8bad597b6607a449b32bf76c523423f7720ed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
 
    http://www.apache.org/licenses/LICENSE-2.0
 
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifdef INTEL_MKL
#include "mkldnn.hpp"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/tensor_format.h"
 
using mkldnn::batch_normalization_backward;
using mkldnn::batch_normalization_forward;
using mkldnn::prop_kind;
using mkldnn::stream;
using mkldnn::use_global_stats;
using mkldnn::use_scale_shift;
 
namespace tensorflow {
using CPUDevice = Eigen::ThreadPoolDevice;
 
struct MklBatchNormFwdParams {
  memory::dims src_dims;
  int depth;
  float eps;
  bool training;
 
  MklBatchNormFwdParams(const memory::dims& src_dims, int depth, float eps,
                        bool training)
      : src_dims(src_dims), depth(depth), eps(eps), training(training) {}
};
 
template <typename T>
class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
 public:
  explicit MklFusedBatchNormFwdPrimitive(const MklBatchNormFwdParams& fwdParams)
      : cpu_engine_(engine::cpu, 0) {
    context_.fwd_stream.reset(new mkldnn::stream(mkldnn::stream::kind::eager));
    if (context_.bn_fwd == nullptr) Setup(fwdParams);
  }
 
  ~MklFusedBatchNormFwdPrimitive() {}
 
  // BatchNormalization forward execute
  //   src_data:     input data buffer of src
  //   weights_data: input data buffer of weights
  //   dst_data:     output data buffer of dst
  //   mean_data:     output data buffer of means
  //   variance_data: output data buffer of variances
  void Execute(const T* src_data, const T* weights_data, T* dst_data,
               T* mean_data, T* variance_data) {
    context_.src_mem->set_data_handle(
        static_cast<void*>(const_cast<T*>(src_data)));
    context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
 
    if (context_.flags & use_scale_shift)
      context_.weights_mem->set_data_handle(
          static_cast<void*>(const_cast<T*>(weights_data)));
 
    if ((context_.pkind == prop_kind::forward_training) ||
        (context_.flags & use_global_stats)) {
      context_.mean_mem->set_data_handle(static_cast<void*>(mean_data));
      context_.variance_mem->set_data_handle(static_cast<void*>(variance_data));
    }
 
    // execution
    context_.fwd_stream->submit(context_.fwd_primitives);
 
    context_.src_mem->set_data_handle(DummyData);
    context_.dst_mem->set_data_handle(DummyData);
 
    if (context_.flags & use_scale_shift)
      context_.weights_mem->set_data_handle(DummyData);
 
    if ((context_.pkind == prop_kind::forward_training) ||
        (context_.flags & use_global_stats)) {
      context_.mean_mem->set_data_handle(DummyData);
      context_.variance_mem->set_data_handle(DummyData);
    }
  }
 
  memory::primitive_desc GetDstPd() const {
    return (*context_.dst_mem).get_primitive_desc();
  }
 
  mkldnn_memory_format_t GetSrcFmt() const {
    return (*context_.src_mem).get_primitive_desc().desc().data.format;
  }
 
  mkldnn_memory_format_t GetDstFmt() const {
    return (*context_.dst_mem).get_primitive_desc().desc().data.format;
  }
 
 private:
  // Primitive reuse context for BatchNorm fwd op
  struct BatchNormFwdContext {
    // flags indict if it is training or inference mode
    int64 flags;
 
    // algorithm
    mkldnn::prop_kind pkind;
 
    // Mkldnn Memory
    std::shared_ptr<mkldnn::memory> src_mem;
    std::shared_ptr<mkldnn::memory> weights_mem;
    std::shared_ptr<mkldnn::memory> dst_mem;
    std::shared_ptr<mkldnn::memory> mean_mem;
    std::shared_ptr<mkldnn::memory> variance_mem;
 
    // BatchNorm forward primitive
    std::shared_ptr<mkldnn::primitive> bn_fwd;
    std::shared_ptr<mkldnn::stream> fwd_stream;
    std::vector<mkldnn::primitive> fwd_primitives;
 
    BatchNormFwdContext()
        : flags(0),
          pkind(mkldnn::forward_training),
          src_mem(nullptr),
          weights_mem(nullptr),
          dst_mem(nullptr),
          mean_mem(nullptr),
          variance_mem(nullptr),
          bn_fwd(nullptr),
          fwd_stream(nullptr) {}
  };
 
  void Setup(const MklBatchNormFwdParams& fwdParams) {
    context_.flags = fwdParams.training ? use_scale_shift
                                        : (use_scale_shift | use_global_stats);
    context_.pkind = fwdParams.training ? prop_kind::forward_training
                                        : prop_kind::forward_scoring;
 
    // memory desc
    auto src_md = memory::desc({fwdParams.src_dims}, MklDnnType<T>(),
                               get_desired_format(fwdParams.src_dims[1]));
 
    // fwd desc & primitive desc
    auto fwd_desc = batch_normalization_forward::desc(
        context_.pkind, src_md, fwdParams.eps, context_.flags);
    auto fwd_pd =
        batch_normalization_forward::primitive_desc(fwd_desc, cpu_engine_);
 
    // memory primitive
    context_.src_mem.reset(new memory({src_md, cpu_engine_}, DummyData));
    context_.dst_mem.reset(new memory(fwd_pd.dst_primitive_desc(), DummyData));
 
    if (context_.flags & use_scale_shift) {
      auto weights_desc = memory::desc({2, fwdParams.depth}, MklDnnType<T>(),
                                       memory::format::nc);
      context_.weights_mem.reset(
          new memory({weights_desc, cpu_engine_}, DummyData));
    }
 
    if (fwdParams.training || (context_.flags & use_global_stats)) {
      auto mean_desc = memory::desc({1, fwdParams.depth}, MklDnnType<T>(),
                                    memory::format::nc);
      context_.mean_mem.reset(new memory({mean_desc, cpu_engine_}, DummyData));
 
      auto variance_desc =
          memory::desc({1, fwdParams.depth}, MklDnnType<T>(), memory::nc);
      context_.variance_mem.reset(
          new memory({variance_desc, cpu_engine_}, DummyData));
    }
 
    // BatchNorm forward primitive
    if (!fwdParams.training && !(context_.flags & use_global_stats)) {
      if ((context_.flags & use_scale_shift) && mkldnn_use_scaleshift) {
        context_.bn_fwd.reset(new batch_normalization_forward(
            fwd_pd, *context_.src_mem, *context_.weights_mem,
            *context_.dst_mem));
      } else {
        context_.bn_fwd.reset(new batch_normalization_forward(
            fwd_pd, *context_.src_mem, *context_.dst_mem));
      }
    } else if (context_.flags & use_global_stats) {
      if ((context_.flags & use_scale_shift) && mkldnn_use_scaleshift) {
        context_.bn_fwd.reset(new batch_normalization_forward(
            fwd_pd, *context_.src_mem, (const primitive::at)*context_.mean_mem,
            (const primitive::at)*context_.variance_mem, *context_.weights_mem,
            *context_.dst_mem));
      } else {
        context_.bn_fwd.reset(new batch_normalization_forward(
            fwd_pd, *context_.src_mem, (const primitive::at)*context_.mean_mem,
            (const primitive::at)*context_.variance_mem, *context_.dst_mem));
      }
    } else {
      if ((context_.flags & use_scale_shift) && mkldnn_use_scaleshift) {
        context_.bn_fwd.reset(new batch_normalization_forward(
            fwd_pd, *context_.src_mem, *context_.weights_mem, *context_.dst_mem,
            *context_.mean_mem, *context_.variance_mem));
      } else {
        context_.bn_fwd.reset(new batch_normalization_forward(
            fwd_pd, *context_.src_mem, *context_.dst_mem, *context_.mean_mem,
            *context_.variance_mem));
      }
    }
 
    context_.fwd_primitives.push_back(*context_.bn_fwd);
  }
 
  mkldnn::memory::desc get_desc_data(const mkldnn::memory& m) const {
    return m.get_primitive_desc().desc().data;
  }
 
  struct BatchNormFwdContext context_;
  engine cpu_engine_;
};
 
template <typename T>
class MklFusedBatchNormFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
 public:
  static MklFusedBatchNormFwdPrimitive<T>* Get(
      const MklBatchNormFwdParams& fwdParams) {
    auto bn_fwd = static_cast<MklFusedBatchNormFwdPrimitive<T>*>(
        MklFusedBatchNormFwdPrimitiveFactory<T>::GetInstance().GetBatchNormFwd(
            fwdParams));
 
    if (bn_fwd == nullptr) {
      bn_fwd = new MklFusedBatchNormFwdPrimitive<T>(fwdParams);
      MklFusedBatchNormFwdPrimitiveFactory<T>::GetInstance().SetBatchNormFwd(
          fwdParams, bn_fwd);
    }
    return bn_fwd;
  }
 
  static MklFusedBatchNormFwdPrimitiveFactory& GetInstance() {
    static MklFusedBatchNormFwdPrimitiveFactory instance_;
    return instance_;
  }
 
 private:
  MklFusedBatchNormFwdPrimitiveFactory() {}
  ~MklFusedBatchNormFwdPrimitiveFactory() {}
 
  static string CreateKey(const MklBatchNormFwdParams& fwdParams) {
    string prefix = "bn_fwd";
    FactoryKeyCreator key_creator;
    key_creator.AddAsKey(prefix);
    key_creator.AddAsKey(fwdParams.src_dims);
    key_creator.AddAsKey<int>(fwdParams.depth);
    key_creator.AddAsKey<float>(fwdParams.eps);
    key_creator.AddAsKey<bool>(fwdParams.training);
    return key_creator.GetKey();
  }
 
  MklPrimitive* GetBatchNormFwd(const MklBatchNormFwdParams& fwdParams) {
    string key = CreateKey(fwdParams);
    return this->GetOp(key);
  }
 
  void SetBatchNormFwd(const MklBatchNormFwdParams& fwdParams,
                       MklPrimitive* op) {
    string key = CreateKey(fwdParams);
    this->SetOp(key, op);
  }
};
 
struct MklBatchNormBwdParams {
  memory::dims src_dims;
  memory::dims diff_dst_dims;
  int depth;
  float eps;
  bool training;
 
  MklBatchNormBwdParams(memory::dims src_dims, memory::dims diff_dst_dims,
                        int depth, float eps, bool training)
      : src_dims(src_dims),
        diff_dst_dims(diff_dst_dims),
        depth(depth),
        eps(eps),
        training(training) {}
};
 
template <typename T>
class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
 public:
  explicit MklFusedBatchNormBwdPrimitive(const MklBatchNormBwdParams& bwdParams)
      : cpu_engine_(engine::cpu, 0) {
    context_.bwd_stream.reset(new mkldnn::stream(mkldnn::stream::kind::eager));
    if (context_.bn_bwd == nullptr) Setup(bwdParams);
  }
 
  ~MklFusedBatchNormBwdPrimitive() {}
 
  // BatchNormalization backward execute
  //   src_data:       input data buffer of src
  //   mean_data:      input data buffer of mean
  //   variance_data:  input data buffer of variance
  //   diff_dst_data:  input data buffer of diff_dst
  //   weights_data:   input data buffer of weights
  //   diff_src_data:      output data buffer of diff_src
  //   diff_weights_data:  output data buffer of diff_weights
  void Execute(const T* src_data, const T* mean_data, const T* variance_data,
               const T* diff_dst_data, const T* weights_data, T* diff_src_data,
               T* diff_weights_data) {
    context_.src_mem->set_data_handle(
        static_cast<void*>(const_cast<T*>(src_data)));
    context_.mean_mem->set_data_handle(
        static_cast<void*>(const_cast<T*>(mean_data)));
    context_.variance_mem->set_data_handle(
        static_cast<void*>(const_cast<T*>(variance_data)));
    context_.diff_dst_mem->set_data_handle(
        static_cast<void*>(const_cast<T*>(diff_dst_data)));
 
    if (context_.flags & use_scale_shift) {
      context_.weights_mem->set_data_handle(
          static_cast<void*>(const_cast<T*>(weights_data)));
      context_.diff_weights_mem->set_data_handle(
          static_cast<void*>(diff_weights_data));
    }
 
    context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
 
    // execution
    context_.bwd_stream->submit(context_.bwd_primitives);
 
    context_.src_mem->set_data_handle(DummyData);
    context_.mean_mem->set_data_handle(DummyData);
    context_.variance_mem->set_data_handle(DummyData);
    context_.diff_dst_mem->set_data_handle(DummyData);
    if (context_.flags & use_scale_shift) {
      context_.weights_mem->set_data_handle(DummyData);
      context_.diff_weights_mem->set_data_handle(DummyData);
    }
    context_.diff_src_mem->set_data_handle(DummyData);
  }
 
  mkldnn_memory_format_t GetSrcFmt() {
    return (*context_.src_mem).get_primitive_desc().desc().data.format;
  }
 
  mkldnn_memory_format_t GetDiffDstFmt() {
    return (*context_.diff_dst_mem).get_primitive_desc().desc().data.format;
  }
 
  memory::primitive_desc GetDiffSrcPd() {
    return (*context_.diff_src_mem).get_primitive_desc();
  }
 
 private:
  struct BatchNormBwdContext {
    // Flags to indicate whether it is training or inference
    int64 flags;
 
    // MKLDNN memory
    std::shared_ptr<mkldnn::memory> src_mem;
    std::shared_ptr<mkldnn::memory> mean_mem;
    std::shared_ptr<mkldnn::memory> variance_mem;
    std::shared_ptr<mkldnn::memory> diff_dst_mem;
    std::shared_ptr<mkldnn::memory> weights_mem;
    std::shared_ptr<mkldnn::memory> diff_weights_mem;
    std::shared_ptr<mkldnn::memory> diff_src_mem;
 
    // Batch Norm primitive
    std::shared_ptr<mkldnn::primitive> bn_bwd;
    std::vector<mkldnn::primitive> bwd_primitives;
    std::shared_ptr<mkldnn::stream> bwd_stream;
 
    BatchNormBwdContext()
        : src_mem(nullptr),
          mean_mem(nullptr),
          variance_mem(nullptr),
          diff_dst_mem(nullptr),
          weights_mem(nullptr),
          diff_weights_mem(nullptr),
          diff_src_mem(nullptr),
          bwd_stream(nullptr) {}
  };
 
  void Setup(const MklBatchNormBwdParams& bwdParams) {
    context_.flags = bwdParams.training ? use_scale_shift
                                        : (use_scale_shift | use_global_stats);
 
    // memory desc
    auto src_md = memory::desc({bwdParams.src_dims}, MklDnnType<T>(),
                               get_desired_format(bwdParams.src_dims[1]));
    auto diff_dst_md =
        memory::desc({bwdParams.diff_dst_dims}, MklDnnType<T>(),
                     get_desired_format(bwdParams.diff_dst_dims[1]));
    auto variance_desc =
        memory::desc({1, bwdParams.depth}, MklDnnType<T>(), memory::nc);
    auto mean_desc =
        memory::desc({1, bwdParams.depth}, MklDnnType<T>(), memory::format::nc);
    auto weights_desc =
        memory::desc({2, bwdParams.depth}, MklDnnType<T>(), memory::format::nc);
    auto diff_weights_desc = weights_desc;
 
    // fwd desc & primitive desc
    auto fwd_desc = batch_normalization_forward::desc(
        prop_kind::forward_training, src_md, bwdParams.eps,
        bwdParams.training ? use_scale_shift
                           : (use_scale_shift | use_global_stats));
    auto fwd_pd =
        batch_normalization_forward::primitive_desc(fwd_desc, cpu_engine_);
 
    // BatchNorm backward primtive
    //
    // For inference, specify use_global_stats
    //   1. on fwd propagation, use mean and variance provided as inputs.
    //   2. on bwd propagation, mean and variance are considered as constants.
    //      Thus, reduce the amount of MKL computation.
    auto bwd_desc = batch_normalization_backward::desc(
        prop_kind::backward, diff_dst_md, src_md, bwdParams.eps,
        bwdParams.training ? use_scale_shift
                           : (use_scale_shift | use_global_stats));
    auto bn_bwd_pd = batch_normalization_backward::primitive_desc(
        bwd_desc, cpu_engine_, fwd_pd);
 
    // memory primitive
    context_.src_mem.reset(new memory({src_md, cpu_engine_}, DummyData));
    context_.diff_dst_mem.reset(
        new memory({diff_dst_md, cpu_engine_}, DummyData));
    context_.variance_mem.reset(
        new memory({variance_desc, cpu_engine_}, DummyData));
    context_.mean_mem.reset(new memory({mean_desc, cpu_engine_}, DummyData));
    context_.weights_mem.reset(
        new memory({weights_desc, cpu_engine_}, DummyData));
    context_.diff_weights_mem.reset(
        new memory({diff_weights_desc, cpu_engine_}, DummyData));
    context_.diff_src_mem.reset(new memory({src_md, cpu_engine_}, DummyData));
 
    context_.bn_bwd.reset(new batch_normalization_backward(
        bn_bwd_pd, *context_.src_mem, *context_.mean_mem,
        *context_.variance_mem, *context_.diff_dst_mem, *context_.weights_mem,
        *context_.diff_src_mem, *context_.diff_weights_mem));
    context_.bwd_primitives.push_back(*context_.bn_bwd);
  }
 
  struct BatchNormBwdContext context_;
  engine cpu_engine_;
};
 
template <typename T>
class MklFusedBatchNormBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
 public:
  static MklFusedBatchNormBwdPrimitive<T>* Get(
      const MklBatchNormBwdParams& bwdParams) {
    auto bn_bwd = static_cast<MklFusedBatchNormBwdPrimitive<T>*>(
        MklFusedBatchNormBwdPrimitiveFactory<T>::GetInstance().GetBatchNormBwd(
            bwdParams));
    if (bn_bwd == nullptr) {
      bn_bwd = new MklFusedBatchNormBwdPrimitive<T>(bwdParams);
      MklFusedBatchNormBwdPrimitiveFactory<T>::GetInstance().SetBatchNormBwd(
          bwdParams, bn_bwd);
    }
    return bn_bwd;
  }
 
  static MklFusedBatchNormBwdPrimitiveFactory& GetInstance() {
    static MklFusedBatchNormBwdPrimitiveFactory instance_;
    return instance_;
  }
 
 private:
  MklFusedBatchNormBwdPrimitiveFactory() {}
  ~MklFusedBatchNormBwdPrimitiveFactory() {}
 
  static string CreateKey(const MklBatchNormBwdParams& bwdParams) {
    string prefix = "bn_bwd";
    FactoryKeyCreator key_creator;
    key_creator.AddAsKey(prefix);
    key_creator.AddAsKey(bwdParams.src_dims);
    key_creator.AddAsKey(bwdParams.diff_dst_dims);
    key_creator.AddAsKey<int>(bwdParams.depth);
    key_creator.AddAsKey<float>(bwdParams.eps);
    key_creator.AddAsKey<bool>(bwdParams.training);
    return key_creator.GetKey();
  }
 
  MklPrimitive* GetBatchNormBwd(const MklBatchNormBwdParams& bwdParams) {
    string key = CreateKey(bwdParams);
    return this->GetOp(key);
  }
 
  void SetBatchNormBwd(const MklBatchNormBwdParams& bwdParams,
                       MklPrimitive* op) {
    string key = CreateKey(bwdParams);
    this->SetOp(key, op);
  }
};
 
template <typename Device, typename T>
class MklFusedBatchNormOp : public OpKernel {
 public:
  explicit MklFusedBatchNormOp(OpKernelConstruction* context)
      : OpKernel(context) {
    float epsilon;
    OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
    epsilon_ = T(epsilon);
    string tensor_format;
    OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
    OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
                errors::InvalidArgument("Invalid data format"));
    OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
  }
 
  void Compute(OpKernelContext* context) override {
    try {
      const size_t kSrcIndex = 0;       // index of src input tensor
      const size_t kScaleIndex = 1;     // index of scale tensor
      const size_t kShiftIndex = 2;     // index of shift tensor
      const size_t kMeanIndex = 3;      // index of est_mean tensor
      const size_t kVarianceIndex = 4;  // index of est_variance tensor
 
      const Tensor& src_tensor = MklGetInput(context, kSrcIndex);
      const Tensor& scale_tensor = MklGetInput(context, kScaleIndex);
      const Tensor& shift_tensor = MklGetInput(context, kShiftIndex);
      const Tensor& est_mean_tensor = MklGetInput(context, kMeanIndex);
      const Tensor& est_variance_tensor = MklGetInput(context, kVarianceIndex);
 
      TensorShape tf_shape_src;
      MklDnnShape dnn_shape_src;
      GetMklShape(context, kSrcIndex, &dnn_shape_src);
 
      if (dnn_shape_src.IsMklTensor()) {
        tf_shape_src = dnn_shape_src.GetTfShape();
        OP_REQUIRES(context, dnn_shape_src.GetDimension() == 4,
                    errors::InvalidArgument("input must be 4-dimensional",
                                            src_tensor.shape().DebugString()));
      } else {
        tf_shape_src = src_tensor.shape();
        OP_REQUIRES(context, src_tensor.dims() == 4,
                    errors::InvalidArgument("input must be 4-dimensional",
                                            src_tensor.shape().DebugString()));
      }
      OP_REQUIRES(context, scale_tensor.dims() == 1,
                  errors::InvalidArgument("scale must be 1-dimensional",
                                          scale_tensor.shape().DebugString()));
      OP_REQUIRES(context, shift_tensor.dims() == 1,
                  errors::InvalidArgument("offset must be 1-dimensional",
                                          shift_tensor.shape().DebugString()));
      OP_REQUIRES(
          context, est_mean_tensor.dims() == 1,
          errors::InvalidArgument("estimated_mean must be 1-dimensional",
                                  est_mean_tensor.shape().DebugString()));
      OP_REQUIRES(
          context, est_variance_tensor.dims() == 1,
          errors::InvalidArgument("estimated_variance must be 1-dimensional",
                                  est_variance_tensor.shape().DebugString()));
 
      if (is_training_) {
        OP_REQUIRES(
            context, est_mean_tensor.dim_size(0) == 0,
            errors::InvalidArgument("estimated_mean must be empty for training",
                                    est_mean_tensor.shape().DebugString()));
        OP_REQUIRES(context, est_variance_tensor.dim_size(0) == 0,
                    errors::InvalidArgument(
                        "estimated_variance must be empty for training",
                        est_variance_tensor.shape().DebugString()));
      }
 
      // special case: input with 0 element and 0 batch size
      Tensor* dst_tensor = nullptr;
      if (tf_shape_src.num_elements() == 0) {
        HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(),
                         &dst_tensor);
        return;
      }
 
      if (dnn_shape_src.IsMklTensor())
        depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C);
      else
        ExtractParams(context);
 
      // Indices of output tensors
      const size_t kDstIndex = 0;
 
      // allocate 4 output TF tensors
      Tensor* batch_mean_tensor = nullptr;
      Tensor* batch_variance_tensor = nullptr;
      Tensor* saved_mean_tensor = nullptr;
      Tensor* saved_variance_tensor = nullptr;
      AllocateTFOutputs(context, scale_tensor.shape(), &batch_mean_tensor,
                        &batch_variance_tensor, &saved_mean_tensor,
                        &saved_variance_tensor);
 
      if (is_training_)
        SetMeanVariance(*batch_mean_tensor, *batch_variance_tensor);
      else
        SetMeanVariance(est_mean_tensor, est_variance_tensor);
 
      MklDnnData<T> src(&cpu_engine);
      MklDnnData<T> weights(&cpu_engine);
 
      memory::format format_m;
      if (dnn_shape_src.IsMklTensor()) {
        if (dnn_shape_src.IsTensorInNCHWFormat()) {
          format_m = memory::format::nchw;
        } else {
          format_m = memory::format::nhwc;
        }
      } else {
        format_m = TFDataFormatToMklDnnDataFormat(tensor_format_);
      }
 
      // set src primitive
      memory::dims src_dims =
          dnn_shape_src.IsMklTensor()
              ? dnn_shape_src.GetSizesAsMklDnnDims()
              : TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_);
 
      auto src_md = dnn_shape_src.IsMklTensor()
                        ? dnn_shape_src.GetMklLayout()
                        : memory::desc(src_dims, MklDnnType<T>(), format_m);
 
      // MKL-DNN packs scale & shift as "weights":
      // <scale>...<scale><shift>...<shift>
      weights.AllocateBuffer(2 * depth_ * sizeof(T));
      T* weights_data = reinterpret_cast<T*>(weights.GetAllocatedBuffer());
      const T* scale_tf = scale_tensor.flat<T>().data();
      const T* shift_tf = shift_tensor.flat<T>().data();
 
      std::memcpy(weights_data, scale_tf, depth_ * sizeof(T));
      std::memcpy(weights_data + depth_, shift_tf, depth_ * sizeof(T));
      char* saved_mean_data_tf =
          reinterpret_cast<char*>(saved_mean_tensor->flat<T>().data());
      std::memcpy(saved_mean_data_tf, reinterpret_cast<char*>(mean_values_),
                  depth_ * sizeof(T));
 
      char* saved_variance_data_tf =
          reinterpret_cast<char*>(saved_variance_tensor->flat<T>().data());
      std::memcpy(saved_variance_data_tf,
                  reinterpret_cast<char*>(variance_values_),
                  depth_ * sizeof(T));
 
      // get batchnorm op from the pool
      MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_);
      MklFusedBatchNormFwdPrimitive<T>* bn_fwd =
          MklFusedBatchNormFwdPrimitiveFactory<T>::Get(fwdParams);
 
      // check if reorder is needed for src, weights, mean, variance
      const T* src_data = src_tensor.flat<T>().data();
      if (src_md.data.format != bn_fwd->GetSrcFmt()) {
        src.SetUsrMem(src_md, &src_tensor);
        auto src_target = memory::primitive_desc(
            {{src_dims},
             MklDnnType<T>(),
             static_cast<memory::format>(bn_fwd->GetSrcFmt())},
            cpu_engine);
        src.CheckReorderToOpMem(src_target);
        src_data = const_cast<T*>(
            reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
      }
 
      // allocate output (dst) tensor; always set it as MKL-DNN layout
      MklDnnShape dnn_shape_dst;
      TensorShape tf_shape_dst;
      dnn_shape_dst.SetMklTensor(true);
      auto dst_pd = bn_fwd->GetDstPd();
      dnn_shape_dst.SetMklLayout(&dst_pd);
      dnn_shape_dst.SetElemType(MklDnnType<T>());
      auto ndims = dnn_shape_src.IsMklTensor() ? dnn_shape_src.GetDimension()
                                               : src_tensor.shape().dims();
      dnn_shape_dst.SetTfLayout(ndims, src_dims, format_m);
      tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T));
      AllocateOutputSetMklShape(context, kDstIndex, &dst_tensor, tf_shape_dst,
                                dnn_shape_dst);
 
      T* weights_op_data = weights_data;
      T* mean_op_data = saved_mean_tensor->flat<T>().data();
      T* variance_op_data = saved_variance_tensor->flat<T>().data();
      T* dst_data = dst_tensor->flat<T>().data();
 
      // execution
      bn_fwd->Execute(src_data, weights_op_data, dst_data, mean_op_data,
                      variance_op_data);
 
      // copy batch_mean data
      T* batch_mean_data_tf = batch_mean_tensor->flat<T>().data();
      std::memcpy(reinterpret_cast<char*>(batch_mean_data_tf),
                  reinterpret_cast<char*>(saved_mean_data_tf),
                  depth_ * sizeof(T));
      // TODO(yli135): OpMem is same as usr mem since
      // since its format is hard-coded as nc when primitive is created.
 
      // copy batch_variance data with Bessel's correction
      float adjust_factor = 1.0;
      if (is_training_) {
        size_t orig_size = src_dims[0] * src_dims[2] * src_dims[3];
        size_t adjust_size = orig_size - 1;
        adjust_factor = (static_cast<float>(orig_size)) / adjust_size;
      }
 
      auto variance_data = reinterpret_cast<T*>(saved_variance_data_tf);
      auto batch_variance_data = batch_variance_tensor->flat<T>().data();
      if (is_training_) {
        for (int k = 0; k < depth_; k++) {
          batch_variance_data[k] = variance_data[k] * adjust_factor;
        }
      } else {
        std::memcpy(batch_variance_data, variance_data, depth_ * sizeof(T));
      }
    } catch (mkldnn::error& e) {
      string error_msg = "Status: " + std::to_string(e.status) +
                         ", message: " + string(e.message) + ", in file " +
                         string(__FILE__) + ":" + std::to_string(__LINE__);
      OP_REQUIRES_OK(
          context,
          errors::Aborted("Operation received an exception:", error_msg));
    }
  }
 
 private:
  T epsilon_;
  TensorFormat tensor_format_;
  bool is_training_;
  T* mean_values_;
  T* variance_values_;
  size_t depth_;  // batch normalization is done for per channel.
  engine cpu_engine = engine(engine::cpu, 0);
 
  void ExtractParams(OpKernelContext* context) {
    const Tensor& input = MklGetInput(context, 0);
    depth_ = static_cast<int>(GetTensorDim(input, tensor_format_, 'C'));
  }
 
  void SetMeanVariance(const Tensor& mean, const Tensor& variance) {
    mean_values_ = reinterpret_cast<T*>(const_cast<T*>(mean.flat<T>().data()));
    variance_values_ =
        reinterpret_cast<T*>(const_cast<T*>(variance.flat<T>().data()));
  }
 
  void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src,
                        TensorShape tf_shape_scale, Tensor** dst_tensor) {
    CHECK_NOTNULL(dst_tensor);
 
    const size_t kDstIndex = 0;
    MklDnnShape dnn_shape_dst;
    dnn_shape_dst.SetMklTensor(false);
    AllocateOutputSetMklShape(context, kDstIndex, dst_tensor, tf_shape_src,
                              dnn_shape_dst);
    CHECK_NOTNULL(*dst_tensor);
    memset(const_cast<char*>((*dst_tensor)->tensor_data().data()), 0,
           (*dst_tensor)->tensor_data().size());
 
    Tensor* batch_mean_tensor = nullptr;
    Tensor* batch_variance_tensor = nullptr;
    Tensor* saved_mean_tensor = nullptr;
    Tensor* saved_variance_tensor = nullptr;
    AllocateTFOutputs(context, tf_shape_scale, &batch_mean_tensor,
                      &batch_variance_tensor, &saved_mean_tensor,
                      &saved_variance_tensor);
  }
 
  void AllocateTFOutputs(OpKernelContext* context, TensorShape tf_shape_scale,
                         Tensor** batch_mean_tensor,
                         Tensor** batch_variance_tensor,
                         Tensor** saved_mean_tensor,
                         Tensor** saved_variance_tensor) {
    CHECK_NOTNULL(batch_mean_tensor);
    CHECK_NOTNULL(batch_variance_tensor);
    CHECK_NOTNULL(saved_mean_tensor);
    CHECK_NOTNULL(saved_variance_tensor);
 
    const size_t kBatchMeanIndex = 1;
    const size_t kBatchVarianceIndex = 2;
    const size_t kSavedMeanIndex = 3;
    const size_t kSavedVarianceIndex = 4;
 
    // allocate batch mean output tensor
    MklDnnShape mkl_shape_batch_mean;
    mkl_shape_batch_mean.SetMklTensor(false);
    AllocateOutputSetMklShape(context, kBatchMeanIndex, batch_mean_tensor,
                              tf_shape_scale, mkl_shape_batch_mean);
    CHECK_NOTNULL(*batch_mean_tensor);
    // set NAN mean value in case of empty input tensor
    int num_elements = tf_shape_scale.num_elements();
    auto batch_mean_data = (*batch_mean_tensor)->flat<T>().data();
    std::fill_n(batch_mean_data, num_elements, NAN);
 
    // allocate batch variance output tensor
    MklDnnShape mkl_shape_batch_variance;
    mkl_shape_batch_variance.SetMklTensor(false);
    AllocateOutputSetMklShape(context, kBatchVarianceIndex,
                              batch_variance_tensor, tf_shape_scale,
                              mkl_shape_batch_variance);
    CHECK_NOTNULL(*batch_variance_tensor);
    // set NAN variance value in case of empty input tensor
    auto batch_variance_data = (*batch_variance_tensor)->flat<T>().data();
    std::fill_n(batch_variance_data, num_elements, NAN);
 
    // Mean and variance (without Bessel's correction) saved for backward
    // computation to serve as pre-computed mean and variance.
    MklDnnShape mkl_shape_saved_mean;
    mkl_shape_saved_mean.SetMklTensor(false);
    AllocateOutputSetMklShape(context, kSavedMeanIndex, saved_mean_tensor,
                              tf_shape_scale, mkl_shape_saved_mean);
    CHECK_NOTNULL(*saved_mean_tensor);
    // set NAN mean value in case of empty input tensor
    auto saved_mean_data = (*saved_mean_tensor)->flat<T>().data();
    std::fill_n(saved_mean_data, num_elements, NAN);
 
    MklDnnShape mkl_shape_saved_variance;
    mkl_shape_saved_variance.SetMklTensor(false);
    AllocateOutputSetMklShape(context, kSavedVarianceIndex,
                              saved_variance_tensor, tf_shape_scale,
                              mkl_shape_saved_variance);
    CHECK_NOTNULL(*saved_variance_tensor);
    // set NAN variance value in case of empty input tensor
    auto saved_variance_data = (*saved_variance_tensor)->flat<T>().data();
    std::fill_n(saved_variance_data, num_elements, NAN);
  }
};
 
template <typename Device, typename T>
class MklFusedBatchNormGradOp : public OpKernel {
 public:
  explicit MklFusedBatchNormGradOp(OpKernelConstruction* context)
      : OpKernel(context) {
    float epsilon;
    OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
    epsilon_ = T(epsilon);
    string tensor_format;
    OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
    OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
                errors::InvalidArgument("Invalid data format"));
    OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
  }
 
  void Compute(OpKernelContext* context) override {
    try {
      const size_t kDiffDstIndex = 0;   // index of diff_dst tensor
      const size_t kSrcIndex = 1;       // index of src input tensor
      const size_t kScaleIndex = 2;     // index of scale tensor
      const size_t kMeanIndex = 3;      // index of saved_mean tensor
      const size_t kVarianceIndex = 4;  // index of saved_variance tensor
 
      const Tensor& diff_dst_tensor = MklGetInput(context, kDiffDstIndex);
      const Tensor& src_tensor = MklGetInput(context, kSrcIndex);
      const Tensor& scale_tensor = MklGetInput(context, kScaleIndex);
      const Tensor& saved_mean_tensor = MklGetInput(context, kMeanIndex);
      const Tensor& saved_variance_tensor =
          MklGetInput(context, kVarianceIndex);
 
      MklDnnShape dnn_shape_src, dnn_shape_diff_dst;
      GetMklShape(context, kSrcIndex, &dnn_shape_src);
      GetMklShape(context, kDiffDstIndex, &dnn_shape_diff_dst);
 
      TensorShape tf_shape_src, tf_shape_diff_dst;
      if (dnn_shape_diff_dst.IsMklTensor()) {
        tf_shape_diff_dst = dnn_shape_diff_dst.GetTfShape();
        OP_REQUIRES(
            context, dnn_shape_diff_dst.GetDimension() == 4,
            errors::InvalidArgument("input must be 4-dimensional",
                                    diff_dst_tensor.shape().DebugString()));
      } else {
        tf_shape_diff_dst = diff_dst_tensor.shape();
        OP_REQUIRES(
            context, diff_dst_tensor.dims() == 4,
            errors::InvalidArgument("input must be 4-dimensional",
                                    diff_dst_tensor.shape().DebugString()));
      }
 
      if (dnn_shape_src.IsMklTensor()) {
        tf_shape_src = dnn_shape_src.GetTfShape();
        OP_REQUIRES(context, dnn_shape_src.GetDimension() == 4,
                    errors::InvalidArgument("input must be 4-dimensional",
                                            src_tensor.shape().DebugString()));
      } else {
        tf_shape_src = src_tensor.shape();
        OP_REQUIRES(context, src_tensor.dims() == 4,
                    errors::InvalidArgument("input must be 4-dimensional",
                                            src_tensor.shape().DebugString()));
      }
 
      OP_REQUIRES(context, scale_tensor.dims() == 1,
                  errors::InvalidArgument("scale must be 1-dimensional",
                                          scale_tensor.shape().DebugString()));
      OP_REQUIRES(
          context, saved_mean_tensor.dims() == 1,
          errors::InvalidArgument("saved mean must be 1-dimensional",
                                  saved_mean_tensor.shape().DebugString()));
 
      OP_REQUIRES(
          context, saved_variance_tensor.dims() == 1,
          errors::InvalidArgument("saved variance must be 1-dimensional",
                                  saved_variance_tensor.shape().DebugString()));
 
      Tensor* diff_src_tensor = nullptr;
      // special case: input with 0 element and 0 batch size
      if (tf_shape_src.num_elements() == 0 ||
          tf_shape_diff_dst.num_elements() == 0) {
        HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(),
                         &diff_src_tensor);
        return;
      }
 
      if (dnn_shape_src.IsMklTensor()) {
        depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C);
      } else if (dnn_shape_diff_dst.IsMklTensor()) {
        depth_ = dnn_shape_diff_dst.DimSize(MklDnnDims::Dim_C);
      } else {
        ExtractParams(context);
      }
 
      memory::format format_m;
      if (dnn_shape_src.IsMklTensor()) {
        if (dnn_shape_src.IsTensorInNCHWFormat())
          format_m = memory::format::nchw;
        else
          format_m = memory::format::nhwc;
      } else {
        format_m = TFDataFormatToMklDnnDataFormat(tensor_format_);
      }
 
      MklDnnData<T> src(&cpu_engine);
      MklDnnData<T> diff_dst(&cpu_engine);
      MklDnnData<T> weights(&cpu_engine);
      MklDnnData<T> diff_weights(&cpu_engine);
 
      memory::dims src_dims =
          dnn_shape_src.IsMklTensor()
              ? dnn_shape_src.GetSizesAsMklDnnDims()
              : TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_);
      memory::dims diff_dst_dims =
          dnn_shape_diff_dst.IsMklTensor()
              ? dnn_shape_diff_dst.GetSizesAsMklDnnDims()
              : TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(),
                                          tensor_format_);
 
      // set src and diff_dst primitive descriptors
      memory::desc src_md =
          dnn_shape_src.IsMklTensor()
              ? dnn_shape_src.GetMklLayout()
              : memory::desc(src_dims, MklDnnType<T>(), format_m);
      memory::desc diff_dst_md =
          dnn_shape_diff_dst.IsMklTensor()
              ? dnn_shape_diff_dst.GetMklLayout()
              : memory::desc(diff_dst_dims, MklDnnType<T>(), format_m);
 
      // weights -- MKL DNN packs scales/ shifts as weights in order
      // of scale, ..., scale, shift, ...., shift
      weights.AllocateBuffer(2 * depth_ * sizeof(T));
      T* weights_data_tf = reinterpret_cast<T*>(weights.GetAllocatedBuffer());
      const T* scale_tf = scale_tensor.flat<T>().data();
      for (int k = 0; k < depth_; k++) {
        weights_data_tf[k] = scale_tf[k];
        weights_data_tf[k + depth_] = 0;
      }
 
      diff_weights.AllocateBuffer(2 * depth_ * sizeof(T));
 
      MklBatchNormBwdParams bwdParams(src_dims, diff_dst_dims, depth_, epsilon_,
                                      is_training_);
      MklFusedBatchNormBwdPrimitive<T>* bn_bwd =
          MklFusedBatchNormBwdPrimitiveFactory<T>::Get(bwdParams);
 
      // check if src/diff_dst need to be reordered
      const T* src_data = src_tensor.flat<T>().data();
      if (src_md.data.format != bn_bwd->GetSrcFmt()) {
        src.SetUsrMem(src_md, &src_tensor);
        auto src_target = memory::primitive_desc(
            {{src_dims},
             MklDnnType<T>(),
             static_cast<memory::format>(bn_bwd->GetSrcFmt())},
            cpu_engine);
        src.CheckReorderToOpMem(src_target);
        src_data = const_cast<T*>(
            reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
      }
 
      const T* diff_dst_data = diff_dst_tensor.flat<T>().data();
      if (diff_dst_md.data.format != bn_bwd->GetDiffDstFmt()) {
        diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
        auto diff_dst_target = memory::primitive_desc(
            {{diff_dst_dims},
             MklDnnType<T>(),
             static_cast<memory::format>(bn_bwd->GetDiffDstFmt())},
            cpu_engine);
        diff_dst.CheckReorderToOpMem(diff_dst_target);
        diff_dst_data = const_cast<T*>(
            reinterpret_cast<T*>(diff_dst.GetOpMem().get_data_handle()));
      }
 
      // Indices of output tensors
      const size_t kDiffSrcIndex = 0;  // index of diff_src tensor
 
      // allocate output tensor: diff_src, always set as MKL-DNN layout
      MklDnnShape dnn_shape_diff_src;
      TensorShape tf_shape_diff_src;
      dnn_shape_diff_src.SetMklTensor(true);
      auto diff_src_pd = bn_bwd->GetDiffSrcPd();
      dnn_shape_diff_src.SetMklLayout(&diff_src_pd);
      dnn_shape_diff_src.SetElemType(MklDnnType<T>());
      dnn_shape_diff_src.SetTfLayout(src_dims.size(), src_dims, format_m);
      dnn_shape_diff_src.SetTfDimOrder(src_dims.size(), tensor_format_);
      tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T));
      AllocateOutputSetMklShape(context, kDiffSrcIndex, &diff_src_tensor,
                                tf_shape_diff_src, dnn_shape_diff_src);
 
      T* mean_data =
          static_cast<T*>(const_cast<T*>(saved_mean_tensor.flat<T>().data()));
      T* variance_data = static_cast<T*>(
          const_cast<T*>(saved_variance_tensor.flat<T>().data()));
      T* weights_data = weights_data_tf;
      T* diff_src_data = static_cast<T*>(diff_src_tensor->flat<T>().data());
      T* diff_weights_data = static_cast<T*>(diff_weights.GetAllocatedBuffer());
      // Execute
      bn_bwd->Execute(src_data, mean_data, variance_data, diff_dst_data,
                      weights_data, diff_src_data, diff_weights_data);
 
      // allocate output TF tensors: diff_scale and diff_shift
      Tensor* diff_scale_tensor = nullptr;
      Tensor* diff_shift_tensor = nullptr;
      AllocateTFOutputs(context, scale_tensor.shape(), &diff_scale_tensor,
                        &diff_shift_tensor);
 
      // copy data: diff_scale and diff_shift
      auto diff_scale_data = diff_scale_tensor->flat<T>().data();
      auto diff_shift_data = diff_shift_tensor->flat<T>().data();
      std::memcpy(reinterpret_cast<char*>(diff_scale_data),
                  reinterpret_cast<char*>(diff_weights_data),
                  depth_ * sizeof(T));
      std::memcpy(reinterpret_cast<char*>(diff_shift_data),
                  reinterpret_cast<char*>(diff_weights_data + depth_),
                  depth_ * sizeof(T));
    } catch (mkldnn::error& e) {
      string error_msg = "Status: " + std::to_string(e.status) +
                         ", message: " + string(e.message) + ", in file " +
                         string(__FILE__) + ":" + std::to_string(__LINE__);
      OP_REQUIRES_OK(
          context,
          errors::Aborted("Operation received an exception:", error_msg));
    }
  }
 
 private:
  T epsilon_;
  TensorFormat tensor_format_;
  int depth_;  // batch normalization is done for per channel.
  bool is_training_;
  engine cpu_engine = engine(engine::cpu, 0);
 
  void ExtractParams(OpKernelContext* context) {
    const Tensor& input = MklGetInput(context, 0);
    depth_ = static_cast<int>(GetTensorDim(input, tensor_format_, 'C'));
  }
 
  void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src,
                        TensorShape tf_shape_scale_shift,
                        Tensor** diff_src_tensor) {
    const size_t kDiffSrcIndex = 0;
 
    MklDnnShape dnn_shape_diff_src;
    dnn_shape_diff_src.SetMklTensor(false);
    AllocateOutputSetMklShape(context, kDiffSrcIndex, diff_src_tensor,
                              tf_shape_src, dnn_shape_diff_src);
    auto diff_src_data = (*diff_src_tensor)->flat<T>().data();
    std::fill_n(diff_src_data, (*diff_src_tensor)->shape().num_elements(), 0);
 
    Tensor* diff_scale_tensor = nullptr;
    Tensor* diff_shift_tensor = nullptr;
    AllocateTFOutputs(context, tf_shape_scale_shift, &diff_scale_tensor,
                      &diff_shift_tensor);
  }
 
  void AllocateTFOutputs(OpKernelContext* context,
                         TensorShape tf_shape_scale_shift,
                         Tensor** diff_scale_tensor,
                         Tensor** diff_shift_tensor) {
    CHECK_NOTNULL(diff_scale_tensor);
    CHECK_NOTNULL(diff_shift_tensor);
 
    const size_t kDiffScaleIndex = 1;
    const size_t kDiffShiftIndex = 2;
    const size_t kP1Index = 3;
    const size_t kP2Index = 4;
 
    // separate out scale and shift grad and copy to individual tensors
    MklDnnShape mkl_shape_diff_scale;
    mkl_shape_diff_scale.SetMklTensor(false);
    AllocateOutputSetMklShape(context, kDiffScaleIndex, diff_scale_tensor,
                              tf_shape_scale_shift, mkl_shape_diff_scale);
    CHECK_NOTNULL(*diff_scale_tensor);
    auto diff_scale_data = (*diff_scale_tensor)->flat<T>().data();
    std::fill_n(diff_scale_data, (*diff_scale_tensor)->shape().num_elements(),
                0);
 
    MklDnnShape mkl_shape_diff_shift;
    mkl_shape_diff_shift.SetMklTensor(false);
    AllocateOutputSetMklShape(context, kDiffShiftIndex, diff_shift_tensor,
                              tf_shape_scale_shift, mkl_shape_diff_shift);
    CHECK_NOTNULL(*diff_shift_tensor);
    auto diff_shift_data = (*diff_shift_tensor)->flat<T>().data();
    std::fill_n(diff_shift_data, (*diff_shift_tensor)->shape().num_elements(),
                0);
 
    // Placeholders for estimated_mean and estimated_variance, which are
    // used for inference and thus not needed here for gradient computation.
    Tensor *p1_tensor = nullptr, *p2_tensor = nullptr;
    MklDnnShape mkl_shape_p;
    mkl_shape_p.SetMklTensor(false);
    AllocateOutputSetMklShape(context, kP1Index, &p1_tensor, TensorShape({}),
                              mkl_shape_p);
    AllocateOutputSetMklShape(context, kP2Index, &p2_tensor, TensorShape({}),
                              mkl_shape_p);
  }
 
  memory::dims GetMeanVarianceDims() { return memory::dims({1, depth_}); }
};
 
#define REGISTER_MKL_CPU(T)                                         \
  REGISTER_KERNEL_BUILDER(Name("_MklFusedBatchNorm")                \
                              .Device(DEVICE_CPU)                   \
                              .TypeConstraint<T>("T")               \
                              .Label(mkl_op_registry::kMklOpLabel), \
                          MklFusedBatchNormOp<CPUDevice, T>);
TF_CALL_float(REGISTER_MKL_CPU);
#undef REGISTER_MKL_CPU
 
#define REGISTER_MKL_CPU(T)                                         \
  REGISTER_KERNEL_BUILDER(Name("_MklFusedBatchNormGrad")            \
                              .Device(DEVICE_CPU)                   \
                              .TypeConstraint<T>("T")               \
                              .Label(mkl_op_registry::kMklOpLabel), \
                          MklFusedBatchNormGradOp<CPUDevice, T>);
TF_CALL_float(REGISTER_MKL_CPU);
#undef REGISTER_MKL_CPU
}  // namespace tensorflow
 
#endif  // INTEL_MKL