-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathextensions.cpp
More file actions
502 lines (448 loc) · 21.7 KB
/
extensions.cpp
File metadata and controls
502 lines (448 loc) · 21.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
// Copyright (c) 2023, DeepLink.
#include <cstdint>
#include <iostream>
#include <string>
#include <tuple>
#include <utility>
#include "torch/library.h"
#include <ATen/core/ATen_fwd.h>
#include <ATen/core/Generator.h>
#include <ATen/core/TensorBody.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#include <c10/core/ScalarType.h>
#include <c10/util/Optional.h>
#include <c10/util/OptionalArrayRef.h>
#include <torch/csrc/utils/pybind.h> // IWYU pragma: keep
#include <pybind11/cast.h>
#include <pybind11/detail/common.h>
#include <diopi/diopirt.h>
#include <diopi/functions.h>
#include <diopi/functions_ext.h>
#include <csrc_dipu/diopirt/diopirt_impl.h>
#include <csrc_dipu/runtime/core/DIPUGeneratorImpl.h>
#include "diopi_helper.h"
#include "pybind_type_cast.h"
namespace dipu::dipu_ext {
void extAdamW(at::Tensor& param, at::Tensor& exp_avg, at::Tensor& exp_avg_sq,
c10::optional<at::Tensor>& max_exp_avg_sq_opt, at::Tensor& grad,
float lr, float beta1, float beta2, float epsilon,
float weight_decay, int64_t step, bool amsgrad) {
// the diopiAdamW func has no "maximize" param
callDiopi(diopiAdamW, param, grad, exp_avg, exp_avg_sq, max_exp_avg_sq_opt,
lr, beta1, beta2, epsilon, weight_decay, step, amsgrad);
}
void extRmsNorm(at::Tensor& output, at::Tensor& inv_rms,
const at::Tensor& input,
const OptionalIntArray& normalized_shape,
const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_opt, double eps) {
callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape, weight,
bias_opt, eps);
}
void extRmsNormBackward(at::Tensor& grad_input, at::Tensor& grad_weight,
c10::optional<at::Tensor>& grad_bias_opt,
const at::Tensor& grad_output, const at::Tensor& input,
const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_opt,
const at::Tensor& inv_rms,
const OptionalIntArray& normalized_shape, double eps) {
callDiopi(diopiRMSNormBackward, grad_input, grad_weight, grad_bias_opt,
grad_output, input, weight, bias_opt, inv_rms, normalized_shape,
eps);
}
void extApplyRotary(at::Tensor& output, const at::Tensor& input,
const at::Tensor& cos, const at::Tensor& sin,
const bool conj, const bool interleaved) {
callDiopi(diopiRotaryEmbedding, output, input, cos, sin, conj, interleaved);
}
auto extMultiHeadAttention(at::Tensor& q, at::Tensor& k, at::Tensor& v,
double dropout_p, bool is_causal,
bool return_debug_mask, double scale) {
const auto batch_size = q.sizes()[0];
const auto q_seq_len = q.sizes()[1];
const auto head_num = q.sizes()[2];
const auto k_seq_len = k.sizes()[1];
auto out = at::empty_like(q);
const IntArray softmax_lse_size{batch_size, head_num, q_seq_len};
const auto softmax_lse_option = q.options().dtype(at::kFloat);
auto softmax_lse = at::empty(softmax_lse_size, softmax_lse_option);
auto gen = createDIPUGenerator();
const auto debug_attn_mask_size =
return_debug_mask ? IntArray{batch_size, head_num, q_seq_len, k_seq_len}
: IntArray{0};
const auto debug_attn_mask_option = q.options().dtype(at::kBool);
auto debug_attn_mask =
at::empty(debug_attn_mask_size, debug_attn_mask_option);
callDiopi(diopiMultiHeadAttention, q, k, v, dropout_p, is_causal,
return_debug_mask, scale, out, softmax_lse, gen, debug_attn_mask);
return std::make_tuple(std::move(out), std::move(softmax_lse), std::move(gen),
std::move(debug_attn_mask));
}
// grad_q, grad_k, grad_v are output args, and should be pre-allocated.
auto extMultiHeadAttentionBackward(const at::Tensor& grad_out,
const at::Tensor& q, const at::Tensor& k,
const at::Tensor& v, const at::Tensor& out,
const at::Tensor& softmax_lse,
double dropout_p, bool is_causal,
at::Generator& gen, double scale,
c10::optional<at::Tensor>& grad_q_opt,
c10::optional<at::Tensor>& grad_k_opt,
c10::optional<at::Tensor>& grad_v_opt) {
auto grad_q = grad_q_opt.has_value() ? grad_q_opt.value() : at::empty_like(q);
auto grad_k = grad_k_opt.has_value() ? grad_k_opt.value() : at::empty_like(k);
auto grad_v = grad_v_opt.has_value() ? grad_v_opt.value() : at::empty_like(v);
callDiopi(diopiMultiHeadAttentionBackward, grad_out, q, k, v, out,
softmax_lse, dropout_p, is_causal, gen, scale, grad_q, grad_k,
grad_v);
return std::make_tuple(std::move(grad_q), std::move(grad_k),
std::move(grad_v));
}
auto extMultiHeadAttentionVarLen(at::Tensor& q, at::Tensor& k, at::Tensor& v,
const at::Tensor& cum_seq_q,
const at::Tensor& cum_seq_k,
std::int64_t max_q, std::int64_t max_k,
double dropout_p, bool is_causal,
bool return_debug_mask, double scale) {
const auto head_num = q.sizes()[1];
const auto batch_size = cum_seq_q.sizes()[0] - 1;
auto out = at::empty_like(q);
const IntArray softmax_lse_size{batch_size, head_num, max_q};
const auto softmax_lse_option = q.options().dtype(at::kFloat);
auto softmax_lse = at::empty(softmax_lse_size, softmax_lse_option);
auto gen = createDIPUGenerator();
const auto debug_attn_mask_size =
return_debug_mask ? IntArray{batch_size, head_num, max_q, max_k}
: IntArray{0};
const auto debug_attn_mask_option = q.options().dtype(at::kBool);
auto debug_attn_mask =
at::empty(debug_attn_mask_size, debug_attn_mask_option);
callDiopi(diopiMultiHeadAttentionVarLen, q, k, v, cum_seq_q, cum_seq_k, max_q,
max_k, dropout_p, is_causal, return_debug_mask, scale, out,
softmax_lse, gen, debug_attn_mask);
return std::make_tuple(std::move(out), std::move(softmax_lse), std::move(gen),
std::move(debug_attn_mask));
}
// grad_q, grad_k, grad_v are output args, and should be pre-allocated.
auto extMultiHeadAttentionVarLenBackward(
const at::Tensor& grad_out, const at::Tensor& q, const at::Tensor& k,
const at::Tensor& v, const at::Tensor& out, const at::Tensor& softmax_lse,
const at::Tensor& cum_seq_q, const at::Tensor& cum_seq_k,
std::int64_t max_q, std::int64_t max_k, double dropout_p, bool is_causal,
at::Generator& gen, double scale, c10::optional<at::Tensor>& grad_q_opt,
c10::optional<at::Tensor>& grad_k_opt,
c10::optional<at::Tensor>& grad_v_opt) {
auto grad_q = grad_q_opt.has_value() ? grad_q_opt.value() : at::empty_like(q);
auto grad_k = grad_k_opt.has_value() ? grad_k_opt.value() : at::empty_like(k);
auto grad_v = grad_v_opt.has_value() ? grad_v_opt.value() : at::empty_like(v);
callDiopi(diopiMultiHeadAttentionVarLenBackward, grad_out, q, k, v, out,
softmax_lse, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p,
is_causal, gen, scale, grad_q, grad_k, grad_v);
return std::make_tuple(std::move(grad_q), std::move(grad_k),
std::move(grad_v));
}
auto extFlashAttention(at::Tensor& out, const at::Tensor& q,
const at::Tensor& k, const at::Tensor& v,
at::Generator& gen, double p_dropout,
double softmax_scale, bool is_causal, int64_t head_num,
const std::string& input_layout) {
diopiTensorHandle_t attention_mask = nullptr;
diopiTensorHandle_t dropout_mask = nullptr;
diopiTensorHandle_t softmax_max = nullptr;
diopiTensorHandle_t softmax_sum = nullptr;
diopiTensorHandle_t softmax_out = nullptr;
[[maybe_unused]] auto context = callDiopiKeepContext(
diopiFlashAttention, out, &attention_mask, &dropout_mask, &softmax_max,
&softmax_sum, &softmax_out, gen, q, k, v, p_dropout, softmax_scale,
is_causal, head_num, input_layout.c_str());
return std::make_tuple(
attention_mask
? *dipu::diopi_helper::fromDiopiTensorHandle(attention_mask)
: at::Tensor(),
dropout_mask ? *dipu::diopi_helper::fromDiopiTensorHandle(dropout_mask)
: at::Tensor(),
*dipu::diopi_helper::fromDiopiTensorHandle(softmax_max),
*dipu::diopi_helper::fromDiopiTensorHandle(softmax_sum),
*dipu::diopi_helper::fromDiopiTensorHandle(softmax_out));
}
auto extFlashAttentionV2(at::Tensor& out, const at::Tensor& q,
const at::Tensor& k, const at::Tensor& v,
at::Generator& gen, const at::Tensor& attention_mask,
double p_dropout, double softmax_scale,
int64_t head_num, const std::string& input_layout) {
diopiTensorHandle_t dropout_mask = nullptr;
diopiTensorHandle_t softmax_max = nullptr;
diopiTensorHandle_t softmax_sum = nullptr;
diopiTensorHandle_t softmax_out = nullptr;
[[maybe_unused]] auto context = callDiopiKeepContext(
diopiFlashAttentionV2, out, &dropout_mask, &softmax_max, &softmax_sum,
&softmax_out, gen, q, k, v, attention_mask, p_dropout, softmax_scale,
head_num, input_layout.c_str());
return std::make_tuple(
dropout_mask ? *dipu::diopi_helper::fromDiopiTensorHandle(dropout_mask)
: at::Tensor(),
*dipu::diopi_helper::fromDiopiTensorHandle(softmax_max),
*dipu::diopi_helper::fromDiopiTensorHandle(softmax_sum),
*dipu::diopi_helper::fromDiopiTensorHandle(softmax_out));
}
auto extFlashAttentionBackward(at::Tensor& grad_q, at::Tensor& grad_k,
at::Tensor& grad_v, const at::Tensor& grad_out,
const at::Tensor& q, const at::Tensor& k,
const at::Tensor& v, const at::Tensor& out,
const c10::optional<at::Tensor>& attention_mask,
const c10::optional<at::Tensor>& dropout_mask,
const at::Tensor& softmax_max,
const at::Tensor& softmax_sum,
const at::Tensor& softmax_out, double p_dropout,
double softmax_scale, int64_t head_num,
const std::string& input_layout) {
callDiopi(diopiFlashAttentionBackward, grad_q, grad_k, grad_v, grad_out, q, k,
v, out, attention_mask, dropout_mask, softmax_max, softmax_sum,
softmax_out, p_dropout, softmax_scale, head_num,
input_layout.c_str());
return std::make_tuple(std::move(grad_q), std::move(grad_k),
std::move(grad_v));
}
void extScaledMaskedSoftmax(at::Tensor& out, const at::Tensor& input,
const at::Tensor& mask, double scale,
bool fixed_triu_mask) {
callDiopi(diopiScaledMaskedSoftmax, out, input, mask, scale, fixed_triu_mask);
}
void extScaledMaskedSoftmaxBackward(at::Tensor& grad_input,
const at::Tensor& grad_output,
const at::Tensor& out,
const at::Tensor& mask, double scale,
bool fixed_triu_mask) {
callDiopi(diopiScaledMaskedSoftmaxBackward, grad_input, grad_output, out,
mask, scale, fixed_triu_mask);
}
void extDestIndexCopyKV(const at::Tensor& k, const at::Tensor& dest_loc,
at::Tensor& out) {
callDiopi(diopiDestIndexCopyKV, out, k, dest_loc);
}
void extTokenAttentionInference(const at::Tensor& q, const at::Tensor& k,
at::Tensor& out, const at::Tensor& b_loc,
const at::Tensor& b_start_loc,
const at::Tensor& b_seq_len,
int max_input_len) {
callDiopi(diopiTokenAttentionInference, out, q, k, b_loc, b_start_loc,
b_seq_len, max_input_len);
}
void extTokenSoftmaxReduceVInference(const at::Tensor& logics,
const at::Tensor& v, at::Tensor& out,
const at::Tensor& b_loc,
const at::Tensor& b_start_loc,
const at::Tensor& b_seq_len,
int max_input_len, int other_kv_index) {
callDiopi(diopiTokenSoftmaxReduceVInference, out, logics, v, b_loc,
b_start_loc, b_seq_len, max_input_len, other_kv_index);
}
void extContextAttentionInference(const at::Tensor& q, const at::Tensor& k,
const at::Tensor& v, at::Tensor& out,
const at::Tensor& b_start_loc,
const at::Tensor& b_seq_len,
int max_input_len) {
callDiopi(diopiContextAttentionInference, out, q, k, v, b_start_loc,
b_seq_len, max_input_len);
}
void extApplyPenalty(at::Tensor& logits, const at::Tensor& presence_penalty,
const at::Tensor& frequency_penalty,
const at::Tensor& p_token_ids,
const at::Tensor& p_token_counts,
const at::Tensor& p_cumsum_seq_len,
int p_max_len_in_batch) {
callDiopi(diopiApplyPenalty, logits, presence_penalty, frequency_penalty,
p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch);
}
// 判断是否有对应的 diopi 实现:
// 如果有, 则直接 pybind 上去;
// 否则不注册, 等到 python 层处理.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Check if weak symbol defined
if (&diopiAdamW != nullptr) {
m.def("adamw", &extAdamW, "deeplink ext_adamw");
}
if (&diopiFlashAttention != nullptr) {
m.def("fa_fwd", &extFlashAttention, "deeplink ext_fa_fwd");
}
if (&diopiFlashAttentionV2 != nullptr) {
m.def("fa_fwd_v2", &extFlashAttentionV2, "deeplink ext_fa_fwd_v2");
}
if (&diopiFlashAttentionBackward != nullptr) {
m.def("fa_bwd", &extFlashAttentionBackward, "deeplink ext_fa_bwd");
}
if (&diopiRMSNorm != nullptr) {
m.def("rms_norm", &extRmsNorm, "deeplink ext_rms_norm");
}
if (&diopiRMSNormBackward != nullptr) {
m.def("rms_norm_backward", &extRmsNormBackward,
"deeplink ext_rms_norm_backward");
}
if (&diopiRotaryEmbedding != nullptr) {
m.def("apply_rotary", &extApplyRotary, "deeplink ext_apply_rotary");
}
if (&diopiMultiHeadAttention != nullptr) {
m.def("mha_fwd", &extMultiHeadAttention, "deeplink ext_mha_fwd");
}
if (&diopiMultiHeadAttentionBackward != nullptr) {
m.def("mha_bwd", &extMultiHeadAttentionBackward, "deeplink ext_mha_bwd");
}
if (&diopiMultiHeadAttentionVarLen != nullptr) {
m.def("mha_varlen_fwd", &extMultiHeadAttentionVarLen,
"deeplink ext_mha_varlen_fwd");
}
if (&diopiMultiHeadAttentionVarLenBackward != nullptr) {
m.def("mha_varlen_bwd", &extMultiHeadAttentionVarLenBackward,
"deeplink ext_mha_varlen_bwd");
}
if (&diopiDestIndexCopyKV != nullptr) {
m.def("dest_index_copy_kv", &extDestIndexCopyKV,
"deeplink ext_dest_index_copy_kv");
}
if (&diopiTokenAttentionInference != nullptr) {
m.def("token_attention_inference", &extTokenAttentionInference,
"deeplink ext_token_attention_inference");
}
if (&diopiTokenSoftmaxReduceVInference != nullptr) {
m.def("token_softmax_reducev_inference", &extTokenSoftmaxReduceVInference,
"deeplink ext_token_softmax_reducev_inference");
}
if (&diopiContextAttentionInference != nullptr) {
m.def("context_attention_inference", &extContextAttentionInference,
"deeplink ext_context_attention_inference");
}
if (&diopiApplyPenalty != nullptr) {
m.def("apply_penalty", &extApplyPenalty, "deeplink ext_apply_penalty");
}
if (&diopiScaledMaskedSoftmax != nullptr) {
m.def("scaled_masked_softmax_fwd", &extScaledMaskedSoftmax,
"deeplink ext_scaled_masked_softmax_fwd");
}
if (&diopiScaledMaskedSoftmaxBackward != nullptr) {
m.def("scaled_masked_softmax_bwd", &extScaledMaskedSoftmaxBackward,
"deeplink ext_scaled_masked_softmax_bwd");
}
}
std::tuple<at::Tensor&, at::Tensor&, at::Tensor&> adamw(
at::Tensor& param, at::Tensor& exp_avg, at::Tensor& exp_avg_sq,
const c10::optional<at::Tensor>& max_exp_avg_sq_opt, const at::Tensor& grad,
double lr, double beta1, double beta2, double epsilon, double weight_decay,
int64_t step, bool amsgrad) {
// the diopiAdamW func has no "maximize" param
at::Tensor& grad_ref =
const_cast<at::Tensor&>(grad); // todo: grad is const value
at::Tensor max_exp_avg_sq_opt_value =
max_exp_avg_sq_opt.value_or(at::Tensor());
callDiopi(diopiAdamW, param, grad_ref, exp_avg, exp_avg_sq,
max_exp_avg_sq_opt_value, lr, beta1, beta2, epsilon, weight_decay,
step, amsgrad);
return std::tie(param, exp_avg, exp_avg_sq);
}
at::Tensor& apply_penalty(at::Tensor& logits,
const at::Tensor& presence_penalty,
const at::Tensor& frequency_penalty,
const at::Tensor& p_token_ids,
const at::Tensor& p_token_counts,
const at::Tensor& p_cumsum_seq_len,
int64_t p_max_len_in_batch) {
callDiopi(diopiApplyPenalty, logits, presence_penalty, frequency_penalty,
p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch);
return logits;
}
at::Tensor& dest_index_copy_kv(const at::Tensor& k, const at::Tensor& dest_loc,
at::Tensor& out) {
callDiopi(diopiDestIndexCopyKV, out, k, dest_loc);
return out;
}
std::tuple<at::Tensor&, at::Tensor&> rms_norm(
at::Tensor& output, at::Tensor& inv_rms, const at::Tensor& input,
const OptionalIntArray& normalized_shape, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_opt, double eps) {
callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape, weight,
bias_opt, eps);
return std::tie(output, inv_rms);
}
std::tuple<at::Tensor&, at::Tensor&, at::Tensor&> rms_norm_backward(
at::Tensor& grad_input, at::Tensor& grad_weight, at::Tensor& grad_bias_opt,
const at::Tensor& grad_output, const at::Tensor& input,
const at::Tensor& weight, const c10::optional<at::Tensor>& bias_opt,
const at::Tensor& inv_rms, const OptionalIntArray& normalized_shape,
double eps) {
callDiopi(diopiRMSNormBackward, grad_input, grad_weight, grad_bias_opt,
grad_output, input, weight, bias_opt, inv_rms, normalized_shape,
eps);
return std::tie(grad_input, grad_weight, grad_bias_opt);
}
at::Tensor& apply_rotary(at::Tensor& output, const at::Tensor& input,
const at::Tensor& cos, const at::Tensor& sin,
const bool conj, const bool interleaved) {
callDiopi(diopiRotaryEmbedding, output, input, cos, sin, conj, interleaved);
return output;
}
at::Tensor& example_for_all_backend(at::Tensor& inout) {
std::cout << __FUNCTION__ << ": " << inout.options() << "\n";
return inout;
}
at::Tensor& example_only_for_xpu(at::Tensor& inout) {
std::cout << __FUNCTION__ << ": " << inout.options() << "\n";
return inout;
}
// By default, all backends (XPU, AutocastXPU, AutoGradXPU, CUDA, PrivateUse1,
// AutogradPrivateUse1 etc) are registered. If you need to register separately
// for a certain backend, separate registration for a certain backend is also
// supported.
TORCH_LIBRARY(deeplink_ext_, m) {
if (&diopiAdamW != nullptr) {
m.def(
"adamw(Tensor(a!) param, Tensor(b!) exp_avg, Tensor(c!) exp_avg_sq, "
"Tensor? max_exp_avg_sq_opt, Tensor grad, float lr, float beta1, float "
"beta2, float epsilon, float weight_decay, int step, bool "
"amsgrad)->(Tensor(a!), Tensor(b!), Tensor(c!))",
adamw);
}
if (&diopiApplyPenalty != nullptr) {
m.def(
"apply_penalty(Tensor(a!) logits, Tensor presence_penalty, Tensor "
"frequency_penalty, Tensor p_token_ids, Tensor p_token_counts, Tensor "
"p_cumsum_seq_len, int p_max_len_in_batch)->Tensor(a!)",
apply_penalty);
}
if (&diopiDestIndexCopyKV != nullptr) {
m.def(
"dest_index_copy_kv(Tensor(a!) out, Tensor k, Tensor "
"dest_loc)->Tensor(a!)",
dest_index_copy_kv);
}
if (&diopiDestIndexCopyKV != nullptr) {
m.def(
"rms_norm(Tensor(a!) output, Tensor(b!) inv_rms, Tensor input, int[]? "
"normalized_shape, Tensor weight, Tensor? bias_opt, float eps) -> "
"(Tensor(a!), Tensor(b!))",
rms_norm);
}
if (&diopiRMSNormBackward != nullptr) {
m.def(
"rms_norm_backward(Tensor(a!) grad_input, Tensor(b!) grad_weight, "
"Tensor(c!) grad_bias_opt, Tensor grad_output, Tensor input, Tensor "
"weight, Tensor? bias_opt, Tensor inv_rms, int[]? normalized_shape, "
"float eps) -> (Tensor(a!), Tensor(b!), Tensor(c!))",
rms_norm_backward);
}
if (&diopiRotaryEmbedding != nullptr) {
m.def(
"apply_rotary(Tensor(a!) output, Tensor input, Tensor cos, Tensor sin, "
"bool conj, bool interleaved) -> Tensor(a!)",
apply_rotary);
}
m.def("example(Tensor(a!) inout)->Tensor(a!)", example_for_all_backend);
}
// only impl for dipu
TORCH_LIBRARY_IMPL(deeplink_ext_, XPU, m) {
// m.impl("example", example_only_for_xpu);
}
int n = [](){
std::cout << "deeplink_ext_ loaded" << std::endl;
return 0;
}();
} // namespace dipu::dipu_ext