Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 740db46

Browse files
xaedesggerganov
authored andcommittedAug 29, 2023
train : mem usage and other improvements (ggml-org#2439)
* fix track_max_mem in forward_batch_wo_cache_flash_attn_train * remove unnecessary Adam(W) optimizer tensors. reduces optimizer memory overhead from 7*modelsize to 2*modelsize. additionally allows to optimize models with more than 2^31 parameters by replacing int with int64_t. bumps training checkpoint file version, but old checkpoints can still be read. new version with less tensors is saved. * add gradient clipping to AdamW * Fix reset of unused g->nodes and g->grads to NULL * implement gradient checkpointing for training reduces memory overhead from O(n_layer) to O(sqrt(n_layer)) as explained in readme of https://github.com/cybertronai/gradient-checkpointing * remove unused compute buffer 3 * add and use function ggml_build_backward_expand to avoid stack overflows with large maximum number of nodes GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep); * change AdamW decay parameter to work like the torch AdamW decay parameter It is now relative to Adam learning rate `alpha*sched`. Before that it was relative to `sched` only. `alpha` being the maximum learning rate and `sched` being a scaling parameter in [0..1] * change default AdamW weight decay parameter used in training to 0.1 as used in nanoGPT * change default AdamW weight decay parameter defined in ggml to 0.0, making Adam default instead of AdamW btw: the default weight decay parameter for torch.optim.AdamW is 0.01 * bug fixes for cross entropy loss ggml_cross_entropy_loss: sums where not correctly added in workload of each thread ggml_cross_entropy_loss_back: simplify backward process, reducing numerical issues guard usage of exp f16 lookup in cross entropy by #define GGML_CROSS_ENTROPY_EXP_FP16 cross entropy loss is only used once during training, but it is quite sensitive to numerical errors introduced by exp-f16-lookup. so exp-f16-lookup for cross entropy loss is disabled by default, trading better gradients for very slightly worse runtime performance. * fix test-grad0 for cross_entropy_loss the second argument to cross_entropy_loss must sum up to 1 for each row * fix test-grad0 for soft_max dont use only sum as aggregation, because sum of softmax is always 1 -> finite differences should not work instead use sum(log(soft_max()*(1-eps)+eps)); use eps to avoid log(0) * improve finite differences of test-grad0 by using double instead of float * change cross_entropy_loss to output average over all rows this helps keeping the loss and gradients in a sane range * improve gradient checkpointing sqrt(n_layers) is only the best checkpoint step when mem size of checkpoints and mem size of layers are equal. since layers require more memory than the single-tensor-checkpoint we use, the optimal values are compute different: ``` given: n, u, v objective: minimize(a*u+b*v) where a*b=n, a>0, b>0 b=n/a minimize(a*u+v*n/a) diff(a*u+v*n/a, a) = u - (v*n/a)/a diff(a*u+v*n/a, a) == 0 u - (v*n/a)/a == 0 u == v*n/(a*a) u*a*a = v*n a*a = v*n/u a = sqrt(n*v/u) ``` this change results in more checkpoints, requiring less layers to store between checkpoints, overall improving memory usage. * disable gradient checkpointing debug output * llama : fix rope usage in train-text-from-scratch after ChatGLM change * add more training parameters: --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. --adam-min-alpha N Adam minimum learning rate alpha, usually 0.1 * alpha * replace memcpy with reshape operation so that the graph is not cut at the input this makes it possible to store other values into the input tensor and then simply recompute the graph without rebuilding it * remove unused function argument from get_example_targets_batch * measure and print total training time * add optimization callback to ggml_opt_resume_g this callback is called before each iteration with custom data and pointer to learning schedule parameter (only used in Adam(W)). can be used for dynamic learning schedule and setting input data for batches before each iteration * use optimization callback in training allows dynamic learning schedule and different batch data for each iteration without relying on low n_iter and high n_examples parameters reduces runtime by avoiding restart of optimization function and improves training convergence by providing a different batch for each iteration * add minimum number of tensor dimensions to apply weight decay (default 2) this allows to not apply weight decay to bias parameters * rename training parameter cos-decay-alpha to cos-decay-min and clarify that adam-min-alpha also applies to warmup * fix increase of model.train_samples and model.train_tokens now that each optimizer iteration gets its own batch we need to multiply by number of opt iterations * change sampling parameters for prediction after training to defaults of common.h and clarify what is context for prediction and what are generated tokens * tighten abs error bounds for cross_entropy_loss in test-grad0 * add conditional compilation of using F16 exp in flash attention uncomment `// #define GGML_FLASH_ATTN_EXP_FP16` to enable usage of f16 exp in flash attention * tighten abs error bounds for flash_attn in test-grad0 * tighten abs error bounds for sqrt in test-grad0 * remove out-commented vectorized code of opt_adam the vectorized code might be bit faster for low number of parameters, but it had a big memory usage overhead * ggml : update ggml_rms_norm_back with configurable eps * llama training : fix ggml_rms_norm_back calls to pass configurable eps * remove trailing whitespace * add train function using automatic gradient checkpointing backward pass and allocator * in train function replace add_inplace by regular add because using add_inplace seems to result in different gradients * don't use allocate hash_map on context because the context has no_alloc=True when using memory allocator resulting in NULL data pointers * correctly clone reshape and permute operations by also cloning tensor->nb values * fix variable name and add missing type cast * terminate recursive tensor cloning when reaching tensor without src tensors * correctly clone view tensors by setting data pointers without this the checkpointing would only work when being used together with memory allocator * fix variable names * swap arguments to commutative ops to be the same as in `forward_batch_wo_cache_flash_attn` * add input tensors as checkpoints so that recursive tensor cloning of gradient checkpointing terminates on input tensors * fix variable name and add missing boolean negation * make sure some tensors are not reallocated by inserting new temporary nodes depending on them: output and parameter gradient tensors need to be available at the end of the graph execution parameter gradient tensors also need to be available before the graph execution because they are set to zero before each optimizer iteration checkpoint tensors are allocated all together to reduce memory allocator fragmentation afterwards, in addition to the temporary nodes, we also need to reset the temporary leafs * fix ASSERT to work with zero layers * add training options whether to use allocator and/or unified training function * integrate unified training function which may use memory allocator the unified training function also supports arguments whether to use flash attention and/or gradient checkpointing * format name of cloned tensors with " (clone)" suffix * set names for tensors in unified train function for easier debugging * allocate graph on context using ggml_new_graph * remove handwritten training functions * remove unused training parameters "use_scratch" and "use_unified" * remove trailing whitespace * remove unused train params: mem_compute1_gb & mem_compute2_gb mem_compute_gb is used for compute when automatic memory allocator is not enabled, otherwise it can be very small to only hold the tensor definitions mem_compute0_gb is used for automatic memory allocator (as long as measurement of max required size is not implemented) * remove unused forward_batch function * add debug asserts in ggml_allocr_alloc to some common pitfalls when using this function directly * only use ggml_allocr_alloc when tensor has NULL data and is no view * fix test when to create temporary backward graph temporary backward graph is only necessary when using checkpointing * fix memory "leak" in optimizers each iteration a new cplan with new memory for work data was allocated. now cplan creation only happens at the start of optimization, with each iteration reusing the cplan and its work data. * reverse order of for loop in ggml_build_backward_expand to save memory when using gradient checkpointing and allocator with this loop order gradient checkpointing with allocator on 16 layer model saves 13% memory; 2 layer memory it saves 2% memory. the computation results are the same * add missing lctx argument to get_example_targets_batch * implement llama model file saving using gguf checkpoint loading and saving disabled, to be replaced by loading and saving via gguf * implement loading/saving of checkpointing files using GGUF * bug fixes * add checkpoint file version for future compatibility * update readme with gguf filenames * save & load opt->just_initialized value * add first draft for checkpoint conversion script * add gguf arch and ftype * save opt parameter counter as uint64 * add gguf key and tensor names for optimizer and training * add layer_norm_rms_eps to checkpoint convert script * use same GGUF_GET_KEY macro as in llama.cpp * use norm_rms_eps, and rope parameters and command line options to set them * fix memory corruption bug in gguf ctx->kv and ctx->infos was reallocated using not-aligned realloc, but freed with aligned free. to fix this a GGML_ALIGNED_REALLOC was added, but there is no posix_memalign_realloc function. so on non-windows and non-mingw32 platforms we fall back to aligned malloc, followed by copying and freeing the old data. * add gguf example cmake file * bug fixes in tokenize_file * bug fixes in load_llama_model_gguf * bug fix: init model when no checkpoint was loaded * bug fix in read_tensor_by_name * bug fix in load_opt_context_gguf * avoid printing lots of spaced on the unusual case that loss gets nan * set name of tensors with empty name from what was read from gguf * remove trailing whitespace * print data checksums before saving and after loading to verify correctness * bug fixes for convert-train-checkpoint-to-gguf * temporarily add code to write old checkpoint files used to verify that old checkpoint files are correctly converted to gguf * bug fixes for convert-train-checkpoint-to-gguf.py loading checkpoints with opt_version=0 * remove code used to verify correctness of checkpoint file conversion * remove trailing whitespace * remove prediction related code use main for prediction, it is better optimized * update train-text-from-scratch README.md * fix non-windows GGML_ALIGNED_REALLOC * add missing blank line at end of file * remove GGML_ALIGNED_REALLOC and use normal malloc/realloc/free for gguf ctx->kv & ctx->infos * train : fix compile warnings --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent fa3edd6 commit 740db46

File tree

11 files changed

+1890
-2458
lines changed

11 files changed

+1890
-2458
lines changed
 

‎common/common.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <string>
1616
#include <unordered_set>
1717
#include <vector>
18+
#include <cinttypes>
1819

1920
#if defined(__APPLE__) && defined(__MACH__)
2021
#include <sys/types.h>
@@ -938,8 +939,8 @@ std::string get_sortable_timestamp() {
938939

939940
const int64_t ns = std::chrono::duration_cast<std::chrono::nanoseconds>(
940941
current_time.time_since_epoch() % 1000000000).count();
941-
char timestamp_ns[10];
942-
snprintf(timestamp_ns, 11, "%09ld", ns);
942+
char timestamp_ns[11];
943+
snprintf(timestamp_ns, 11, "%09" PRId64, ns);
943944

944945
return std::string(timestamp_no_ns) + "." + std::string(timestamp_ns);
945946
}

‎examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,6 @@ void save_as_llama_model(struct llama_vocab * vocab, struct my_llama_model * mod
681681

682682
// for rms-att-weight
683683
int row_length = model->hparams.n_embd;
684-
const auto & hparams = model->hparams;
685684
int n_ff = model->hparams.n_ff;
686685

687686
for (uint32_t i = 0; i < model->hparams.n_layer; ++i){

‎examples/gguf/CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(TARGET gguf)
2+
add_executable(${TARGET} gguf.cpp)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_11)

‎examples/train-text-from-scratch/README.md

+7-7
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@ wget https://raw.githubusercontent.com/brunoklein99/deep-learning-notes/master/s
88

99
# train
1010
./bin/train-text-from-scratch \
11-
--vocab-model ../models/ggml-vocab.bin \
11+
--vocab-model ../models/ggml-vocab-llama.gguf \
1212
--ctx 64 --embd 256 --head 8 --layer 16 \
13-
--checkpoint-in chk-shakespeare-256x16.bin \
14-
--checkpoint-out chk-shakespeare-256x16.bin \
15-
--model-out ggml-shakespeare-256x16-f32.bin \
13+
--checkpoint-in chk-shakespeare-256x16.gguf \
14+
--checkpoint-out chk-shakespeare-256x16.gguf \
15+
--model-out ggml-shakespeare-256x16-f32.gguf \
1616
--train-data "shakespeare.txt" \
17-
-t 6 -b 16 -n 32 --seed 1 --adam-iter 16 \
18-
--print-details-interval 0 --predict 16 --use-flash
17+
-t 6 -b 16 --seed 1 --adam-iter 256 \
18+
--no-checkpointing
1919

2020
# predict
21-
./bin/main -m ggml-shakespeare-256x16-f32.bin
21+
./bin/main -m ggml-shakespeare-256x16-f32.gguf
2222
```

‎examples/train-text-from-scratch/convert-train-checkpoint-to-gguf.py

+492
Large diffs are not rendered by default.

‎examples/train-text-from-scratch/train-text-from-scratch.cpp

+1,151-2,249
Large diffs are not rendered by default.

‎ggml-alloc.c

+4
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ static size_t ggml_allocator_get_alloc_size(struct ggml_allocr * alloc, struct g
107107
}
108108

109109
void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
110+
#ifdef GGML_ALLOCATOR_DEBUG
111+
GGML_ASSERT(ggml_is_view(tensor) == false); // views generally get data pointer from one of their sources
112+
GGML_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated
113+
#endif
110114
size_t size = ggml_allocator_get_alloc_size(alloc, tensor);
111115
size = aligned_offset(NULL, size, alloc->alignment);
112116

‎ggml.c

+170-165
Large diffs are not rendered by default.

‎ggml.h

+17-12
Original file line numberDiff line numberDiff line change
@@ -952,11 +952,11 @@ extern "C" {
952952

953953
// a - x
954954
// b - dy
955-
// TODO: update with configurable eps
956955
GGML_API struct ggml_tensor * ggml_rms_norm_back(
957956
struct ggml_context * ctx,
958957
struct ggml_tensor * a,
959-
struct ggml_tensor * b);
958+
struct ggml_tensor * b,
959+
float eps);
960960

961961
// A: n columns, m rows
962962
// B: n columns, p rows (i.e. we transpose it internally)
@@ -1612,7 +1612,8 @@ extern "C" {
16121612
struct ggml_tensor * tensor);
16131613

16141614

1615-
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
1615+
GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
1616+
GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep);
16161617

16171618
GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
16181619
GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
@@ -1677,6 +1678,8 @@ extern "C" {
16771678
GGML_LINESEARCH_INVALID_PARAMETERS,
16781679
};
16791680

1681+
typedef void (*ggml_opt_callback)(void * data, float * sched);
1682+
16801683
// optimization parameters
16811684
//
16821685
// see ggml.c (ggml_opt_default_params) for default values
@@ -1712,12 +1715,14 @@ extern "C" {
17121715

17131716
float sched; // schedule multiplier (fixed, decay or warmup)
17141717
float decay; // weight decay for AdamW, use 0.0f to disable
1718+
int decay_min_ndim; // minimum number of tensor dimension to apply weight decay
17151719
float alpha; // learning rate
17161720
float beta1;
17171721
float beta2;
17181722
float eps; // epsilon for numerical stability
17191723
float eps_f; // epsilon for convergence test
17201724
float eps_g; // epsilon for convergence test
1725+
float gclip; // gradient clipping
17211726
} adam;
17221727

17231728
// LBFGS parameters
@@ -1745,14 +1750,12 @@ extern "C" {
17451750

17461751
bool just_initialized;
17471752

1753+
float loss_before;
1754+
float loss_after;
1755+
17481756
struct {
1749-
struct ggml_tensor * x; // view of the parameters
1750-
struct ggml_tensor * g1; // gradient
1751-
struct ggml_tensor * g2; // gradient squared
17521757
struct ggml_tensor * m; // first moment
17531758
struct ggml_tensor * v; // second moment
1754-
struct ggml_tensor * mh; // first moment hat
1755-
struct ggml_tensor * vh; // second moment hat
17561759
struct ggml_tensor * pf; // past function values
17571760
float fx_best;
17581761
float fx_prev;
@@ -1789,10 +1792,10 @@ extern "C" {
17891792

17901793
// initialize optimizer context
17911794
GGML_API void ggml_opt_init(
1792-
struct ggml_context * ctx,
1795+
struct ggml_context * ctx,
17931796
struct ggml_opt_context * opt,
1794-
struct ggml_opt_params params,
1795-
int64_t nx);
1797+
struct ggml_opt_params params,
1798+
int64_t nx);
17961799

17971800
// continue optimizing the function defined by the tensor f
17981801
GGML_API enum ggml_opt_result ggml_opt_resume(
@@ -1806,7 +1809,9 @@ extern "C" {
18061809
struct ggml_opt_context * opt,
18071810
struct ggml_tensor * f,
18081811
struct ggml_cgraph * gf,
1809-
struct ggml_cgraph * gb);
1812+
struct ggml_cgraph * gb,
1813+
ggml_opt_callback callback,
1814+
void * callback_data);
18101815

18111816
//
18121817
// quantization

‎llama.cpp

+4-5
Original file line numberDiff line numberDiff line change
@@ -6248,7 +6248,6 @@ const char * llama_print_system_info(void) {
62486248
}
62496249

62506250
void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) {
6251-
62526251
fprintf(stream, "\n");
62536252
fprintf(stream, "###########\n");
62546253
fprintf(stream, "# Timings #\n");
@@ -6264,10 +6263,10 @@ void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) {
62646263
fprintf(stream, "n_eval: %d # number of tokens generated (excluding the first one)\n", ctx->n_eval);
62656264
fprintf(stream, "n_p_eval: %d # number of tokens processed in batches at the beginning\n", ctx->n_p_eval);
62666265
fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->n_sample);
6267-
fprintf(stream, "t_eval_us: %ld # total microseconds spent generating tokens\n", ctx->t_eval_us);
6268-
fprintf(stream, "t_load_us: %ld # total microseconds spent loading the model\n", ctx->t_load_us);
6269-
fprintf(stream, "t_p_eval_us: %ld # total microseconds spent prompt processing\n", ctx->t_p_eval_us);
6270-
fprintf(stream, "t_sample_us: %ld # total microseconds spent sampling\n", ctx->t_sample_us);
6266+
fprintf(stream, "t_eval_us: %" PRId64 " # total microseconds spent generating tokens\n", ctx->t_eval_us);
6267+
fprintf(stream, "t_load_us: %" PRId64 " # total microseconds spent loading the model\n", ctx->t_load_us);
6268+
fprintf(stream, "t_p_eval_us: %" PRId64 " # total microseconds spent prompt processing\n", ctx->t_p_eval_us);
6269+
fprintf(stream, "t_sample_us: %" PRId64 " # total microseconds spent sampling\n", ctx->t_sample_us);
62716270
fprintf(stream, "ts_eval: %.2f # tokens / second during generation\n",
62726271
1.0e6 * ctx->n_eval / ctx->t_eval_us);
62736272
fprintf(stream, "ts_p_eval: %.2f # tokens / second during prompt processing\n",

‎tests/test-grad0.cpp

+37-17
Original file line numberDiff line numberDiff line change
@@ -275,14 +275,14 @@ static bool check_gradient(
275275

276276
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
277277

278-
const float f0 = ggml_get_f32_1d(f, 0);
278+
const double f0 = ggml_get_f32_1d(f, 0);
279279

280280
ggml_set_f32_1d(x[i], k, xm);
281281

282282
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
283283

284-
const float f1 = ggml_get_f32_1d(f, 0);
285-
const float g0 = (f0 - f1)/(2.0f*eps);
284+
const double f1 = ggml_get_f32_1d(f, 0);
285+
const double g0 = (f0 - f1)/(2.0*(double) eps);
286286

287287
ggml_set_f32_1d(x[i], k, x0);
288288

@@ -292,10 +292,10 @@ static bool check_gradient(
292292

293293
ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
294294

295-
const float g1 = ggml_get_f32_1d(x[i]->grad, k);
295+
const double g1 = ggml_get_f32_1d(x[i]->grad, k);
296296

297-
const float error_abs = fabsf(g0 - g1);
298-
const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabsf(g0) : 0;
297+
const double error_abs = fabs(g0 - g1);
298+
const double error_rel = g0 != 0 ? fabs(g0 - g1)/fabs(g0) : 0;
299299

300300
if (error_abs > max_error_abs || error_rel > max_error_rel) {
301301
printf("%s: ndims=%d, i=%d, k=%d, x0=%f, xm=%f, xp=%f, f0=%f, f1=%f, g0=%f, g1=%f, eps=%f, error_abs=%f, error_rel=%f\n",
@@ -531,7 +531,7 @@ int main(int argc, const char ** argv) {
531531

532532
struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqrt(ctx0, x[0]));
533533

534-
check_gradient("sqrt", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-1f);
534+
check_gradient("sqrt", ctx0, x, f, ndims, nargs, 1e-3f, 2e-2f, 1e-1f);
535535
}
536536
}
537537

@@ -1345,9 +1345,18 @@ int main(int argc, const char ** argv) {
13451345
x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
13461346
ggml_set_param(ctx0, x[0]);
13471347

1348-
struct ggml_tensor * f = ggml_sum(ctx0, ggml_soft_max(ctx0, x[0]));
1349-
1350-
check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
1348+
float eps = 1e-6f;
1349+
// dont use only sum as aggregation, because sum of softmax is always 1 -> finite differences should not work
1350+
// instead use sum(log(soft_max()*(1-eps)+eps)); use eps to avoid log(0)
1351+
struct ggml_tensor * f = ggml_sum(ctx0,
1352+
ggml_log(ctx0,
1353+
ggml_add1(ctx0,
1354+
ggml_scale(ctx0,
1355+
ggml_soft_max(ctx0, x[0]),
1356+
ggml_new_f32(ctx0, 1.0f - eps)),
1357+
ggml_new_f32(ctx0, eps))));
1358+
1359+
check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 2e-1f, INFINITY);
13511360
}
13521361
}
13531362

@@ -1358,15 +1367,26 @@ int main(int argc, const char ** argv) {
13581367
int64_t ne2[4];
13591368
get_random_dims(ne2, 4);
13601369

1361-
for (int ndims = 1; ndims <= 3; ++ndims) {
1362-
x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
1370+
for (int ndims = 1; ndims <= 4; ++ndims) {
1371+
x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -0.1f, 0.1f);
13631372
x[1] = get_random_tensor_f32(ctx0, ndims, ne2, 0.0f, 1.0f);
1373+
// the second argument to cross_entropy_loss must sum up to 1 for each row
1374+
int nr = ggml_nrows(x[1]);
1375+
int nc = ggml_nelements(x[1]) / nr;
1376+
for (int ir = 0; ir < nr; ++ir) {
1377+
float sum = 0;
1378+
for (int ic = 0; ic < nc; ++ic) {
1379+
sum += ((float *) x[1]->data)[ic + ir*nc];
1380+
}
1381+
for (int ic = 0; ic < nc; ++ic) {
1382+
((float *) x[1]->data)[ic + ir*nc] /= sum;
1383+
}
1384+
}
13641385
ggml_set_param(ctx0, x[0]);
13651386

1366-
struct ggml_tensor * f = ggml_sum(ctx0, ggml_cross_entropy_loss(ctx0, x[0], x[1]));
1387+
struct ggml_tensor * f = ggml_cross_entropy_loss(ctx0, x[0], x[1]);
13671388

1368-
check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-1f, 1e-2f, INFINITY);
1369-
// finite differences regularly fails!
1389+
check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-4f, 1e-3f, INFINITY);
13701390
}
13711391
}
13721392

@@ -1473,7 +1493,7 @@ int main(int argc, const char ** argv) {
14731493

14741494
struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
14751495

1476-
check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f);
1496+
check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
14771497
}
14781498
}
14791499
}
@@ -1514,7 +1534,7 @@ int main(int argc, const char ** argv) {
15141534

15151535
struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
15161536

1517-
check_gradient("flash_attn f16", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f);
1537+
check_gradient("flash_attn f16", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
15181538
}
15191539
}
15201540
}

0 commit comments

Comments
 (0)
Please sign in to comment.