diff --git a/CMakePresets.json b/CMakePresets.json index 64b7fd58ad6..669decdd619 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -40,7 +40,17 @@ "name": "CUDA 13", "inherits": [ "CUDA" ], "cacheVariables": { - "CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual", + "CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual", + "CMAKE_CUDA_FLAGS": "-t 4", + "OLLAMA_RUNNER_DIR": "cuda_v13" + } + }, + { + "name": "CUDA 13 Windows", + "inherits": [ "CUDA" ], + "description": "Reduced architecture set for Windows to avoid MSVC template compilation issues", + "cacheVariables": { + "CMAKE_CUDA_ARCHITECTURES": "75-virtual;89-virtual;100-virtual;120-virtual", "CMAKE_CUDA_FLAGS": "-t 4", "OLLAMA_RUNNER_DIR": "cuda_v13" } @@ -138,6 +148,11 @@ "inherits": [ "CUDA" ], "configurePreset": "CUDA 13" }, + { + "name": "CUDA 13 Windows", + "inherits": [ "CUDA" ], + "configurePreset": "CUDA 13 Windows" + }, { "name": "JetPack 5", "inherits": [ "CUDA" ], diff --git a/Dockerfile b/Dockerfile index 43f511465a5..b5d67eca87b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,7 +7,7 @@ ARG ROCMVERSION=6.3.3 ARG JETPACK5VERSION=r35.4.1 ARG JETPACK6VERSION=r36.4.0 ARG CMAKEVERSION=3.31.2 -ARG VULKANVERSION=1.4.321.1 +ARG VULKANVERSION=1.4.341.1 FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64 RUN dnf install -y yum-utils ccache gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ gcc-toolset-11-binutils \ diff --git a/Makefile.sync b/Makefile.sync index c1c24f2f5dc..6f642d9995e 100644 --- a/Makefile.sync +++ b/Makefile.sync @@ -1,6 +1,6 @@ UPSTREAM=https://github.com/ggml-org/llama.cpp.git WORKDIR=llama/vendor -FETCH_HEAD=ec98e2002 +FETCH_HEAD=feefb928367a01c1912975f0b277a48a14bbcadf .PHONY: help help: diff --git a/app/ui/ui.go b/app/ui/ui.go index f720fe05aa2..7238c1316f5 100644 --- a/app/ui/ui.go +++ b/app/ui/ui.go @@ -1672,7 +1672,6 @@ func supportsBrowserTools(model string) bool { return strings.HasPrefix(strings.ToLower(model), "gpt-oss") } - // buildChatRequest converts store.Chat to api.ChatRequest func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, availableTools []map[string]any) (*api.ChatRequest, error) { var msgs []api.Message diff --git a/integration/embed_test.go b/integration/embed_test.go index e4506673940..57c8b9b972f 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -73,13 +73,18 @@ func manhattanDistance[V float32 | float64](v1, v2 []V) V { } func TestEmbedCosineDistanceCorrelation(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + softTimeout, hardTimeout := getTimeouts(t) + ctx, cancel := context.WithTimeout(context.Background(), hardTimeout) defer cancel() client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() + started := time.Now() for _, model := range libraryEmbedModels { t.Run(model, func(t *testing.T) { + if time.Since(started) > softTimeout { + t.Skip("skipping - soft timeout exceeded") + } testCases := []struct { a string b string @@ -489,14 +494,19 @@ func TestEmbedTruncation(t *testing.T) { // TestEmbedLargeInput tests that embedding models can handle large inputs that would exceed typical batch sizes. func TestEmbedLargeInput(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + softTimeout, hardTimeout := getTimeouts(t) + ctx, cancel := context.WithTimeout(context.Background(), hardTimeout) defer cancel() client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() + started := time.Now() for _, model := range libraryEmbedModels { model := model t.Run(model, func(t *testing.T) { + if time.Since(started) > softTimeout { + t.Skip("skipping - soft timeout exceeded") + } mctx, mcancel := context.WithTimeout(ctx, 2*time.Minute) defer mcancel() diff --git a/integration/tools_test.go b/integration/tools_test.go index 39b3e1a91eb..193706187cd 100644 --- a/integration/tools_test.go +++ b/integration/tools_test.go @@ -21,9 +21,10 @@ func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap { } func TestAPIToolCalling(t *testing.T) { - initialTimeout := 60 * time.Second - streamTimeout := 60 * time.Second - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + initialTimeout := 90 * time.Second + streamTimeout := 90 * time.Second + softTimeout, hardTimeout := getTimeouts(t) + ctx, cancel := context.WithTimeout(context.Background(), hardTimeout) defer cancel() client, _, cleanup := InitServerConnection(ctx, t) @@ -47,8 +48,12 @@ func TestAPIToolCalling(t *testing.T) { "granite3.3": 7, } + started := time.Now() for _, model := range libraryToolsModels { t.Run(model, func(t *testing.T) { + if time.Since(started) > softTimeout { + t.Skip("skipping - soft timeout exceeded") + } if v, ok := minVRAM[model]; ok { skipUnderMinVRAM(t, v) } diff --git a/llama/README.md b/llama/README.md index bfe66a8b43f..40298dc9f54 100644 --- a/llama/README.md +++ b/llama/README.md @@ -14,25 +14,28 @@ make -f Makefile.sync apply-patches ### Updating Base Commit -**Pin to new base commit** +To update to a new base commit: -To change the base commit, update `FETCH_HEAD` in Makefile.sync. +1. **Update FETCH_HEAD** in `Makefile.sync` to the new commit hash. -When updating to a newer base commit, the existing patches may not apply cleanly and require manual merge resolution. +2. **Check for upstreamed patches**: Before applying, review if any patches have been merged upstream. Remove those patches from `./patches/` to avoid conflicts. -Start by applying the patches. If any of the patches have conflicts, the `git am` will stop at the first failure. +3. **Apply patches**: + ```shell + make -f Makefile.sync apply-patches + ``` -```shell -make -f Makefile.sync apply-patches -``` - -If there are conflicts, you will see an error message. Resolve the conflicts in `./vendor/`, and continue the patch series with `git am --continue` and rerun `make -f Makefile.sync apply-patches`. Repeat until all patches are successfully applied. +4. **Resolve conflicts** (if any): When `git am` fails on a patch: + - Fix conflicts in `./vendor/` + - Stage the resolved files: `git -C llama/vendor add ` + - Continue: `git -C llama/vendor am --continue` + - Re-run: `make -f Makefile.sync apply-patches` + - Repeat until all patches are applied. -Once all patches are applied, commit the changes to the tracking repository. - -```shell -make -f Makefile.sync format-patches sync -``` +5. **Regenerate patches and sync**: + ```shell + make -f Makefile.sync format-patches sync + ``` ### Generating Patches diff --git a/llama/build-info.cpp b/llama/build-info.cpp index b37cd25efad..80e6bb6670a 100644 --- a/llama/build-info.cpp +++ b/llama/build-info.cpp @@ -1,4 +1,4 @@ int LLAMA_BUILD_NUMBER = 0; -char const *LLAMA_COMMIT = "ec98e2002"; +char const *LLAMA_COMMIT = "feefb928367a01c1912975f0b277a48a14bbcadf"; char const *LLAMA_COMPILER = ""; char const *LLAMA_BUILD_TARGET = ""; diff --git a/llama/llama.cpp/.rsync-filter b/llama/llama.cpp/.rsync-filter index 7987be1c83d..e149fa9a7ad 100644 --- a/llama/llama.cpp/.rsync-filter +++ b/llama/llama.cpp/.rsync-filter @@ -7,6 +7,7 @@ include /common/json-schema-to-grammar.* include /common/json.* include /common/log.* include /common/sampling.* +include /common/unicode.* include /include/ include /include/llama.* include /include/llama-*.* diff --git a/llama/llama.cpp/LICENSE b/llama/llama.cpp/LICENSE index acb96ce78e0..e7dca554bcb 100644 --- a/llama/llama.cpp/LICENSE +++ b/llama/llama.cpp/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2023-2024 The ggml authors +Copyright (c) 2023-2026 The ggml authors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/llama/llama.cpp/common/common.cpp b/llama/llama.cpp/common/common.cpp index 5a8cf524856..53bddc4ef2f 100644 --- a/llama/llama.cpp/common/common.cpp +++ b/llama/llama.cpp/common/common.cpp @@ -1,7 +1,3 @@ -#if defined(_MSC_VER) -#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING -#endif - #include "ggml.h" #include "gguf.h" @@ -9,12 +5,12 @@ #include "log.h" #include "llama.h" #include "sampling.h" +#include "unicode.h" #include #include #include #include -#include #include #include #include @@ -251,7 +247,7 @@ bool set_process_priority(enum ggml_sched_priority prio) { case GGML_SCHED_PRIO_REALTIME: p = -20; break; } - if (!setpriority(PRIO_PROCESS, 0, p)) { + if (setpriority(PRIO_PROCESS, 0, p) != 0) { LOG_WRN("failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno); return false; } @@ -456,34 +452,6 @@ void string_replace_all(std::string & s, const std::string & search, const std:: s = std::move(builder); } -bool string_ends_with(const std::string_view & str, const std::string_view & suffix) { - return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; -} - -bool string_remove_suffix(std::string & str, const std::string_view & suffix) { - bool has_suffix = string_ends_with(str, suffix); - if (has_suffix) { - str = str.substr(0, str.size() - suffix.size()); - } - return has_suffix; -} - -size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) { - if (!str.empty() && !stop.empty()) { - const char text_last_char = str.back(); - for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) { - if (stop[char_index] == text_last_char) { - const auto current_partial = stop.substr(0, char_index + 1); - if (string_ends_with(str, current_partial)) { - return str.size() - char_index - 1; - } - } - } - } - - return std::string::npos; -} - std::string regex_escape(const std::string & s) { static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]"); return std::regex_replace(s, special_chars, "\\$&"); @@ -706,45 +674,28 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) { return false; } - std::u32string filename_utf32; - try { -#if defined(__clang__) - // disable C++17 deprecation warning for std::codecvt_utf8 -# pragma clang diagnostic push -# pragma clang diagnostic ignored "-Wdeprecated-declarations" -#elif defined(__GNUC__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wdeprecated-declarations" -#endif - - std::wstring_convert, char32_t> converter; - -#if defined(__clang__) -# pragma clang diagnostic pop -#elif defined(__GNUC__) -# pragma GCC diagnostic pop -#endif + size_t offset = 0; + while (offset < filename.size()) { + utf8_parse_result result = parse_utf8_codepoint(filename, offset); - filename_utf32 = converter.from_bytes(filename); + if (result.status != utf8_parse_result::SUCCESS) { + return false; + } + uint32_t c = result.codepoint; - // If the reverse conversion mismatches, it means overlong UTF-8 sequences were used, - // or invalid encodings were encountered. Reject such attempts - std::string filename_reencoded = converter.to_bytes(filename_utf32); - if (filename_reencoded != filename) { + if ((result.bytes_consumed == 2 && c < 0x80) || + (result.bytes_consumed == 3 && c < 0x800) || + (result.bytes_consumed == 4 && c < 0x10000)) { return false; } - } catch (const std::exception &) { - return false; - } - // Check for forbidden codepoints: - // - Control characters - // - Unicode equivalents of illegal characters - // - UTF-16 surrogate pairs - // - UTF-8 replacement character - // - Byte order mark (BOM) - // - Illegal characters: / \ : * ? " < > | - for (char32_t c : filename_utf32) { + // Check for forbidden codepoints: + // - Control characters + // - Unicode equivalents of illegal characters + // - UTF-16 surrogate pairs + // - UTF-8 replacement character + // - Byte order mark (BOM) + // - Illegal characters: / \ : * ? " < > | if (c <= 0x1F // Control characters (C0) || c == 0x7F // Control characters (DEL) || (c >= 0x80 && c <= 0x9F) // Control characters (C1) @@ -752,6 +703,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) { || c == 0x2215 // Division Slash (forward slash equivalent) || c == 0x2216 // Set Minus (backslash equivalent) || (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs + || c > 0x10FFFF // Max Unicode limit || c == 0xFFFD // Replacement Character (UTF-8) || c == 0xFEFF // Byte Order Mark (BOM) || c == ':' || c == '*' // Illegal characters @@ -762,6 +714,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) { // Subdirectories not allowed, reject path separators return false; } + offset += result.bytes_consumed; } // Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename @@ -898,7 +851,8 @@ std::string fs_get_cache_directory() { if (getenv("LLAMA_CACHE")) { cache_directory = std::getenv("LLAMA_CACHE"); } else { -#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__) +#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || \ + defined(__OpenBSD__) || defined(__NetBSD__) if (std::getenv("XDG_CACHE_HOME")) { cache_directory = std::getenv("XDG_CACHE_HOME"); } else if (std::getenv("HOME")) { @@ -1078,12 +1032,15 @@ struct common_init_result::impl { impl() = default; ~impl() = default; + // note: the order in which model, context, etc. are declared matters because their destructors will be called bottom-to-top + llama_model_ptr model; llama_context_ptr context; std::vector lora; std::vector samplers; + std::vector samplers_seq_config; }; common_init_result::common_init_result(common_params & params) : @@ -1092,9 +1049,12 @@ common_init_result::common_init_result(common_params & params) : auto cparams = common_context_params_to_llama(params); if (params.fit_params) { - LOG_INF("%s: fitting params to device memory, to report bugs during this step use -fit off (or --verbose if you can't)\n", __func__); + LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__); llama_params_fit(params.model.path.c_str(), &mparams, &cparams, - params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx, + params.tensor_split, + params.tensor_buft_overrides.data(), + params.fit_params_target.data(), + params.fit_params_min_ctx, params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR); } @@ -1107,6 +1067,25 @@ common_init_result::common_init_result(common_params & params) : const llama_vocab * vocab = llama_model_get_vocab(model); + // load and optionally apply lora adapters (must be loaded before context creation) + for (auto & la : params.lora_adapters) { + llama_adapter_lora_ptr lora; + lora.reset(llama_adapter_lora_init(model, la.path.c_str())); + if (lora == nullptr) { + LOG_ERR("%s: failed to load lora adapter '%s'\n", __func__, la.path.c_str()); + pimpl->model.reset(model); + return; + } + + char buf[1024]; + la.ptr = lora.get(); + llama_adapter_meta_val_str(la.ptr, "adapter.lora.task_name", buf, sizeof(buf)); + la.task_name = buf; + llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf)); + la.prompt_prefix = buf; + pimpl->lora.emplace_back(std::move(lora)); // copy to list of loaded adapters + } + // updates params.sampling // TODO: fix naming common_init_sampler_from_model(model, params.sampling); @@ -1141,10 +1120,18 @@ common_init_result::common_init_result(common_params & params) : // params.sampling.dry_penalty_last_n = llama_n_ctx(lctx); //} + // init the backend samplers as part of the context creation pimpl->samplers.resize(cparams.n_seq_max); + pimpl->samplers_seq_config.resize(cparams.n_seq_max); for (int i = 0; i < (int) cparams.n_seq_max; ++i) { pimpl->samplers[i].reset(common_sampler_init(model, params.sampling)); + pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) }; + } + + if (params.sampling.backend_sampling) { + cparams.samplers = pimpl->samplers_seq_config.data(); + cparams.n_samplers = pimpl->samplers_seq_config.size(); } llama_context * lctx = llama_init_from_model(model, cparams); @@ -1168,12 +1155,14 @@ common_sampler * common_init_result::sampler(llama_seq_id seq_id) { return pimpl->samplers[seq_id].get(); } -std::vector & common_init_result::lora() { - return pimpl->lora; +void common_init_result::reset_samplers() { + for (int i = 0; i < (int) pimpl->samplers.size(); ++i) { + llama_sampler_reset(common_sampler_get(pimpl->samplers[i].get())); + } } -void common_init_result::free_context() { - pimpl->context.reset(); +std::vector & common_init_result::lora() { + return pimpl->lora; } common_init_result_ptr common_init_from_params(common_params & params) { @@ -1207,7 +1196,7 @@ common_init_result_ptr common_init_from_params(common_params & params) { return res; } - int err = llama_apply_adapter_cvec( + int err = llama_set_adapter_cvec( lctx, cvec.data.data(), cvec.data.size(), @@ -1243,24 +1232,6 @@ common_init_result_ptr common_init_from_params(common_params & params) { } } - // load and optionally apply lora adapters - for (auto & la : params.lora_adapters) { - llama_adapter_lora_ptr lora; - lora.reset(llama_adapter_lora_init(model, la.path.c_str())); - if (lora == nullptr) { - LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); - return res; - } - - char buf[1024]; - la.ptr = lora.get(); - llama_adapter_meta_val_str(la.ptr, "adapter.lora.task_name", buf, sizeof(buf)); - la.task_name = buf; - llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf)); - la.prompt_prefix = buf; - res->lora().emplace_back(std::move(lora)); // copy to list of loaded adapters - } - if (!params.lora_init_without_apply) { common_set_adapter_lora(lctx, params.lora_adapters); } @@ -1301,6 +1272,9 @@ common_init_result_ptr common_init_from_params(common_params & params) { llama_synchronize(lctx); llama_perf_context_reset(lctx); llama_set_warmup(lctx, false); + + // reset samplers to reset RNG state after warmup to the seeded state + res->reset_samplers(); } return res; @@ -1324,12 +1298,15 @@ std::string get_model_endpoint() { } void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora) { - llama_clear_adapter_lora(ctx); - for (auto & la : lora) { - if (la.scale != 0.0f) { - llama_set_adapter_lora(ctx, la.ptr, la.scale); - } + std::vector loras; + std::vector scales; + + for (auto & la: lora) { + loras.push_back(la.ptr); + scales.push_back(la.scale); } + + llama_set_adapters_lora(ctx, loras.data(), loras.size(), scales.data()); } struct llama_model_params common_model_params_to_llama(common_params & params) { @@ -1339,14 +1316,12 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { mparams.devices = params.devices.data(); } - if (params.n_gpu_layers != -1) { - mparams.n_gpu_layers = params.n_gpu_layers; - } - + mparams.n_gpu_layers = params.n_gpu_layers; mparams.main_gpu = params.main_gpu; mparams.split_mode = params.split_mode; mparams.tensor_split = params.tensor_split; mparams.use_mmap = params.use_mmap; + mparams.use_direct_io = params.use_direct_io; mparams.use_mlock = params.use_mlock; mparams.check_tensors = params.check_tensors; mparams.use_extra_bufts = !params.no_extra_bufts; @@ -1451,66 +1426,6 @@ void common_batch_add( batch.n_tokens++; } -// -// Token utils -// - -size_t common_lcp(const llama_tokens & a, const llama_tokens & b) { - size_t i; - for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {} - - return i; -} - -size_t common_lcs(const llama_tokens & a, const llama_tokens & b) { - // check for empty sequences - if (a.empty() || b.empty()) { - return 0; - } - - // get the lengths of the input sequences - size_t a_len = a.size(); - size_t b_len = b.size(); - - // initialize the maximum length of the longest common subsequence (LCS) - size_t max_length = 0; - - // use two rows instead of a 2D matrix to optimize space - std::vector prev_row(b_len + 1, 0); - std::vector curr_row(b_len + 1, 0); - - // iterate through the elements of a - for (size_t i = 1; i <= a_len; i++) { - // iterate through the elements of b - for (size_t j = 1; j <= b_len; j++) { - // if elements at the current positions match - if (a[i - 1] == b[j - 1]) { - // if it's the first element of either sequences, set LCS length to 1 - if (i == 1 || j == 1) { - curr_row[j] = 1; - } else { - // increment LCS length by 1 compared to the previous element - curr_row[j] = prev_row[j - 1] + 1; - } - - // update max_length if necessary - if (curr_row[j] > max_length) { - max_length = curr_row[j]; - } - } else { - // reset LCS length if elements don't match - curr_row[j] = 0; - } - } - - // update the previous row for the next iteration - prev_row = curr_row; - } - - // return the maximum length of the LCS - return max_length; -} - // // Vocab utils // @@ -1845,3 +1760,65 @@ float lr_opt::get_lr(float epoch) const { LOG_INF("epoch %.2g lr=%.2g\n", epoch, r); return r; } + +bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos) { + llama_batch batch = llama_batch_get_one(&last_token, 1); + batch.pos = &pos; + if (llama_decode(ctx, batch)) { + LOG_ERR("%s: failed to replay last token\n", __func__); + return false; + } + return true; +} + +bool common_prompt_batch_decode( + struct llama_context * ctx, + const std::vector & tokens, + int & n_past, + int n_batch, + std::string_view state_path, + bool save_state) { + const int n_eval = tokens.size(); + if (n_eval == 0) { + return true; + } + + if (save_state && n_eval > 1) { + const int n_tokens_before_last = n_eval - 1; + + GGML_ASSERT(n_eval <= n_batch); + + // Decode all but the last token so we can save the memory state before decoding the last token. + // This is done so we can restore the session state later and replay the last token. + // Memory implementations in recurrent/hybrid models don't support removing tokens from their + // memory, so we can't just remove the last token from the memory and replay the last token which + // is the reason for this logic. + if (llama_decode(ctx, llama_batch_get_one(const_cast(tokens.data()), n_tokens_before_last))) { + LOG_ERR("%s : failed to eval\n", __func__); + return false; + } + n_past += n_tokens_before_last; + + llama_state_save_file(ctx, state_path.data(), tokens.data(), n_tokens_before_last); + LOG_INF("saved session before last token to %s, n_tokens = %d\n", state_path.data(), n_tokens_before_last); + + llama_token last_token = tokens.back(); + llama_batch batch = llama_batch_get_one(&last_token, 1); + int32_t pos = n_past; + batch.pos = &pos; + + if (llama_decode(ctx, batch)) { + LOG_ERR("%s : failed to eval last token\n", __func__); + return false; + } + n_past++; + } else { + if (llama_decode(ctx, llama_batch_get_one(const_cast(tokens.data()), n_eval))) { + LOG_ERR("%s : failed to eval\n", __func__); + return false; + } + n_past += n_eval; + } + + return true; +} diff --git a/llama/llama.cpp/common/common.h b/llama/llama.cpp/common/common.h index d70744840fb..c5a80375713 100644 --- a/llama/llama.cpp/common/common.h +++ b/llama/llama.cpp/common/common.h @@ -57,6 +57,8 @@ extern const char * LLAMA_COMMIT; extern const char * LLAMA_COMPILER; extern const char * LLAMA_BUILD_TARGET; +const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); + struct common_control_vector_load_info; // @@ -80,6 +82,8 @@ int32_t cpu_get_num_math(); // enum llama_example { + LLAMA_EXAMPLE_BATCHED, + LLAMA_EXAMPLE_DEBUG, LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_COMPLETION, @@ -117,6 +121,7 @@ enum common_sampler_type { COMMON_SAMPLER_TYPE_INFILL = 9, COMMON_SAMPLER_TYPE_PENALTIES = 10, COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11, + COMMON_SAMPLER_TYPE_ADAPTIVE_P = 12, }; // dimensionality reduction methods, used by cvector-generator @@ -159,37 +164,50 @@ enum common_params_sampling_config : uint64_t { COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11, }; +enum common_speculative_type { + COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding + COMMON_SPECULATIVE_TYPE_DRAFT, // draft model + COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model + COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding + COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only + COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values + COMMON_SPECULATIVE_TYPE_NGRAM_MOD, + COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, // self-speculative decoding with 3-level n-gram cache + COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type +}; // sampling parameters struct common_params_sampling { uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler - int32_t n_prev = 64; // number of previous tokens to remember - int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. - int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens - int32_t top_k = 40; // <= 0 to use vocab size - float top_p = 0.95f; // 1.0 = disabled - float min_p = 0.05f; // 0.0 = disabled - float xtc_probability = 0.00f; // 0.0 = disabled - float xtc_threshold = 0.10f; // > 0.5 disables XTC - float typ_p = 1.00f; // typical_p, 1.0 = disabled - float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities - float dynatemp_range = 0.00f; // 0.0 = disabled - float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler - int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float penalty_repeat = 1.00f; // 1.0 = disabled - float penalty_freq = 0.00f; // 0.0 = disabled - float penalty_present = 0.00f; // 0.0 = disabled - float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition: - float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length) - int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty - int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) - int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float top_n_sigma = -1.00f;// -1.0 = disabled - float mirostat_tau = 5.00f; // target entropy - float mirostat_eta = 0.10f; // learning rate + int32_t n_prev = 64; // number of previous tokens to remember + int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens + int32_t top_k = 40; // <= 0 to use vocab size + float top_p = 0.95f; // 1.0 = disabled + float min_p = 0.05f; // 0.0 = disabled + float xtc_probability = 0.00f; // 0.0 = disabled + float xtc_threshold = 0.10f; // > 0.5 disables XTC + float typ_p = 1.00f; // typical_p, 1.0 = disabled + float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities + float dynatemp_range = 0.00f; // 0.0 = disabled + float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler + int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat = 1.00f; // 1.0 = disabled + float penalty_freq = 0.00f; // 0.0 = disabled + float penalty_present = 0.00f; // 0.0 = disabled + float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition: + float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length) + int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty + int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) + float adaptive_target = -1.0f; // select tokens near this probability (valid range 0.0 to 1.0; negative = disabled) + float adaptive_decay = 0.90f; // EMA decay for adaptation; history ≈ 1/(1-decay) tokens (0.0 - 0.99) + int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float top_n_sigma = -1.00f; // -1.0 = disabled + float mirostat_tau = 5.00f; // target entropy + float mirostat_eta = 0.10f; // learning rate bool ignore_eos = false; - bool no_perf = false; // disable performance metrics + bool no_perf = false; // disable performance metrics bool timing_per_token = false; uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers @@ -216,6 +234,8 @@ struct common_params_sampling { std::vector logit_bias; // logit biases to apply std::vector logit_bias_eog; // pre-calculated logit biases for EOG tokens + bool backend_sampling = false; + bool has_logit_bias() const { return !logit_bias.empty(); } @@ -233,17 +253,39 @@ struct common_params_model { std::string name = ""; // in format /[:] (tag is optional) // NOLINT }; +struct common_ngram_mod; + struct common_params_speculative { - std::vector devices; // devices to use for offloading + common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; // type of speculative decoding - int32_t n_ctx = 0; // draft context size - int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding - int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding - int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) - float p_split = 0.1f; // speculative decoding split probability - float p_min = 0.75f; // minimum speculative decoding probability (greedy) - std::vector> replacements; // main to speculative model replacements - std::vector tensor_buft_overrides; + // general-purpose speculative decoding parameters + + int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding + int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding + float p_split = 0.1f; // speculative decoding split probability + float p_min = 0.75f; // minimum speculative decoding probability (greedy) + + // ngram-based speculative decoding + + uint16_t ngram_size_n = 12; // ngram size for lookup + uint16_t ngram_size_m = 48; // mgram size for speculative tokens + uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed + + std::shared_ptr ngram_mod; + + std::string lookup_cache_static; // path of static ngram cache file for lookup decoding // NOLINT + std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding // NOLINT + + // draft-model speculative decoding + + struct common_params_model mparams_dft; + + llama_model * model_dft = nullptr; // a llama_model that can be shared by multiple speculative contexts + + llama_context_params cparams_dft; // these are the parameters for the draft llama_context + + int32_t n_ctx = 0; // draft context size + int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V @@ -251,7 +293,14 @@ struct common_params_speculative { struct cpu_params cpuparams; struct cpu_params cpuparams_batch; - struct common_params_model model; + std::vector devices; // devices to use for offloading + + std::vector> replacements; // main to speculative model replacements + std::vector tensor_buft_overrides; + + bool has_dft() const { + return !mparams_dft.path.empty() || !mparams_dft.hf_repo.empty(); + } }; struct common_params_vocoder { @@ -277,6 +326,7 @@ struct common_params_diffusion { }; // reasoning API response format (not to be confused as chat template's reasoning format) +// only used by server enum common_reasoning_format { COMMON_REASONING_FORMAT_NONE, COMMON_REASONING_FORMAT_AUTO, // Same as deepseek, using `message.reasoning_content` @@ -329,12 +379,14 @@ struct common_params { // offload params std::vector devices; // devices to use for offloading - int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) - int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors - float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs - bool fit_params = true; // whether to fit unset model/context parameters to free device memory - size_t fit_params_target = 1024 * 1024*1024; // margin per device in bytes for fitting parameters to free memory - int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use + int32_t n_gpu_layers = -1; // number of layers to store in VRAM, -1 is auto, <= -2 is all + int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors + float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs + bool fit_params = true; // whether to fit unset model/context parameters to free device memory + int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use + + // margin per device in bytes for fitting parameters to free memory: + std::vector fit_params_target = std::vector(llama_max_devices(), 1024 * 1024*1024); enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs @@ -358,7 +410,8 @@ struct common_params { struct common_params_model model; - std::string model_alias = ""; // model alias // NOLINT + std::set model_alias; // model aliases // NOLINT + std::set model_tags; // model tags (informational, not used for routing) // NOLINT std::string hf_token = ""; // HF token // NOLINT std::string prompt = ""; // NOLINT std::string system_prompt = ""; // NOLINT @@ -366,10 +419,13 @@ struct common_params { std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT std::string input_prefix = ""; // string to prefix user inputs with // NOLINT std::string input_suffix = ""; // string to suffix user inputs with // NOLINT - std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT - std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT std::string logits_file = ""; // file for saving *all* logits // NOLINT + // llama-debug specific options + std::string logits_output_dir = "data"; // directory for saving logits output files // NOLINT + bool save_logits = false; // whether to save logits to files // NOLINT + std::vector tensor_filter; // filter tensor names for debug output (regex) // NOLINT + std::vector in_files; // all input files std::vector antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector kv_overrides; @@ -420,7 +476,8 @@ struct common_params { bool kv_unified = false; // enable unified KV cache bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix - bool use_mmap = true; // use mmap for faster loads + bool use_mmap = true; // enable mmap to use filesystem cache + bool use_direct_io = false; // read from disk without buffering bool use_mlock = false; // use mlock to keep model in memory bool verbose_prompt = false; // print prompt tokens before generation bool display_prompt = true; // print prompt before generation @@ -464,6 +521,7 @@ struct common_params { int32_t timeout_write = timeout_read; // http write timeout in seconds int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting + bool cache_prompt = true; // whether to enable prompt caching int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. @@ -475,7 +533,8 @@ struct common_params { bool enable_chat_template = true; common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; int reasoning_budget = -1; - bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response + bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response + int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time std::vector api_keys; @@ -484,8 +543,11 @@ struct common_params { std::map default_template_kwargs; + // webui configs + bool webui = true; + std::string webui_config_json; + // "advanced" endpoints are disabled by default for better security - bool webui = true; bool endpoint_slots = true; bool endpoint_props = false; // only control POST requests, not GET bool endpoint_metrics = false; @@ -552,10 +614,6 @@ struct common_params { // return false from callback to abort model loading or true to continue llama_progress_callback load_progress_callback = NULL; void * load_progress_callback_user_data = NULL; - - bool has_speculative() const { - return !speculative.model.path.empty() || !speculative.model.hf_repo.empty(); - } }; // call once at the start of a program if it uses libcommon @@ -613,30 +671,55 @@ static std::vector string_split(const std::string & str, char delim) { } template<> -std::vector string_split(const std::string & input, char separator) +inline std::vector string_split(const std::string & str, char delim) { std::vector parts; size_t begin_pos = 0; - size_t separator_pos = input.find(separator); - while (separator_pos != std::string::npos) { - std::string part = input.substr(begin_pos, separator_pos - begin_pos); + size_t delim_pos = str.find(delim); + while (delim_pos != std::string::npos) { + std::string part = str.substr(begin_pos, delim_pos - begin_pos); parts.emplace_back(part); - begin_pos = separator_pos + 1; - separator_pos = input.find(separator, begin_pos); + begin_pos = delim_pos + 1; + delim_pos = str.find(delim, begin_pos); } - parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos)); + parts.emplace_back(str.substr(begin_pos)); return parts; } -static bool string_starts_with(const std::string & str, - const std::string & prefix) { // While we wait for C++20's std::string::starts_with... - return str.rfind(prefix, 0) == 0; +// remove when moving to c++20 +inline bool string_starts_with(std::string_view str, std::string_view prefix) { + return str.size() >= prefix.size() && + str.compare(0, prefix.size(), prefix) == 0; +} + +// remove when moving to c++20 +inline bool string_ends_with(std::string_view str, std::string_view suffix) { + return str.size() >= suffix.size() && + str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0; } -// While we wait for C++20's std::string::ends_with... -bool string_ends_with(const std::string_view & str, const std::string_view & suffix); -bool string_remove_suffix(std::string & str, const std::string_view & suffix); -size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop); +inline bool string_remove_suffix(std::string & str, std::string_view suffix) { + if (string_ends_with(str, suffix)) { + str.resize(str.size() - suffix.size()); + return true; + } + return false; +} + +inline size_t string_find_partial_stop(std::string_view str, std::string_view stop) { + if (!str.empty() && !stop.empty()) { + const size_t max_len = std::min(str.size(), stop.size()); + const char last_char = str.back(); + for (size_t len = max_len; len > 0; --len) { + if (stop[len - 1] == last_char) { + if (string_ends_with(str, stop.substr(0, len))) { + return str.size() - len; + } + } + } + } + return std::string::npos; +} bool string_parse_kv_override(const char * data, std::vector & overrides); void string_process_escapes(std::string & input); @@ -685,12 +768,12 @@ struct common_init_result { llama_model * model(); llama_context * context(); + common_sampler * sampler(llama_seq_id seq_id); + void reset_samplers(); std::vector & lora(); - void free_context(); - private: struct impl; std::unique_ptr pimpl; @@ -722,15 +805,22 @@ void common_batch_add( const std::vector & seq_ids, bool logits); +// decodes a single batch of tokens for a prompt and manages session tokens // -// Token utils -// - -// longest common prefix -size_t common_lcp(const llama_tokens & a, const llama_tokens & b); - -// longet common subsequence -size_t common_lcs(const llama_tokens & a, const llama_tokens & b); +// Note: We save state before the last token so that we can replay it to ensure +// compatibility with all memory types. Recurrent/hybrid models cannot remove +// tokens from memory, so this approach works across all model architectures. +bool common_prompt_batch_decode( + struct llama_context * ctx, + const std::vector & embd, + int & n_past, + int n_batch, + std::string_view state_path, + bool save_state); + +// replays the last token after loading state to regenerate logits +// used after loading session state to ensure the sampling context has valid logits +bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos); // // Vocab utils @@ -823,11 +913,11 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"; const char * const LLM_FFN_EXPS_REGEX = "\\.ffn_(up|down|gate)_(ch|)exps"; -static std::string llm_ffn_exps_block_regex(int idx) { +inline std::string llm_ffn_exps_block_regex(int idx) { return string_format("blk\\.%d%s", idx, LLM_FFN_EXPS_REGEX); } -static llama_model_tensor_buft_override llm_ffn_exps_cpu_override() { +inline llama_model_tensor_buft_override llm_ffn_exps_cpu_override() { return { LLM_FFN_EXPS_REGEX, ggml_backend_cpu_buffer_type() }; } diff --git a/llama/llama.cpp/common/sampling.cpp b/llama/llama.cpp/common/sampling.cpp index 6935d84e226..11a1d483980 100644 --- a/llama/llama.cpp/common/sampling.cpp +++ b/llama/llama.cpp/common/sampling.cpp @@ -104,10 +104,9 @@ struct ring_buffer { struct common_sampler { common_params_sampling params; + struct llama_sampler * grmr; struct llama_sampler * chain; - bool grammar; - ring_buffer prev; std::vector cur; @@ -121,17 +120,34 @@ struct common_sampler { } void set_logits(struct llama_context * ctx, int idx) { - const auto * logits = llama_get_logits_ith(ctx, idx); + const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx); + const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx); + const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx); const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); const int n_vocab = llama_vocab_n_tokens(vocab); - cur.resize(n_vocab); - - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + if (sampled_probs) { + const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx); + cur.resize(sampled_probs_count); + for (uint32_t i = 0; i < sampled_probs_count; ++i) { + cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]}; + } + } else if (sampled_logits) { + const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx); + cur.resize(sampled_logits_count); + for (uint32_t i = 0; i < sampled_logits_count; i++) { + cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}; + } + } else { + const auto * logits = llama_get_logits_ith(ctx, idx); + GGML_ASSERT(logits != nullptr); + cur.resize(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } } cur_p = { cur.data(), cur.size(), -1, false }; @@ -151,54 +167,59 @@ std::string common_params_sampling::print() const { "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n" "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n" - "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", + "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f, adaptive_target = %.3f, adaptive_decay = %.3f", penalty_last_n, penalty_repeat, penalty_freq, penalty_present, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp, - mirostat, mirostat_eta, mirostat_tau); + mirostat, mirostat_eta, mirostat_tau, adaptive_target, adaptive_decay); return std::string(result); } -struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) { +struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params) { const llama_vocab * vocab = llama_model_get_vocab(model); llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); lparams.no_perf = params.no_perf; + llama_sampler * grmr = nullptr; llama_sampler * chain = llama_sampler_chain_init(lparams); - bool grammar = false; std::vector samplers; if (params.grammar.compare(0, 11, "%llguidance") == 0) { #ifdef LLAMA_USE_LLGUIDANCE - samplers.push_back(llama_sampler_init_llg(vocab, "lark", params.grammar.c_str())); - grammar = true; + grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()); #else GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); #endif // LLAMA_USE_LLGUIDANCE } else { std::vector trigger_patterns; - std::vector patterns_anywhere; std::vector trigger_tokens; for (const auto & trigger : params.grammar_triggers) { switch (trigger.type) { case COMMON_GRAMMAR_TRIGGER_TYPE_WORD: { const auto & word = trigger.value; - patterns_anywhere.push_back(regex_escape(word)); + trigger_patterns.push_back(regex_escape(word)); break; } case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: { - patterns_anywhere.push_back(trigger.value); + trigger_patterns.push_back(trigger.value); break; } case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL: { - trigger_patterns.push_back(trigger.value); + const auto & pattern = trigger.value; + std::string anchored = "^$"; + if (!pattern.empty()) { + anchored = (pattern.front() != '^' ? "^" : "") + + pattern + + (pattern.back() != '$' ? "$" : ""); + } + trigger_patterns.push_back(anchored); break; } case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN: @@ -212,10 +233,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co } } - if (!patterns_anywhere.empty()) { - trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*"); - } - std::vector trigger_patterns_c; trigger_patterns_c.reserve(trigger_patterns.size()); for (const auto & regex : trigger_patterns) { @@ -224,15 +241,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co if (!params.grammar.empty()) { if (params.grammar_lazy) { - samplers.push_back( - llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root", - trigger_patterns_c.data(), trigger_patterns_c.size(), - trigger_tokens.data(), trigger_tokens.size())); + grmr = llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root", + trigger_patterns_c.data(), trigger_patterns_c.size(), + trigger_tokens.data(), trigger_tokens.size()); } else { - samplers.push_back(llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root")); + grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"); } - - grammar = true; } } @@ -241,6 +255,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co } if (params.mirostat == 0) { + + bool use_adaptive_p = false; // see below + for (const auto & cnstr : params.samplers) { switch (cnstr) { case COMMON_SAMPLER_TYPE_DRY: @@ -250,43 +267,54 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co for (const auto & str : params.dry_sequence_breakers) { c_breakers.push_back(str.c_str()); } - - samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); + samplers.push_back(llama_sampler_init_dry(vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); } break; case COMMON_SAMPLER_TYPE_TOP_K: - samplers.push_back(llama_sampler_init_top_k (params.top_k)); + samplers.push_back(llama_sampler_init_top_k(params.top_k)); break; case COMMON_SAMPLER_TYPE_TOP_P: - samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep)); + samplers.push_back(llama_sampler_init_top_p(params.top_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma)); break; case COMMON_SAMPLER_TYPE_MIN_P: - samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep)); + samplers.push_back(llama_sampler_init_min_p(params.min_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_XTC: - samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); + samplers.push_back(llama_sampler_init_xtc(params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); break; case COMMON_SAMPLER_TYPE_TYPICAL_P: - samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep)); + samplers.push_back(llama_sampler_init_typical(params.typ_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_TEMPERATURE: - samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + samplers.push_back(llama_sampler_init_temp_ext(params.temp, params.dynatemp_range, params.dynatemp_exponent)); break; case COMMON_SAMPLER_TYPE_INFILL: - samplers.push_back(llama_sampler_init_infill (vocab)); + samplers.push_back(llama_sampler_init_infill(vocab)); break; case COMMON_SAMPLER_TYPE_PENALTIES: - samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); + samplers.push_back(llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); + break; + case COMMON_SAMPLER_TYPE_ADAPTIVE_P: + // the `adaptive-p` sampler is like `dist` and `mirostat` in that it selects + // a single token, so we will add `dist` at the end of the chain by default, + // unless the user specifically included `adaptive-p`. we set this flag here + // so we know to add the sampler at the very end. + use_adaptive_p = true; break; default: GGML_ASSERT(false && "unknown sampler type"); } } - - samplers.push_back(llama_sampler_init_dist(params.seed)); + if (use_adaptive_p) { + // only if user explicitly included adaptive-p sampler + samplers.push_back(llama_sampler_init_adaptive_p(params.adaptive_target, params.adaptive_decay, params.seed)); + } else { + // default: sample from distribution + samplers.push_back(llama_sampler_init_dist(params.seed)); + } } else if (params.mirostat == 1) { samplers.push_back(llama_sampler_init_temp(params.temp)); samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); @@ -301,10 +329,16 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co llama_sampler_chain_add(chain, smpl); } + if (grmr && params.backend_sampling) { + LOG_WRN("%s: backend sampling is not compatible with grammar, disabling\n", __func__); + + params.backend_sampling = false; + } + auto * result = new common_sampler { /* .params = */ params, + /* .grmr = */ grmr, /* .chain = */ chain, - /* .grammar = */ grammar, /* .prev = */ ring_buffer(std::max(32, params.n_prev)), /* .cur = */ {}, /* .cur_p = */ {}, @@ -314,47 +348,45 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co } void common_sampler_free(struct common_sampler * gsmpl) { - if (gsmpl) { - llama_sampler_free(gsmpl->chain); - - delete gsmpl; + if (!gsmpl) { + return; } + + llama_sampler_free(gsmpl->grmr); + llama_sampler_free(gsmpl->chain); + + delete gsmpl; } void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) { - const auto tm = gsmpl->tm(); - - if (gsmpl->grammar) { - const int n_smpl = llama_sampler_chain_n(gsmpl->chain); + if (!gsmpl) { + return; + } - for (int i = 0; i < n_smpl; i++) { - auto * smpl = llama_sampler_chain_get(gsmpl->chain, i); + const auto tm = gsmpl->tm(); - // the grammar sampler is always the first one - if (i == 0) { - if (accept_grammar) { - llama_sampler_accept(smpl, token); - } - } else { - llama_sampler_accept(smpl, token); - } - } - } else { - llama_sampler_accept(gsmpl->chain, token); + if (gsmpl->grmr && accept_grammar) { + llama_sampler_accept(gsmpl->grmr, token); } + llama_sampler_accept(gsmpl->chain, token); + gsmpl->prev.push_back(token); } void common_sampler_reset(struct common_sampler * gsmpl) { + if (!gsmpl) { + return; + } + gsmpl->reset(); } struct common_sampler * common_sampler_clone(common_sampler * gsmpl) { return new common_sampler { /* .params = */ gsmpl->params, + /* .grmr = */ llama_sampler_clone(gsmpl->grmr), /* .chain = */ llama_sampler_clone(gsmpl->chain), - /* .grammar = */ gsmpl->grammar, /* .prev = */ gsmpl->prev, /* .cur = */ gsmpl->cur, /* .cur_p = */ gsmpl->cur_p, @@ -407,10 +439,14 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam } struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) { + if (!gsmpl) { + return nullptr; + } + return gsmpl->chain; } -llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) { +llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { llama_synchronize(ctx); // start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations @@ -418,11 +454,61 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co llama_token id = LLAMA_TOKEN_NULL; + auto & grmr = gsmpl->grmr; auto & chain = gsmpl->chain; auto & cur_p = gsmpl->cur_p; // initialized by set_logits + // Check if a backend sampler has already sampled a token in which case we + // return that token id directly. + { + id = llama_get_sampled_token_ith(ctx, idx); + + if (id != LLAMA_TOKEN_NULL) { + LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id); + + GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported"); + + // TODO: simplify + gsmpl->cur.resize(1); + gsmpl->cur[0] = { id, 0.0f, 1.0f }; + cur_p = { gsmpl->cur.data(), gsmpl->cur.size(), 0, true }; + + return id; + } + } + + gsmpl->set_logits(ctx, idx); + + if (grammar_first) { + llama_sampler_apply(grmr, &cur_p); + } + + llama_sampler_apply(chain, &cur_p); + + id = cur_p.data[cur_p.selected].id; + + if (grammar_first) { + return id; + } + + // check if it the sampled token fits the grammar (grammar-based rejection sampling) + { + llama_token_data single_token_data = { id, 1.0f, 0.0f }; + llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false }; + + llama_sampler_apply(grmr, &single_token_data_array); + + const bool is_valid = single_token_data_array.data[0].logit != -INFINITY; + if (is_valid) { + return id; + } + } + + // resampling: + // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain gsmpl->set_logits(ctx, idx); + llama_sampler_apply(grmr, &cur_p); llama_sampler_apply(chain, &cur_p); GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration"); @@ -432,7 +518,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co return id; } -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft) { +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first) { GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); std::vector result; @@ -440,7 +526,7 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample size_t i = 0; for (; i < draft.size(); i++) { - const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]); + const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); common_sampler_accept(gsmpl, id, true); @@ -452,7 +538,7 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample } if (i == draft.size()) { - const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]); + const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); common_sampler_accept(gsmpl, id, true); @@ -462,13 +548,13 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample return result; } -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft) { +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) { std::vector idxs(draft.size() + 1); for (size_t i = 0; i < idxs.size(); ++i) { idxs[i] = i; } - return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft); + return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first); } uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) { @@ -553,6 +639,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) { case COMMON_SAMPLER_TYPE_XTC: return 'x'; case COMMON_SAMPLER_TYPE_INFILL: return 'i'; case COMMON_SAMPLER_TYPE_PENALTIES: return 'e'; + case COMMON_SAMPLER_TYPE_ADAPTIVE_P: return 'a'; default : return '?'; } } @@ -569,6 +656,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) { case COMMON_SAMPLER_TYPE_XTC: return "xtc"; case COMMON_SAMPLER_TYPE_INFILL: return "infill"; case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties"; + case COMMON_SAMPLER_TYPE_ADAPTIVE_P: return "adaptive_p"; default : return ""; } } @@ -585,6 +673,7 @@ std::vector common_sampler_types_from_names(const std::vect { "xtc", COMMON_SAMPLER_TYPE_XTC }, { "infill", COMMON_SAMPLER_TYPE_INFILL }, { "penalties", COMMON_SAMPLER_TYPE_PENALTIES }, + { "adaptive_p", COMMON_SAMPLER_TYPE_ADAPTIVE_P }, }; // since samplers names are written multiple ways @@ -600,6 +689,7 @@ std::vector common_sampler_types_from_names(const std::vect { "typ", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "min-p", COMMON_SAMPLER_TYPE_MIN_P }, { "temp", COMMON_SAMPLER_TYPE_TEMPERATURE }, + { "adaptive-p", COMMON_SAMPLER_TYPE_ADAPTIVE_P }, }; std::vector samplers; @@ -636,6 +726,7 @@ std::vector common_sampler_types_from_chars(const std::stri { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_ADAPTIVE_P), COMMON_SAMPLER_TYPE_ADAPTIVE_P }, }; std::vector samplers; diff --git a/llama/llama.cpp/common/sampling.h b/llama/llama.cpp/common/sampling.h index ace5d3d020b..5b57ad65811 100644 --- a/llama/llama.cpp/common/sampling.h +++ b/llama/llama.cpp/common/sampling.h @@ -36,7 +36,8 @@ struct common_sampler; // llama_sampler API overloads -struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params); +// note: can mutate params in some cases +struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params); void common_sampler_free(struct common_sampler * gsmpl); @@ -48,6 +49,7 @@ struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl); // arguments can be nullptr to skip printing void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl); +// get the underlying llama_sampler_chain struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl); // extended sampling implementation: @@ -57,7 +59,10 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl); // - check if the token fits the grammar (if any) // - if not: resample by first applying the grammar constraints and then sampling again (slower path) // -llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx); +// if grammar_first is true, the grammar is applied before the samplers (slower) +// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar +// +llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false); // generalized version of common_sampler_sample // @@ -75,10 +80,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co // // returns at least 1 token, up to idxs.size() // -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft); +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first = false); // assume idxs == [ 0, 1, 2, ..., draft.size() ] -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft); +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false); uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); diff --git a/llama/llama.cpp/common/unicode.cpp b/llama/llama.cpp/common/unicode.cpp new file mode 100644 index 00000000000..56ab0f468e0 --- /dev/null +++ b/llama/llama.cpp/common/unicode.cpp @@ -0,0 +1,64 @@ +#include "unicode.h" + +// implementation adopted from src/unicode.cpp + +size_t utf8_sequence_length(unsigned char first_byte) { + const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + uint8_t highbits = static_cast(first_byte) >> 4; + return lookup[highbits]; +} + +utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset) { + if (offset >= input.size()) { + return utf8_parse_result(utf8_parse_result::INCOMPLETE); + } + + // ASCII fast path + if (!(input[offset] & 0x80)) { + return utf8_parse_result(utf8_parse_result::SUCCESS, input[offset], 1); + } + + // Invalid: continuation byte as first byte + if (!(input[offset] & 0x40)) { + return utf8_parse_result(utf8_parse_result::INVALID); + } + + // 2-byte sequence + if (!(input[offset] & 0x20)) { + if (offset + 1 >= input.size()) { + return utf8_parse_result(utf8_parse_result::INCOMPLETE); + } + if ((input[offset + 1] & 0xc0) != 0x80) { + return utf8_parse_result(utf8_parse_result::INVALID); + } + auto result = ((input[offset] & 0x1f) << 6) | (input[offset + 1] & 0x3f); + return utf8_parse_result(utf8_parse_result::SUCCESS, result, 2); + } + + // 3-byte sequence + if (!(input[offset] & 0x10)) { + if (offset + 2 >= input.size()) { + return utf8_parse_result(utf8_parse_result::INCOMPLETE); + } + if ((input[offset + 1] & 0xc0) != 0x80 || (input[offset + 2] & 0xc0) != 0x80) { + return utf8_parse_result(utf8_parse_result::INVALID); + } + auto result = ((input[offset] & 0x0f) << 12) | ((input[offset + 1] & 0x3f) << 6) | (input[offset + 2] & 0x3f); + return utf8_parse_result(utf8_parse_result::SUCCESS, result, 3); + } + + // 4-byte sequence + if (!(input[offset] & 0x08)) { + if (offset + 3 >= input.size()) { + return utf8_parse_result(utf8_parse_result::INCOMPLETE); + } + if ((input[offset + 1] & 0xc0) != 0x80 || (input[offset + 2] & 0xc0) != 0x80 || (input[offset + 3] & 0xc0) != 0x80) { + return utf8_parse_result(utf8_parse_result::INVALID); + } + auto result = ((input[offset] & 0x07) << 18) | ((input[offset + 1] & 0x3f) << 12) | ((input[offset + 2] & 0x3f) << 6) | (input[offset + 3] & 0x3f); + return utf8_parse_result(utf8_parse_result::SUCCESS, result, 4); + } + + // Invalid first byte + return utf8_parse_result(utf8_parse_result::INVALID); +} diff --git a/llama/llama.cpp/common/unicode.h b/llama/llama.cpp/common/unicode.h new file mode 100644 index 00000000000..9d9e8e1227a --- /dev/null +++ b/llama/llama.cpp/common/unicode.h @@ -0,0 +1,22 @@ +#pragma once + +#include +#include + +// UTF-8 parsing utilities for streaming-aware unicode support + +struct utf8_parse_result { + uint32_t codepoint; // Decoded codepoint (only valid if status == SUCCESS) + size_t bytes_consumed; // How many bytes this codepoint uses (1-4) + enum status { SUCCESS, INCOMPLETE, INVALID } status; + + utf8_parse_result(enum status s, uint32_t cp = 0, size_t bytes = 0) + : codepoint(cp), bytes_consumed(bytes), status(s) {} +}; + +// Determine the expected length of a UTF-8 sequence from its first byte +// Returns 0 for invalid first bytes +size_t utf8_sequence_length(unsigned char first_byte); + +// Parse a single UTF-8 codepoint from input +utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset); diff --git a/llama/llama.cpp/include/llama-cpp.h b/llama/llama.cpp/include/llama-cpp.h index 8f6368177de..807e77f6280 100644 --- a/llama/llama.cpp/include/llama-cpp.h +++ b/llama/llama.cpp/include/llama-cpp.h @@ -21,7 +21,9 @@ struct llama_sampler_deleter { }; struct llama_adapter_lora_deleter { - void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); } + void operator()(llama_adapter_lora *) { + // llama_adapter_lora_free is deprecated + } }; typedef std::unique_ptr llama_model_ptr; diff --git a/llama/llama.cpp/include/llama.h b/llama/llama.cpp/include/llama.h index f8629300991..077f66dc651 100644 --- a/llama/llama.cpp/include/llama.h +++ b/llama/llama.cpp/include/llama.h @@ -286,7 +286,7 @@ extern "C" { // NULL-terminated list of buffer types to use for tensors that match a pattern const struct llama_model_tensor_buft_override * tensor_buft_overrides; - int32_t n_gpu_layers; // number of layers to store in VRAM + int32_t n_gpu_layers; // number of layers to store in VRAM, a negative value means all layers enum llama_split_mode split_mode; // how to split the model across multiple GPUs // the GPU that is used for the entire model when split_mode is LLAMA_SPLIT_MODE_NONE @@ -309,6 +309,7 @@ extern "C" { // Keep the booleans together to avoid misalignment during copy-by-value. bool vocab_only; // only load the vocabulary, no weights bool use_mmap; // use mmap if possible + bool use_direct_io; // use direct io, takes precedence over use_mmap when supported bool use_mlock; // force system to keep model in RAM bool check_tensors; // validate model tensor data bool use_extra_bufts; // use extra buffer types (used for weight repacking) @@ -316,6 +317,11 @@ extern "C" { bool no_alloc; // only load metadata and simulate memory allocations }; + struct llama_sampler_seq_config { + llama_seq_id seq_id; + struct llama_sampler * sampler; + }; + // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations // https://github.com/ggml-org/llama.cpp/pull/7544 struct llama_context_params { @@ -364,6 +370,12 @@ extern "C" { bool kv_unified; // use a unified buffer across the input sequences when computing the attention // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix // ref: https://github.com/ggml-org/llama.cpp/pull/14363 + + // [EXPERIMENTAL] + // backend sampler chain configuration (make sure the caller keeps the sampler chains alive) + // note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init) + struct llama_sampler_seq_config * samplers; + size_t n_samplers; }; // model quantization parameters @@ -377,6 +389,7 @@ extern "C" { bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored bool pure; // quantize all tensors to the default type bool keep_split; // quantize to the same number of shards + bool dry_run; // calculate and show the final quantization size without performing quantization void * imatrix; // pointer to importance matrix data void * kv_overrides; // pointer to vector containing overrides void * tensor_types; // pointer to vector containing tensor types @@ -467,16 +480,24 @@ extern "C" { // Frees all allocated memory LLAMA_API void llama_free(struct llama_context * ctx); + enum llama_params_fit_status { + LLAMA_PARAMS_FIT_STATUS_SUCCESS = 0, // found allocations that are projected to fit + LLAMA_PARAMS_FIT_STATUS_FAILURE = 1, // could not find allocations that are projected to fit + LLAMA_PARAMS_FIT_STATUS_ERROR = 2, // a hard error occurred, e.g. because no model could be found at the specified path + }; + // fits mparams and cparams to free device memory (assumes system memory is unlimited) - // returns true if the parameters could be successfully modified to fit device memory - // this function is NOT thread safe because it modifies the global llama logger state - LLAMA_API bool llama_params_fit( + // - returns true if the parameters could be successfully modified to fit device memory + // - this function is NOT thread safe because it modifies the global llama logger state + // - only parameters that have the same value as in llama_default_model_params are modified + // with the exception of the context size which is modified if and only if equal to 0 + LLAMA_API enum llama_params_fit_status llama_params_fit( const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams, float * tensor_split, // writable buffer for tensor split, needs at least llama_max_devices elements struct llama_model_tensor_buft_override * tensor_buft_overrides, // writable buffer for overrides, needs at least llama_max_tensor_buft_overrides elements - size_t margin, // margin of memory to leave per device in bytes + size_t * margins, // margins of memory to leave per device in bytes uint32_t n_ctx_min, // minimum context size to set when trying to reduce memory use enum ggml_log_level log_level); // minimum log level to print during fitting, lower levels go to debug log @@ -517,6 +538,7 @@ extern "C" { LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model); LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model); LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model); LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model); LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); @@ -600,6 +622,8 @@ extern "C" { // // Load a LoRA adapter from file + // The adapter is valid as long as the associated model is not freed + // All adapters must be loaded before context creation LLAMA_API struct llama_adapter_lora * llama_adapter_lora_init( struct llama_model * model, const char * path_lora); @@ -624,7 +648,8 @@ extern "C" { // Manually free a LoRA adapter // NOTE: loaded adapters will be free when the associated model is deleted - LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter); + LLAMA_API DEPRECATED(void llama_adapter_lora_free(struct llama_adapter_lora * adapter), + "adapters are now freed together with the associated model"); // Get the invocation tokens if the current lora is an alora LLAMA_API uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter); @@ -632,21 +657,12 @@ extern "C" { // The following functions operate on a llama_context, hence the naming: llama_verb_... - // Add a loaded LoRA adapter to given context - // This will not modify model's weight - LLAMA_API int32_t llama_set_adapter_lora( - struct llama_context * ctx, - struct llama_adapter_lora * adapter, - float scale); - - // Remove a specific LoRA adapter from given context - // Return -1 if the adapter is not present in the context - LLAMA_API int32_t llama_rm_adapter_lora( + // Set LoRa adapters on the context. Will only modify if the adapters currently in context are different. + LLAMA_API int32_t llama_set_adapters_lora( struct llama_context * ctx, - struct llama_adapter_lora * adapter); - - // Remove all LoRA adapters from given context - LLAMA_API void llama_clear_adapter_lora(struct llama_context * ctx); + struct llama_adapter_lora ** adapters, + size_t n_adapters, + float * scales); // Apply a loaded control vector to a llama_context, or if data is NULL, clear // the currently loaded vector. @@ -654,7 +670,7 @@ extern "C" { // to an n_embd x n_layers buffer starting from layer 1. // il_start and il_end are the layer range the vector should apply to (both inclusive) // See llama_control_vector_load in common to load a control vector. - LLAMA_API int32_t llama_apply_adapter_cvec( + LLAMA_API int32_t llama_set_adapter_cvec( struct llama_context * ctx, const float * data, size_t len, @@ -983,6 +999,32 @@ extern "C" { // otherwise: float[n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); + // + // backend sampling API [EXPERIMENTAL] + // note: use only if the llama_context was created with at least one llama_sampler_seq_config + // + + // Get the backend sampled token for the ith token. + // Returns LLAMA_TOKEN_NULL if no token was sampled. + LLAMA_API llama_token llama_get_sampled_token_ith(struct llama_context * ctx, int32_t i); + + // Get the backend sampled probabilites for the ith token + // The index matches llama_get_sampled_token_ith(). + // Returns NULL if no probabilites were generated. + LLAMA_API float * llama_get_sampled_probs_ith (struct llama_context * ctx, int32_t i); + LLAMA_API uint32_t llama_get_sampled_probs_count_ith(struct llama_context * ctx, int32_t i); + + // Get the backend sampled logits for the ith token + // Returns NULL if no logits were sampled. + LLAMA_API float * llama_get_sampled_logits_ith (struct llama_context * ctx, int32_t i); + LLAMA_API uint32_t llama_get_sampled_logits_count_ith(struct llama_context * ctx, int32_t i); + + // Get the backend sampled candidates (token ids) for the ith token + // These are needed to map probability/logit indices to vocab token ids. + // Returns NULL if no candidates were sampled. + LLAMA_API llama_token * llama_get_sampled_candidates_ith (struct llama_context * ctx, int32_t i); + LLAMA_API uint32_t llama_get_sampled_candidates_count_ith(struct llama_context * ctx, int32_t i); + // // Vocab // @@ -1100,9 +1142,9 @@ extern "C" { // /// Apply chat template. Inspired by hf apply_chat_template() on python. - /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" + /// /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template - /// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead. + /// @param tmpl A Jinja template to use for this chat. /// @param chat Pointer to a list of multiple llama_chat_message /// @param n_msg Number of llama_chat_message in this chat /// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message. @@ -1154,11 +1196,16 @@ extern "C" { // // llama_sampler_free(smpl); // - // TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU). - // typedef void * llama_sampler_context_t; + struct llama_sampler_data { + struct ggml_tensor * logits; + struct ggml_tensor * probs; + struct ggml_tensor * sampled; + struct ggml_tensor * candidates; + }; + // user code can implement the interface below in order to create custom llama_sampler struct llama_sampler_i { const char * (*name) (const struct llama_sampler * smpl); // can be NULL @@ -1168,17 +1215,44 @@ extern "C" { struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL - // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph - //void (*apply_ggml) (struct llama_sampler * smpl, ...); + // [EXPERIMENTAL] + // backend sampling interface: + + // return true if the backend supports all ops needed by the sampler + // note: call once per sampler + bool (*backend_init)(struct llama_sampler * smpl, ggml_backend_buffer_type_t buft); + + // call after .backend_apply() + void (*backend_accept)( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct ggml_tensor * selected_token); + + // call after .backend_init() + void (*backend_apply)( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data); + + // called before graph execution to set inputs for the current ubatch + void (*backend_set_input)(struct llama_sampler * smpl); }; struct llama_sampler { - const struct llama_sampler_i * iface; - llama_sampler_context_t ctx; + struct llama_sampler_i * iface; + + llama_sampler_context_t ctx; }; + // [EXPERIMENTAL] + // attach a sampler to the context + // note: prefer initializing the context with llama_context_params.samplers when possible + LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl); + // mirror of llama_sampler_i: - LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx); + LLAMA_API struct llama_sampler * llama_sampler_init ( struct llama_sampler_i * iface, llama_sampler_context_t ctx); LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl); LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); @@ -1194,7 +1268,15 @@ extern "C" { // important: takes ownership of the sampler object and will free it when llama_sampler_free is called LLAMA_API void llama_sampler_chain_add( struct llama_sampler * chain, struct llama_sampler * smpl); - LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i); + + // return NULL if: + // - the sampler is NULL + // - the sampler is not a llama_sampler_chain + // - the index is out of bounds, unless i == -1 + // - if i == -1, returns the chain itself (can be used to check if the sampler is a chain) + LLAMA_API struct llama_sampler * llama_sampler_chain_get( struct llama_sampler * chain, int32_t i); + + // the total number of samplers in the chain LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain); // after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed @@ -1203,7 +1285,9 @@ extern "C" { // available samplers: LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void); - LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed); + + /// seed == LLAMA_DEFAULT_SEED to use a random seed. + LLAMA_API struct llama_sampler * llama_sampler_init_dist(uint32_t seed); /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 /// Setting k <= 0 makes this a noop @@ -1304,6 +1388,33 @@ extern "C" { const char ** seq_breakers, size_t num_breakers); + /// adaptive-p: select tokens near a configurable target probability over time. + /// + /// the adaptive-p sampler transforms the token probability distribution to favor tokens + /// that fall near a user-configurable probability target. + /// + /// internally, the sampler maintains an exponential moving average of the *ORIGINAL* + /// probabilities of selected tokens at each sampling step. it uses this EMA to compute an + /// adapted target probability at each sampling step, thus maintaining the desired target + /// probability over time. + /// + /// adaptive-p selects a token ID rather than just mutating candidates, so it must be last + /// in the sampler chain (like mirostat, dist, greedy). + /// + /// only mild truncation before this sampler is recommended. we suggest applying min-p + /// before adaptive-p as the only other active sampler in the chain. + /// + /// @param target select tokens near this probability (valid range 0.0 to 1.0; negative = disabled) + /// @param decay EMA decay for adaptation; history ≈ 1/(1-decay) tokens (valid range 0.0 - 0.99) + /// @param seed RNG seed + /// + /// ref: https://github.com/ggml-org/llama.cpp/pull/17927 + /// + LLAMA_API struct llama_sampler * llama_sampler_init_adaptive_p( + float target, + float decay, + uint32_t seed); + LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias( int32_t n_vocab, int32_t n_logit_bias, @@ -1357,12 +1468,12 @@ extern "C" { /// @details Build a split GGUF final path for this chunk. /// llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf" // Returns the split_path length. - LLAMA_API int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count); + LLAMA_API int32_t llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int32_t split_no, int32_t split_count); /// @details Extract the path prefix from the split_path if and only if the split_no and split_count match. /// llama_split_prefix(split_prefix, 64, "/models/ggml-model-q4_0-00002-of-00004.gguf", 2, 4) => split_prefix = "/models/ggml-model-q4_0" // Returns the split_prefix length. - LLAMA_API int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count); + LLAMA_API int32_t llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int32_t split_no, int32_t split_count); // Print system information LLAMA_API const char * llama_print_system_info(void); diff --git a/llama/llama.cpp/src/llama-adapter.cpp b/llama/llama.cpp/src/llama-adapter.cpp index d8eef75a7ad..d6a5800e63a 100644 --- a/llama/llama.cpp/src/llama-adapter.cpp +++ b/llama/llama.cpp/src/llama-adapter.cpp @@ -411,6 +411,9 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ } } + // register adapter with model + model.loras.insert(&adapter); + LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2); } @@ -468,8 +471,8 @@ int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter, return snprintf(buf, buf_size, "%s", it->second.c_str()); } -void llama_adapter_lora_free(llama_adapter_lora * adapter) { - delete adapter; +void llama_adapter_lora_free(llama_adapter_lora *) { + // deprecated: adapters are freed by llama_model's destructor } uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter) { diff --git a/llama/llama.cpp/src/llama-adapter.h b/llama/llama.cpp/src/llama-adapter.h index 4f65247c0fe..aa3ab63ad75 100644 --- a/llama/llama.cpp/src/llama-adapter.h +++ b/llama/llama.cpp/src/llama-adapter.h @@ -39,6 +39,8 @@ struct llama_adapter_cvec { std::vector tensors; // per layer }; +using llama_adapter_cvec_ptr = std::shared_ptr; + // // llama_adapter_lora // @@ -77,6 +79,11 @@ struct llama_adapter_lora { ~llama_adapter_lora() = default; llama_adapter_lora_weight * get_weight(ggml_tensor * w); + + uint32_t get_n_nodes() const { + return ab_map.size() * 6u; // a, b, scale, add, 2 x mul_mat + } }; using llama_adapter_loras = std::unordered_map; +using llama_adapter_loras_ptr = std::unique_ptr; diff --git a/llama/llama.cpp/src/llama-arch.cpp b/llama/llama.cpp/src/llama-arch.cpp index 2ce8ffec022..977783cbe38 100644 --- a/llama/llama.cpp/src/llama-arch.cpp +++ b/llama/llama.cpp/src/llama-arch.cpp @@ -20,11 +20,13 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_STARCODER, "starcoder" }, { LLM_ARCH_REFACT, "refact" }, { LLM_ARCH_BERT, "bert" }, + { LLM_ARCH_MODERN_BERT, "modern-bert" }, { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, { LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" }, { LLM_ARCH_NEO_BERT, "neo-bert" }, { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, { LLM_ARCH_JINA_BERT_V3, "jina-bert-v3" }, + { LLM_ARCH_EUROBERT, "eurobert" }, { LLM_ARCH_BLOOM, "bloom" }, { LLM_ARCH_STABLELM, "stablelm" }, { LLM_ARCH_QWEN, "qwen" }, @@ -36,11 +38,14 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN3NEXT, "qwen3next" }, { LLM_ARCH_QWEN3VL, "qwen3vl" }, { LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" }, + { LLM_ARCH_QWEN35, "qwen35" }, + { LLM_ARCH_QWEN35MOE, "qwen35moe" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PHIMOE, "phimoe" }, { LLM_ARCH_PLAMO, "plamo" }, { LLM_ARCH_PLAMO2, "plamo2" }, + { LLM_ARCH_PLAMO3, "plamo3" }, { LLM_ARCH_CODESHELL, "codeshell" }, { LLM_ARCH_ORION, "orion" }, { LLM_ARCH_INTERNLM2, "internlm2" }, @@ -70,15 +75,18 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_GLM4_MOE, "glm4moe" }, + { LLM_ARCH_GLM_DSA, "glm-dsa" }, { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, { LLM_ARCH_JAIS, "jais" }, + { LLM_ARCH_JAIS2, "jais2" }, { LLM_ARCH_NEMOTRON, "nemotron" }, { LLM_ARCH_NEMOTRON_H, "nemotron_h" }, { LLM_ARCH_NEMOTRON_H_MOE, "nemotron_h_moe" }, { LLM_ARCH_EXAONE, "exaone" }, { LLM_ARCH_EXAONE4, "exaone4" }, + { LLM_ARCH_EXAONE_MOE, "exaone-moe" }, { LLM_ARCH_RWKV6, "rwkv6" }, { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, { LLM_ARCH_RWKV7, "rwkv7" }, @@ -115,6 +123,12 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, { LLM_ARCH_MISTRAL3, "mistral3" }, + { LLM_ARCH_PADDLEOCR, "paddleocr" }, + { LLM_ARCH_MIMO2, "mimo2" }, + { LLM_ARCH_STEP35, "step35" }, + { LLM_ARCH_LLAMA_EMBED, "llama-embed" }, + { LLM_ARCH_MAINCODER, "maincoder" }, + { LLM_ARCH_KIMI_LINEAR, "kimi-linear" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -148,6 +162,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, + { LLM_KV_EMBEDDING_LENGTH_OUT, "%s.embedding_length_out" }, { LLM_KV_FEATURES_LENGTH, "%s.features_length" }, { LLM_KV_BLOCK_COUNT, "%s.block_count" }, { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, @@ -155,6 +170,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" }, { LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" }, { LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, "%s.expert_chunk_feed_forward_length" }, + { LLM_KV_SWIGLU_CLAMP_EXP, "%s.swiglu_clamp_exp" }, + { LLM_KV_SWIGLU_CLAMP_SHEXP, "%s.swiglu_clamp_shexp" }, { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, @@ -185,6 +202,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" }, { LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" }, { LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" }, + { LLM_KV_FULL_ATTENTION_INTERVAL, "%s.full_attention_interval" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -205,6 +223,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_GATE_LORA_RANK, "%s.attention.gate_lora_rank" }, { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, + { LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, "%s.attention.sliding_window_pattern" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" }, { LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" }, @@ -212,21 +231,25 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, + { LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, "%s.attention.indexer.head_count" }, + { LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" }, + { LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" }, - { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, - { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, - { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, - { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, - { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, - { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, - { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, - { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, - { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, - { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, - { LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, "%s.rope.scaling.yarn_ext_factor" }, - { LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, "%s.rope.scaling.yarn_attn_factor" }, - { LLM_KV_ROPE_SCALING_YARN_BETA_FAST, "%s.rope.scaling.yarn_beta_fast" }, - { LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, "%s.rope.scaling.yarn_beta_slow" }, + { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, + { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, + { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, + { LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" }, + { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, + { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, + { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, + { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, + { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, + { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, + { LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, "%s.rope.scaling.yarn_ext_factor" }, + { LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, "%s.rope.scaling.yarn_attn_factor" }, + { LLM_KV_ROPE_SCALING_YARN_BETA_FAST, "%s.rope.scaling.yarn_beta_fast" }, + { LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, "%s.rope.scaling.yarn_beta_slow" }, { LLM_KV_SPLIT_NO, "split.no" }, { LLM_KV_SPLIT_COUNT, "split.count" }, @@ -239,6 +262,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" }, { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, + { LLM_KV_KDA_HEAD_DIM, "%s.kda.head_dim" }, + { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, { LLM_KV_POSNET_EMBEDDING_LENGTH, "%s.posnet.embedding_length" }, @@ -326,6 +351,7 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_GATE_UP_EXPS, "blk.%d.ffn_gate_up_exps" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, @@ -348,12 +374,14 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_TOKEN_TYPES, "token_types" }, { LLM_TENSOR_CLS, "cls" }, { LLM_TENSOR_CLS_OUT, "cls.output" }, + { LLM_TENSOR_CLS_NORM, "cls.norm" }, { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" }, { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, { LLM_TENSOR_SSM_A_NOSCAN, "blk.%d.ssm_a" }, { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, { LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" }, + { LLM_TENSOR_SSM_ALPHA, "blk.%d.ssm_alpha" }, { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, @@ -365,6 +393,15 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" }, { LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" }, { LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" }, + { LLM_TENSOR_SSM_CONV1D_Q, "blk.%d.ssm_conv1d_q" }, + { LLM_TENSOR_SSM_CONV1D_K, "blk.%d.ssm_conv1d_k" }, + { LLM_TENSOR_SSM_CONV1D_V, "blk.%d.ssm_conv1d_v" }, + { LLM_TENSOR_SSM_F_A, "blk.%d.ssm_f_a" }, + { LLM_TENSOR_SSM_F_B, "blk.%d.ssm_f_b" }, + { LLM_TENSOR_SSM_BETA, "blk.%d.ssm_beta" }, + { LLM_TENSOR_SSM_G_A, "blk.%d.ssm_g_a" }, + { LLM_TENSOR_SSM_G_B, "blk.%d.ssm_g_b" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, @@ -491,6 +528,10 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_VISEXP_FFN_GATE, "blk.%d.vis_gate" }, { LLM_TENSOR_VISEXP_FFN_DOWN, "blk.%d.vis_down" }, { LLM_TENSOR_VISEXP_FFN_UP, "blk.%d.vis_up" }, + { LLM_TENSOR_INDEXER_K_NORM, "blk.%d.indexer.k_norm" }, + { LLM_TENSOR_INDEXER_PROJ, "blk.%d.indexer.proj" }, + { LLM_TENSOR_INDEXER_ATTN_K, "blk.%d.indexer.attn_k" }, + { LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" }, }; static std::set llm_get_tensor_names(llm_arch arch) { @@ -500,6 +541,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { case LLM_ARCH_LLAMA: case LLM_ARCH_DECI: case LLM_ARCH_MISTRAL3: + case LLM_ARCH_LLAMA_EMBED: return { LLM_TENSOR_TOKEN_EMBD, LLM_TENSOR_OUTPUT_NORM, @@ -703,6 +745,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { case LLM_ARCH_INTERNLM2: case LLM_ARCH_GRANITE: case LLM_ARCH_ERNIE4_5: + case LLM_ARCH_PADDLEOCR: case LLM_ARCH_SMOLLM3: case LLM_ARCH_DREAM: case LLM_ARCH_LLADA: @@ -781,6 +824,35 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_CLS, LLM_TENSOR_CLS_OUT, }; + case LLM_ARCH_EUROBERT: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_DOWN, + }; + case LLM_ARCH_MODERN_BERT: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_TOKEN_EMBD_NORM, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_CLS, + LLM_TENSOR_CLS_OUT, + LLM_TENSOR_CLS_NORM, + }; case LLM_ARCH_JINA_BERT_V2: return { LLM_TENSOR_TOKEN_EMBD, @@ -930,11 +1002,13 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_ATTN_V, LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_GATE, LLM_TENSOR_FFN_GATE_INP, LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_DOWN_EXPS, LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_UP_EXPS, LLM_TENSOR_FFN_GATE_INP_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_DOWN_SHEXP, @@ -947,6 +1021,64 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, }; + case LLM_ARCH_QWEN35: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_GATE, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_SSM_A_NOSCAN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_BETA, + LLM_TENSOR_SSM_ALPHA, + LLM_TENSOR_SSM_NORM, + LLM_TENSOR_SSM_OUT, + }; + case LLM_ARCH_QWEN35MOE: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_GATE, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_UP_EXPS, + LLM_TENSOR_FFN_GATE_INP_SHEXP, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_SSM_A_NOSCAN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_BETA, + LLM_TENSOR_SSM_ALPHA, + LLM_TENSOR_SSM_NORM, + LLM_TENSOR_SSM_OUT, + }; case LLM_ARCH_QWEN3VL: case LLM_ARCH_CHAMELEON: case LLM_ARCH_HUNYUAN_DENSE: @@ -1060,6 +1192,22 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_ATTN_POST_NORM, LLM_TENSOR_FFN_POST_NORM, }; + case LLM_ARCH_PLAMO3: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_POST_NORM, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; case LLM_ARCH_CODESHELL: return { LLM_TENSOR_TOKEN_EMBD, @@ -1459,6 +1607,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_DOWN_EXPS, LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_UP_EXPS, LLM_TENSOR_FFN_GATE_INP_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_DOWN_SHEXP, @@ -1511,6 +1660,12 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_DOWN, LLM_TENSOR_ATTN_POST_NORM, LLM_TENSOR_FFN_POST_NORM, + LLM_TENSOR_NEXTN_EH_PROJ, + LLM_TENSOR_NEXTN_EMBED_TOKENS, + LLM_TENSOR_NEXTN_ENORM, + LLM_TENSOR_NEXTN_HNORM, + LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, + LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, }; case LLM_ARCH_GLM4_MOE: return { @@ -1543,6 +1698,46 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, }; + case LLM_ARCH_GLM_DSA: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q_A_NORM, + LLM_TENSOR_ATTN_KV_A_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_A, + LLM_TENSOR_ATTN_Q_B, + LLM_TENSOR_ATTN_KV_A_MQA, + LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_INP_SHEXP, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_INDEXER_K_NORM, + LLM_TENSOR_INDEXER_PROJ, + LLM_TENSOR_INDEXER_ATTN_K, + LLM_TENSOR_INDEXER_ATTN_Q_B, + LLM_TENSOR_NEXTN_EH_PROJ, + LLM_TENSOR_NEXTN_EMBED_TOKENS, + LLM_TENSOR_NEXTN_ENORM, + LLM_TENSOR_NEXTN_HNORM, + LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, + LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, + }; case LLM_ARCH_BITNET: return { LLM_TENSOR_TOKEN_EMBD, @@ -1621,6 +1816,20 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_GATE, LLM_TENSOR_FFN_DOWN, }; + case LLM_ARCH_JAIS2: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_DOWN, + }; case LLM_ARCH_NEMOTRON_H: return { LLM_TENSOR_TOKEN_EMBD, @@ -1690,6 +1899,38 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_UP, LLM_TENSOR_FFN_POST_NORM, }; + case LLM_ARCH_EXAONE_MOE: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_NEXTN_EH_PROJ, + LLM_TENSOR_NEXTN_EMBED_TOKENS, + LLM_TENSOR_NEXTN_ENORM, + LLM_TENSOR_NEXTN_HNORM, + LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, + LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, + }; case LLM_ARCH_RWKV6: return { LLM_TENSOR_TOKEN_EMBD, @@ -2040,6 +2281,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_TOKEN_EMBD, LLM_TENSOR_OUTPUT_NORM_LFM2, LLM_TENSOR_OUTPUT, + LLM_TENSOR_DENSE_2_OUT, }; case LLM_ARCH_LFM2MOE: return { @@ -2058,7 +2300,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_SHORTCONV_INPROJ, LLM_TENSOR_SHORTCONV_OUTPROJ, LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT_NORM_LFM2, LLM_TENSOR_FFN_GATE_INP, LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_DOWN_EXPS, @@ -2174,27 +2416,142 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_VISEXP_FFN_DOWN, LLM_TENSOR_VISEXP_FFN_UP, }; + case LLM_ARCH_MIMO2: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_SINKS, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_EXP_PROBS_B, + }; + case LLM_ARCH_STEP35: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ROPE_FACTORS_LONG, + LLM_TENSOR_ROPE_FACTORS_SHORT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_GATE, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, + }; case LLM_ARCH_GPTJ: case LLM_ARCH_UNKNOWN: return { LLM_TENSOR_TOKEN_EMBD, }; - case LLM_ARCH_SOLAR: + case LLM_ARCH_MAINCODER: return { LLM_TENSOR_TOKEN_EMBD, LLM_TENSOR_OUTPUT_NORM, LLM_TENSOR_OUTPUT, LLM_TENSOR_ATTN_NORM, LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_ATTN_V, LLM_TENSOR_ATTN_OUT, LLM_TENSOR_FFN_NORM, LLM_TENSOR_FFN_GATE, LLM_TENSOR_FFN_DOWN, LLM_TENSOR_FFN_UP, + }; + case LLM_ARCH_SOLAR: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, LLM_TENSOR_BSKCN_TV, }; + case LLM_ARCH_KIMI_LINEAR: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + // Dense FFN (layer 0 only) + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + // MoE FFN (layers 1+) + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_EXP_PROBS_B, + // Shared experts + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + // KDA (using SSM_ enum prefix, keeping GGUF names for backward compat) + LLM_TENSOR_SSM_CONV1D_Q, + LLM_TENSOR_SSM_CONV1D_K, + LLM_TENSOR_SSM_CONV1D_V, + LLM_TENSOR_SSM_F_A, + LLM_TENSOR_SSM_F_B, + LLM_TENSOR_SSM_BETA, + LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_G_A, + LLM_TENSOR_SSM_G_B, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_NORM, + // MLA + LLM_TENSOR_ATTN_Q_A, + LLM_TENSOR_ATTN_Q_B, + LLM_TENSOR_ATTN_Q_A_NORM, + LLM_TENSOR_ATTN_KV_A_MQA, + LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, + LLM_TENSOR_ATTN_KV_A_NORM, + }; default: GGML_ABORT("unknown architecture for tensor mapping"); } @@ -2218,6 +2575,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output {LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, @@ -2270,6 +2628,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SSM_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_DT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_ALPHA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_BETA_ALPHA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, @@ -2298,6 +2657,15 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SSM_C_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + // Kimi KDA - Conv tensors are 4D [d_conv, 1, d_inner, 1], reshaped to 2D at runtime + {LLM_TENSOR_SSM_CONV1D_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_CONV1D_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_CONV1D_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_F_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_F_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_BETA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_G_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_G_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -2340,6 +2708,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, + {LLM_TENSOR_FFN_GATE_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_DOWN_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_GATE_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_UP_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, @@ -2388,6 +2757,10 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_VISEXP_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_VISEXP_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_VISEXP_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_INDEXER_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, // NextN/MTP tensors are currently ignored (reserved for future MTP support) // These tensors only exist in the last layer(s) and are treated as output tensors {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, @@ -2480,6 +2853,9 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_NEMOTRON_H: case LLM_ARCH_NEMOTRON_H_MOE: case LLM_ARCH_QWEN3NEXT: + case LLM_ARCH_KIMI_LINEAR: + case LLM_ARCH_QWEN35: + case LLM_ARCH_QWEN35MOE: return true; default: return false; diff --git a/llama/llama.cpp/src/llama-arch.h b/llama/llama.cpp/src/llama-arch.h index 14d461c763f..e9f2739ac12 100644 --- a/llama/llama.cpp/src/llama-arch.h +++ b/llama/llama.cpp/src/llama-arch.h @@ -24,11 +24,13 @@ enum llm_arch { LLM_ARCH_STARCODER, LLM_ARCH_REFACT, LLM_ARCH_BERT, + LLM_ARCH_MODERN_BERT, LLM_ARCH_NOMIC_BERT, LLM_ARCH_NOMIC_BERT_MOE, LLM_ARCH_NEO_BERT, LLM_ARCH_JINA_BERT_V2, LLM_ARCH_JINA_BERT_V3, + LLM_ARCH_EUROBERT, LLM_ARCH_BLOOM, LLM_ARCH_STABLELM, LLM_ARCH_QWEN, @@ -40,11 +42,14 @@ enum llm_arch { LLM_ARCH_QWEN3NEXT, LLM_ARCH_QWEN3VL, LLM_ARCH_QWEN3VLMOE, + LLM_ARCH_QWEN35, + LLM_ARCH_QWEN35MOE, LLM_ARCH_PHI2, LLM_ARCH_PHI3, LLM_ARCH_PHIMOE, LLM_ARCH_PLAMO, LLM_ARCH_PLAMO2, + LLM_ARCH_PLAMO3, LLM_ARCH_CODESHELL, LLM_ARCH_ORION, LLM_ARCH_INTERNLM2, @@ -74,15 +79,18 @@ enum llm_arch { LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, LLM_ARCH_GLM4_MOE, + LLM_ARCH_GLM_DSA, LLM_ARCH_BITNET, LLM_ARCH_T5, LLM_ARCH_T5ENCODER, LLM_ARCH_JAIS, + LLM_ARCH_JAIS2, LLM_ARCH_NEMOTRON, LLM_ARCH_NEMOTRON_H, LLM_ARCH_NEMOTRON_H_MOE, LLM_ARCH_EXAONE, LLM_ARCH_EXAONE4, + LLM_ARCH_EXAONE_MOE, LLM_ARCH_RWKV6, LLM_ARCH_RWKV6QWEN2, LLM_ARCH_RWKV7, @@ -119,6 +127,12 @@ enum llm_arch { LLM_ARCH_RND1, LLM_ARCH_PANGU_EMBED, LLM_ARCH_MISTRAL3, + LLM_ARCH_PADDLEOCR, + LLM_ARCH_MIMO2, + LLM_ARCH_STEP35, + LLM_ARCH_LLAMA_EMBED, + LLM_ARCH_MAINCODER, + LLM_ARCH_KIMI_LINEAR, LLM_ARCH_UNKNOWN, }; @@ -152,6 +166,7 @@ enum llm_kv { LLM_KV_VOCAB_SIZE, LLM_KV_CONTEXT_LENGTH, LLM_KV_EMBEDDING_LENGTH, + LLM_KV_EMBEDDING_LENGTH_OUT, LLM_KV_FEATURES_LENGTH, LLM_KV_BLOCK_COUNT, LLM_KV_LEADING_DENSE_BLOCK_COUNT, @@ -159,6 +174,8 @@ enum llm_kv { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, + LLM_KV_SWIGLU_CLAMP_EXP, + LLM_KV_SWIGLU_CLAMP_SHEXP, LLM_KV_USE_PARALLEL_RESIDUAL, LLM_KV_TENSOR_DATA_LAYOUT, LLM_KV_EXPERT_COUNT, @@ -189,6 +206,7 @@ enum llm_kv { LLM_KV_EMBEDDING_SCALE, LLM_KV_TOKEN_SHIFT_COUNT, LLM_KV_INTERLEAVE_MOE_LAYER_STEP, + LLM_KV_FULL_ATTENTION_INTERVAL, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -209,6 +227,7 @@ enum llm_kv { LLM_KV_ATTENTION_GATE_LORA_RANK, LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, LLM_KV_ATTENTION_SLIDING_WINDOW, + LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_OUTPUT_SCALE, LLM_KV_ATTENTION_TEMPERATURE_LENGTH, @@ -216,10 +235,14 @@ enum llm_kv { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, + LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, + LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, + LLM_KV_ATTENTION_INDEXER_TOP_K, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS, LLM_KV_ROPE_FREQ_BASE, + LLM_KV_ROPE_FREQ_BASE_SWA, LLM_KV_ROPE_SCALE_LINEAR, LLM_KV_ROPE_SCALING_TYPE, LLM_KV_ROPE_SCALING_FACTOR, @@ -243,6 +266,8 @@ enum llm_kv { LLM_KV_SSM_GROUP_COUNT, LLM_KV_SSM_DT_B_C_RMS, + LLM_KV_KDA_HEAD_DIM, + LLM_KV_WKV_HEAD_SIZE, LLM_KV_TOKENIZER_MODEL, @@ -350,6 +375,7 @@ enum llm_tensor { LLM_TENSOR_FFN_DOWN_EXPS, // merged experts LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_UP_EXPS, LLM_TENSOR_FFN_DOWN_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_UP_SHEXP, @@ -391,6 +417,16 @@ enum llm_tensor { LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, LLM_TENSOR_SSM_BETA_ALPHA, // qwen3next + LLM_TENSOR_SSM_ALPHA, // qwen3.5 + // Kimi Linear KDA (using SSM_ prefix for consistency) + LLM_TENSOR_SSM_CONV1D_Q, // kimi: Q conv1d weight + LLM_TENSOR_SSM_CONV1D_K, // kimi: K conv1d weight + LLM_TENSOR_SSM_CONV1D_V, // kimi: V conv1d weight + LLM_TENSOR_SSM_F_A, // kimi: forget gate projection A + LLM_TENSOR_SSM_F_B, // kimi: forget gate projection B + LLM_TENSOR_SSM_BETA, // kimi: beta mixing coefficient and qwen3.5 + LLM_TENSOR_SSM_G_A, // kimi: output gate projection A + LLM_TENSOR_SSM_G_B, // kimi: output gate projection B LLM_TENSOR_TIME_MIX_W0, LLM_TENSOR_TIME_MIX_W1, LLM_TENSOR_TIME_MIX_W2, @@ -467,6 +503,7 @@ enum llm_tensor { LLM_TENSOR_ENC_OUTPUT_NORM, LLM_TENSOR_CLS, LLM_TENSOR_CLS_OUT, + LLM_TENSOR_CLS_NORM, LLM_TENSOR_BSKCN_TV, LLM_TENSOR_CONV1D, LLM_TENSOR_CONVNEXT_DW, @@ -492,6 +529,10 @@ enum llm_tensor { LLM_TENSOR_VISEXP_FFN_GATE, LLM_TENSOR_VISEXP_FFN_DOWN, LLM_TENSOR_VISEXP_FFN_UP, + LLM_TENSOR_INDEXER_K_NORM, + LLM_TENSOR_INDEXER_PROJ, + LLM_TENSOR_INDEXER_ATTN_K, + LLM_TENSOR_INDEXER_ATTN_Q_B, LLM_TENSOR_NEXTN_EH_PROJ, LLM_TENSOR_NEXTN_EMBED_TOKENS, LLM_TENSOR_NEXTN_ENORM, diff --git a/llama/llama.cpp/src/llama-chat.cpp b/llama/llama.cpp/src/llama-chat.cpp index fc6a6223cfe..c415a998f33 100644 --- a/llama/llama.cpp/src/llama-chat.cpp +++ b/llama/llama.cpp/src/llama-chat.cpp @@ -57,6 +57,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "minicpm", LLM_CHAT_TEMPLATE_MINICPM }, { "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 }, { "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 }, + { "exaone-moe", LLM_CHAT_TEMPLATE_EXAONE_MOE }, { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD }, { "granite", LLM_CHAT_TEMPLATE_GRANITE }, { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, @@ -74,6 +75,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS }, { "grok-2", LLM_CHAT_TEMPLATE_GROK_2 }, { "pangu-embedded", LLM_CHAT_TEMPLATE_PANGU_EMBED }, + { "solar-open", LLM_CHAT_TEMPLATE_SOLAR_OPEN }, }; llm_chat_template llm_chat_template_from_str(const std::string & name) { @@ -136,6 +138,9 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { } else if (tmpl_contains("[gMASK]")) { return LLM_CHAT_TEMPLATE_CHATGLM_4; } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) { + if (tmpl_contains("<|tool_declare|>")) { + return LLM_CHAT_TEMPLATE_EXAONE_MOE; + } return tmpl_contains("") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE; } else if (tmpl_contains("<|{{ item['role'] }}|>") && tmpl_contains("<|begin_of_image|>")) { return LLM_CHAT_TEMPLATE_GLMEDGE; @@ -216,6 +221,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_GROK_2; } else if (tmpl_contains(LU8("[unused9]系统:[unused10]"))) { return LLM_CHAT_TEMPLATE_PANGU_EMBED; + } else if (tmpl_contains("<|begin|>") && tmpl_contains("<|end|>") && tmpl_contains("<|content|>")) { + return LLM_CHAT_TEMPLATE_SOLAR_OPEN; } return LLM_CHAT_TEMPLATE_UNKNOWN; } @@ -226,7 +233,7 @@ int32_t llm_chat_apply_template( llm_chat_template tmpl, const std::vector & chat, std::string & dest, bool add_ass) { - // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 + // Taken from the research: https://github.com/ggml-org/llama.cpp/issues/5527 std::stringstream ss; if (tmpl == LLM_CHAT_TEMPLATE_CHATML) { // chatml template @@ -573,6 +580,22 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "[|assistant|]"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_MOE) { + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "<|system|>\n" << trim(message->content) << "<|endofturn|>\n"; + } else if (role == "user") { + ss << "<|user|>\n" << trim(message->content) << "<|endofturn|>\n"; + } else if (role == "assistant") { + ss << "<|assistant|>\n" << trim(message->content) << "<|endofturn|>\n"; + } else if (role == "tool") { + ss << "<|tool|>\n" << trim(message->content) << "<|endofturn|>\n"; + } + } + if (add_ass) { + ss << "<|assistant|>\n"; + } } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) { // this template requires the model to have "\n\n" as EOT token for (size_t i = 0; i < chat.size(); i++) { @@ -845,6 +868,14 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "[unused9]助手:"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_SOLAR_OPEN) { + for (auto message : chat) { + std::string role(message->role); + ss << "<|begin|>" << role << "<|content|>" << message->content << "<|end|>"; + } + if (add_ass) { + ss << "<|begin|>assistant"; + } } else { // template not supported return -1; diff --git a/llama/llama.cpp/src/llama-chat.h b/llama/llama.cpp/src/llama-chat.h index 684efb4d67f..9ed1db128ec 100644 --- a/llama/llama.cpp/src/llama-chat.h +++ b/llama/llama.cpp/src/llama-chat.h @@ -36,6 +36,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_MINICPM, LLM_CHAT_TEMPLATE_EXAONE_3, LLM_CHAT_TEMPLATE_EXAONE_4, + LLM_CHAT_TEMPLATE_EXAONE_MOE, LLM_CHAT_TEMPLATE_RWKV_WORLD, LLM_CHAT_TEMPLATE_GRANITE, LLM_CHAT_TEMPLATE_GIGACHAT, @@ -54,6 +55,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_SEED_OSS, LLM_CHAT_TEMPLATE_GROK_2, LLM_CHAT_TEMPLATE_PANGU_EMBED, + LLM_CHAT_TEMPLATE_SOLAR_OPEN, LLM_CHAT_TEMPLATE_UNKNOWN, }; diff --git a/llama/llama.cpp/src/llama-context.cpp b/llama/llama.cpp/src/llama-context.cpp index 9e699827276..964bb3220c1 100644 --- a/llama/llama.cpp/src/llama-context.cpp +++ b/llama/llama.cpp/src/llama-context.cpp @@ -22,6 +22,8 @@ llama_context::llama_context( const llama_model & model, llama_context_params params) : model(model), + cvec(std::make_unique()), + loras(std::make_unique()), balloc(std::make_unique(model.hparams.n_pos_per_embd())) { // TODO warning when creating llama_context with awkward ctx size that is not a power of 2, // may need to be backend-dependent @@ -60,6 +62,25 @@ llama_context::llama_context( cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; + // Initialize backend samplers here so they are part of the sampling graph + // before the reserve passes run later in this function. This avoids a later + // re-reserve when graph nodes change. + if (params.samplers != nullptr && params.n_samplers > 0) { + for (size_t i = 0; i < params.n_samplers; ++i) { + const auto & config = params.samplers[i]; + + if (llama_sampler_chain_get(config.sampler, -1) == nullptr) { + throw std::runtime_error("the backend samplers must be of type llama_sampler_chain"); + } + + if (set_sampler(config.seq_id, config.sampler)) { + const int n_samplers = llama_sampler_chain_n(config.sampler); + + LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers); + } + } + } + auto rope_scaling_type = params.rope_scaling_type; if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) { rope_scaling_type = hparams.rope_scaling_type_train; @@ -127,6 +148,7 @@ llama_context::llama_context( } cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED; + cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO; // with causal attention, the batch size is limited by the context size cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; @@ -136,6 +158,9 @@ llama_context::llama_context( cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; + // intialized later + cparams.pipeline_parallel = false; + { const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE"); graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable; @@ -230,7 +255,6 @@ llama_context::llama_context( // graph outputs buffer { - // resized during inference when a batch uses more outputs if (output_reserve(params.n_seq_max) < params.n_seq_max) { throw std::runtime_error("failed to reserve initial output buffer"); } @@ -280,22 +304,12 @@ llama_context::llama_context( LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size()); - const uint32_t n_seqs = cparams.n_seq_max; - const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - - const size_t max_nodes = this->graph_max_nodes(n_tokens); - - LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); - - gf_res_prev.reset(new llm_graph_result(max_nodes)); - gf_res_reserve.reset(new llm_graph_result(max_nodes)); - // TODO: move these checks to ggml_backend_sched // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary bool pipeline_parallel = model.n_devices() > 1 && - model.params.n_gpu_layers > (int) model.hparams.n_layer && - model.params.split_mode == LLAMA_SPLIT_MODE_LAYER && + model.n_gpu_layers() > model.hparams.n_layer && + model.split_mode() == LLAMA_SPLIT_MODE_LAYER && cparams.offload_kqv && !model.has_tensor_overrides(); @@ -305,6 +319,7 @@ llama_context::llama_context( auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) { // ignore CPU backend + // TODO: should we ignore ACCEL types too? continue; } auto * dev = ggml_backend_get_device(backend.get()); @@ -318,168 +333,218 @@ llama_context::llama_context( } } - sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload)); + cparams.pipeline_parallel = pipeline_parallel; - if (pipeline_parallel) { - LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get())); + if (cparams.pipeline_parallel) { + LLAMA_LOG_INFO("%s: pipeline parallelism enabled\n", __func__); } - llama_memory_context_ptr mctx; - if (memory) { - LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__); - mctx = memory->init_full(); - if (!mctx) { - throw std::runtime_error("failed to initialize memory module"); + sched_reserve(); + + if (!cparams.flash_attn) { + if (ggml_is_quantized(params.type_v)) { + throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention"); } } + } - cross.v_embd.clear(); + // Initialize the full vocabulary token ids for backend samplers. + { + const int n_vocab = model.vocab.n_tokens(); - // avoid reserving graphs with zero outputs - assume one output per sequence - n_outputs = n_seqs; + sampling.token_ids_full_vocab.resize(n_vocab); + for (int i = 0; i < n_vocab; ++i) { + sampling.token_ids_full_vocab[i] = i; + } + } +} - LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); +llama_context::~llama_context() { + if (!model.hparams.no_alloc) { + for (size_t i = 0; i < backend_ptrs.size(); ++i) { + ggml_backend_t backend = backend_ptrs[i]; + ggml_backend_buffer_type_t buft = backend_buft[i]; - // resolve automatic Flash Attention use - if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { - auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); - if (!gf) { - throw std::runtime_error("failed to split graph for Flash Attention check"); + const size_t size_exp = backend_buf_exp_size[i]; + const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend); + if (size_exp == size_act) { + LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n", + __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); + } else { + LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n", + __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); } + } + } + ggml_opt_free(opt_ctx); +} - const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; - bool fa_device_mismatch = false; - for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { - ggml_tensor * n = ggml_graph_node(gf, i); - if (n->op != GGML_OP_FLASH_ATTN_EXT) { - continue; - } - ggml_backend_dev_t device_fa = ggml_backend_get_device( +void llama_context::sched_reserve() { + if (!sched_need_reserve) { + return; + } + + sched_need_reserve = false; + + LLAMA_LOG_INFO("%s: reserving ...\n", __func__); + + synchronize(); + + const int64_t t_start_us = ggml_time_us(); + + const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + + const size_t max_nodes = this->graph_max_nodes(n_tokens); + + LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); + + gf_res_prev.reset(new llm_graph_result(max_nodes)); + gf_res_reserve.reset(new llm_graph_result(max_nodes)); + + sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, cparams.pipeline_parallel, cparams.op_offload)); + + llama_memory_context_ptr mctx; + if (memory) { + LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__); + mctx = memory->init_full(); + if (!mctx) { + throw std::runtime_error("failed to initialize memory module"); + } + } + + // avoid reserving graphs with zero outputs - assume one output per sequence + const int n_outputs = n_seqs; + + LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); + + // resolve automatic Flash Attention use + if (cparams.auto_fa) { + auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); + if (!gf) { + throw std::runtime_error("failed to split graph for Flash Attention check"); + } + + const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; + bool fa_device_mismatch = false; + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + ggml_tensor * n = ggml_graph_node(gf, i); + if (n->op != GGML_OP_FLASH_ATTN_EXT) { + continue; + } + ggml_backend_dev_t device_fa = ggml_backend_get_device( ggml_backend_sched_get_tensor_backend(sched.get(), n)); - // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer - GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0); - const int il = std::stoi(n->name + prefix_len); - ggml_backend_dev_t device_kv = model.dev_layer(il); - if (device_fa != device_kv) { - LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor " + // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer + GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0); + const int il = std::stoi(n->name + prefix_len); + ggml_backend_dev_t device_kv = model.dev_layer(il); + if (device_fa != device_kv) { + LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor " "is assigned to device %s (usually due to missing support)\n", __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa)); - // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways - fa_device_mismatch = true; - break; - } - } - if (fa_device_mismatch) { - cparams.flash_attn = false; - LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__); - if (ggml_is_quantized(params.type_v)) { - throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention"); - } - } else { - cparams.flash_attn = true; - LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__); + // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways + fa_device_mismatch = true; + break; } } + if (fa_device_mismatch) { + cparams.flash_attn = false; + LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__); + } else { + cparams.flash_attn = true; + LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__); + } - // reserve worst-case graph - int n_splits_pp = -1; - int n_nodes_pp = -1; + cparams.auto_fa = false; + } - int n_splits_tg = -1; - int n_nodes_tg = -1; + // reserve worst-case graph + int n_splits_pp = -1; + int n_nodes_pp = -1; - // reserve pp (prompt processing) graph first so that buffers are only allocated once - { - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), + int n_splits_tg = -1; + int n_nodes_tg = -1; + + // reserve pp (prompt processing) graph first so that buffers are only allocated once + { + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr); + if (!gf) { + if (cparams.pipeline_parallel) { + LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); + cparams.pipeline_parallel = false; + sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload)); + gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); + } if (!gf) { - if (pipeline_parallel) { - LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); - sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload)); - gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); - } - if (!gf) { - throw std::runtime_error("failed to allocate compute pp buffers"); - } + throw std::runtime_error("failed to allocate compute pp buffers"); } - - n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); - n_nodes_pp = ggml_graph_n_nodes(gf); } - // reserve with tg (token generation) graph to get the number of splits and nodes - { - auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc); - if (!gf) { - throw std::runtime_error("failed to allocate compute tg buffers"); - } + n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); + n_nodes_pp = ggml_graph_n_nodes(gf); + } - n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); - n_nodes_tg = ggml_graph_n_nodes(gf); + // reserve with tg (token generation) graph to get the number of splits and nodes + { + auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc); + if (!gf) { + throw std::runtime_error("failed to allocate compute tg buffers"); } - // reserve again with pp graph to avoid ggml-alloc reallocations during inference - { - // TODO: not sure if the following graph would be worster case for multi-stream KV caches: - // - // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); - // - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc); - if (!gf) { - throw std::runtime_error("failed to allocate compute pp buffers"); - } - } + n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); + n_nodes_tg = ggml_graph_n_nodes(gf); + } - for (size_t i = 0; i < backend_ptrs.size(); ++i) { - ggml_backend_t backend = backend_ptrs[i]; - ggml_backend_buffer_type_t buft = backend_buft[i]; - if (!model.hparams.no_alloc) { - backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend); - } - if (backend_buf_exp_size[i] > 1) { - LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, - ggml_backend_buft_name(buft), - backend_buf_exp_size[i] / 1024.0 / 1024.0); - } + // reserve again with pp graph to avoid ggml-alloc reallocations during inference + { + // TODO: not sure if the following graph would be worster case for multi-stream KV caches: + // + // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); + // + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc); + if (!gf) { + throw std::runtime_error("failed to allocate compute pp buffers"); } + } - if (n_nodes_pp == n_nodes_tg) { - LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp); - } else { - LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg); + for (size_t i = 0; i < backend_ptrs.size(); ++i) { + ggml_backend_t backend = backend_ptrs[i]; + ggml_backend_buffer_type_t buft = backend_buft[i]; + if (!model.hparams.no_alloc) { + backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend); } - - if (n_splits_pp == n_splits_tg) { - LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp); - } else { - LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); + if (backend_buf_exp_size[i] > 1) { + LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, + ggml_backend_buft_name(buft), + backend_buf_exp_size[i] / 1024.0 / 1024.0); } } -} -llama_context::~llama_context() { - // FIXME this currently results in a use-after-free bug if the model is freed before the context - // if (!model.hparams.no_alloc) { - // for (size_t i = 0; i < backend_ptrs.size(); ++i) { - // ggml_backend_t backend = backend_ptrs[i]; - // ggml_backend_buffer_type_t buft = backend_buft[i]; - - // const size_t size_exp = backend_buf_exp_size[i]; - // const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend); - // if (size_exp == size_act) { - // LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n", - // __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); - // } else { - // LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n", - // __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); - // } - // } - // } - ggml_opt_free(opt_ctx); + if (n_nodes_pp == n_nodes_tg) { + LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp); + } else { + LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg); + } + + if (n_splits_pp == n_splits_tg) { + LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp); + } else { + LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); + } + + const int64_t t_end_us = ggml_time_us(); + + LLAMA_LOG_INFO("%s: reserve took %.2f ms, sched copies = %d\n", + __func__, (t_end_us - t_start_us)/1000.0, ggml_backend_sched_get_n_copies(sched.get())); } void llama_context::synchronize() { + if (!sched) { + return; + } + ggml_backend_sched_synchronize(sched.get()); // FIXME: if multiple single tokens are evaluated without a synchronization, @@ -614,39 +679,48 @@ enum llama_pooling_type llama_context::pooling_type() const { float * llama_context::get_logits() { output_reorder(); - return logits; + return logits.data; } -float * llama_context::get_logits_ith(int32_t i) { +int64_t llama_context::output_resolve_row(int32_t i) const { int64_t j = -1; + // support negative indices (last output row) + if (i < 0) { + j = n_outputs + i; + if (j < 0) { + throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); + } + } else if ((size_t) i >= output_ids.size()) { + throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); + } else { + // use output_ids to translate the batch token index into a row number + // that holds this token's data. + j = output_ids[i]; + } + + if (j < 0) { + // the batch token was not configured to output anything + throw std::runtime_error(format("batch.logits[%d] != true", i)); + } + + if (j >= n_outputs) { + throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); + } + + return j; +} + +float * llama_context::get_logits_ith(int32_t i) { output_reorder(); try { - if (logits == nullptr) { + if (logits.data == nullptr) { throw std::runtime_error("no logits"); } - if (i < 0) { - j = n_outputs + i; - if (j < 0) { - throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); - } - } else if ((size_t) i >= output_ids.size()) { - throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); - } else { - j = output_ids[i]; - } - - if (j < 0) { - throw std::runtime_error(format("batch.logits[%d] != true", i)); - } - if (j >= n_outputs) { - // This should not happen - throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); - } - - return logits + j*model.vocab.n_tokens(); + const int64_t j = output_resolve_row(i); + return logits.data + j*model.vocab.n_tokens(); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); #ifndef NDEBUG @@ -660,39 +734,24 @@ float * llama_context::get_logits_ith(int32_t i) { float * llama_context::get_embeddings() { output_reorder(); - return embd; + return embd.data; } -float * llama_context::get_embeddings_ith(int32_t i) { - int64_t j = -1; +llama_token * llama_context::get_sampled_tokens() const{ + return sampling.sampled.data; +} +float * llama_context::get_embeddings_ith(int32_t i) { output_reorder(); try { - if (embd == nullptr) { + if (embd.data == nullptr) { throw std::runtime_error("no embeddings"); } - if (i < 0) { - j = n_outputs + i; - if (j < 0) { - throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); - } - } else if ((size_t) i >= output_ids.size()) { - throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); - } else { - j = output_ids[i]; - } - - if (j < 0) { - throw std::runtime_error(format("batch.logits[%d] != true", i)); - } - if (j >= n_outputs) { - // This should not happen - throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); - } - - return embd + j*model.hparams.n_embd; + const int64_t j = output_resolve_row(i); + const uint32_t n_embd_out = model.hparams.n_embd_out(); + return embd.data + j*n_embd_out; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); #ifndef NDEBUG @@ -712,6 +771,137 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } +llama_token llama_context::get_sampled_token_ith(int32_t idx) { + output_reorder(); + + if (!sampling.sampled.has_data()) { + return LLAMA_TOKEN_NULL; + } + + try { + const int64_t row = output_resolve_row(idx); + GGML_ASSERT(row < (int64_t) sampling.sampled.size); + return sampling.sampled.data[row]; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what()); + return LLAMA_TOKEN_NULL; + } +} + +float * llama_context::get_sampled_probs_ith(int32_t idx) { + output_reorder(); + + if (!sampling.probs.has_data()) { + return nullptr; + } + + try { + const int64_t row = output_resolve_row(idx); + if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) { + return nullptr; + } + return sampling.probs.data + row*model.vocab.n_tokens(); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what()); + return nullptr; + } +} + +float * llama_context::get_sampled_logits_ith(int32_t idx) { + output_reorder(); + + if (!sampling.logits.has_data()) { + return nullptr; + } + + try { + const int64_t row = output_resolve_row(idx); + if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) { + return nullptr; + } + return sampling.logits.data + row*model.vocab.n_tokens(); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what()); + return nullptr; + } +} + +const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) { + output_reorder(); + + try { + const int64_t row = output_resolve_row(idx); + if (sampling.candidates.has_data() && + (size_t) row < sampling.candidates_count.size() && + sampling.candidates_count[row] > 0) { + return sampling.candidates.data + row*model.vocab.n_tokens(); + } + } catch (const std::exception & err) { + // fallback to full vocab list + GGML_UNUSED(err); + } + + return sampling.token_ids_full_vocab.data(); +} + +size_t llama_context::get_sampled_candidates_count(int32_t idx) { + output_reorder(); + + if (!sampling.candidates.has_data()) { + return 0; + } + + try { + const int64_t row = output_resolve_row(idx); + if ((size_t) row >= sampling.candidates_count.size()) { + return 0; + } + return sampling.candidates_count[row]; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled candidates count id %d, reason: %s\n", __func__, idx, err.what()); + return 0; + } +} + +size_t llama_context::get_sampled_logits_count(int32_t idx) { + output_reorder(); + + if (!sampling.logits.has_data()) { + return model.vocab.n_tokens(); + } + + try { + const int64_t row = output_resolve_row(idx); + if ((size_t) row >= sampling.logits_count.size()) { + return 0; + } + return sampling.logits_count[row]; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled logits count id %d, reason: %s\n", __func__, idx, err.what()); + return 0; + } +} + +size_t llama_context::get_sampled_probs_count(int32_t idx) { + output_reorder(); + + if (!sampling.probs.has_data()) { + return 0; + } + + try { + const int64_t row = output_resolve_row(idx); + if ((size_t) row >= sampling.probs_count.size()) { + return 0; + } + return sampling.probs_count[row]; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled probs count id %d, reason: %s\n", __func__, idx, err.what()); + return 0; + } +} + + void llama_context::attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch) { @@ -754,48 +944,117 @@ void llama_context::set_embeddings(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); cparams.embeddings = value; + + // TODO: not sure yet if we want to reserve here + //sched_need_reserve = true; } void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); + if (cparams.causal_attn == value) { + return; + } + cparams.causal_attn = value; + + sched_need_reserve = true; } void llama_context::set_warmup(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); + if (cparams.warmup == value) { + return; + } + cparams.warmup = value; + + // warmups are usually with small batches, so no need to reserve + //sched_need_reserve = true; } -void llama_context::set_adapter_lora( - llama_adapter_lora * adapter, - float scale) { - LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale); +bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { + if (!sampler && sampling.samplers.count(seq_id) == 0) { + return true; + } + + LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); - loras[adapter] = scale; -} + const bool can_offload = + sampler && + sampler->iface->backend_init && + sampler->iface->backend_apply && + llama_sampler_chain_n(sampler) > 0; + + if (sampler && can_offload) { + auto * buft = ggml_backend_dev_buffer_type(model.dev_output()); + + sampler->iface->backend_init(sampler, buft); -bool llama_context::rm_adapter_lora( - llama_adapter_lora * adapter) { - LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter); + sampling.samplers[seq_id] = sampler; + + sched_need_reserve = true; - auto pos = loras.find(adapter); - if (pos != loras.end()) { - loras.erase(pos); return true; } - return false; + if (sampler && !can_offload) { + LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id); + + if (sampling.samplers.count(seq_id) > 0) { + sched_need_reserve = true; + } + + sampling.samplers.erase(seq_id); + + return false; + } + + sampling.samplers.erase(seq_id); + + sched_need_reserve = true; + + return true; } -void llama_context::clear_adapter_lora() { - LLAMA_LOG_DEBUG("%s: call\n", __func__); +void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) { + LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters); + + if (adapters_lora_are_same(adapters, n_adapters, scales)) { + return; + } + + loras.reset(new llama_adapter_loras()); + + for (size_t i = 0; i < n_adapters; i ++) { + if (scales[i] != 0.0f) { + loras->insert({adapters[i], scales[i]}); + } + } + + sched_need_reserve = true; +} + +bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) { + LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters); - loras.clear(); + if (n_adapters != loras->size()) { + return false; + } + + for (size_t i = 0; i < n_adapters; i ++) { + auto it = loras->find(adapters[i]); + + if (it == loras->end() || it->second != scales[i]) { + return false; + } + } + + return true; } -bool llama_context::apply_adapter_cvec( +bool llama_context::set_adapter_cvec( const float * data, size_t len, int32_t n_embd, @@ -803,7 +1062,9 @@ bool llama_context::apply_adapter_cvec( int32_t il_end) { LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end); - return cvec.apply(model, data, len, n_embd, il_start, il_end); + // TODO: should we reserve? + + return cvec->apply(model, data, len, n_embd, il_start, il_end); } llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { @@ -905,6 +1166,8 @@ int llama_context::encode(const llama_batch & batch_inp) { // TODO: this clear of the buffer can easily be forgotten - need something better embd_seq.clear(); + sched_reserve(); + n_queued_tokens += n_tokens; // reserve output buffer @@ -944,16 +1207,16 @@ int llama_context::encode(const llama_batch & batch_inp) { auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); // extract logits - if (logits && t_logits) { + if (logits.data && t_logits) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); - GGML_ASSERT(logits != nullptr); + GGML_ASSERT(logits.data != nullptr); - ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float)); + ggml_backend_tensor_get_async(backend_res, t_logits, logits.data, 0, n_tokens*n_vocab*sizeof(float)); } // extract embeddings - if (embd && t_embd) { + if (embd.data && t_embd) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); @@ -961,10 +1224,11 @@ int llama_context::encode(const llama_batch & batch_inp) { case LLAMA_POOLING_TYPE_NONE: { // extract token embeddings - GGML_ASSERT(embd != nullptr); + GGML_ASSERT(embd.data != nullptr); + const uint32_t n_embd_out = hparams.n_embd_out(); - GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float)); + GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd.size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd.data, 0, n_tokens*n_embd_out*sizeof(float)); } break; case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_CLS: @@ -1012,7 +1276,7 @@ int llama_context::encode(const llama_batch & batch_inp) { cross.n_embd = t_embd->ne[0]; cross.n_enc = t_embd->ne[1]; cross.v_embd.resize(cross.n_embd*cross.n_enc); - memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd)); + memcpy(cross.v_embd.data(), embd.data, ggml_nbytes(t_embd)); const auto & batch = balloc->get_batch(); @@ -1032,6 +1296,128 @@ int llama_context::encode(const llama_batch & batch_inp) { return 0; } +static std::map build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) { + std::map seq_to_row; + // how many output tokens we have seen so far for this ubatch. + uint32_t local = 0; + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + // skip tokens that are not output. + if (!ubatch.output[i]) { + continue; + } + + const llama_seq_id seq_id = ubatch.seq_id[i][0]; + // row_offset is the number of output tokens before this ubatch. + seq_to_row[seq_id] = row_offset + local; + ++local; + } + return seq_to_row; +} + +static void copy_tensor_async_ints( + const std::map & tensor_map, + const buffer_view & sampled, + const std::map & seq_to_row, + ggml_backend_sched_t sched) { + if (!sampled.has_data()) { + return; + } + + for (const auto & [seq_id, tensor] : tensor_map) { + auto it = seq_to_row.find(seq_id); + if (it == seq_to_row.end()) { + continue; + } + + const uint32_t row = it->second; + GGML_ASSERT(row < sampled.size); + + GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy"); + + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); + ggml_backend_tensor_get_async(backend, tensor, sampled.data + row, 0, sizeof(sampled.data[row])); + } +} + +static void copy_tensor_async_floats( + const std::map & tensor_map, + const buffer_view & dst, + size_t stride, + std::vector & counts, + const std::map & seq_to_row, + ggml_backend_sched_t sched) { + if (!dst.has_data()) { + return; + } + + for (const auto & [seq_id, tensor] : tensor_map) { + auto it = seq_to_row.find(seq_id); + if (it == seq_to_row.end()) { + continue; + } + + const uint32_t row = it->second; + GGML_ASSERT(row < counts.size()); + + GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy"); + + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); + float * row_ptr = dst.data + (size_t) row * stride; + ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); + + // Update the actual number of logits/probabilities that were written for this row. + counts[row] = ggml_nelements(tensor); + } +} + +static void copy_tensor_async_candidates( + const std::map & tensor_map, + const buffer_view & dst, + size_t stride, + std::vector & counts, + const std::map & seq_to_row, + ggml_backend_sched_t sched) { + if (!dst.has_data()) { + return; + } + + for (const auto & [seq_id, tensor] : tensor_map) { + auto it = seq_to_row.find(seq_id); + if (it == seq_to_row.end()) { + continue; + } + + const uint32_t row = it->second; + GGML_ASSERT(row < counts.size()); + + GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy"); + + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); + llama_token * row_ptr = dst.data + (size_t) row * stride; + ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); + + // Update the actual number of candidates that were written. + counts[row] = ggml_nelements(tensor); + } +} + +static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map & samplers) { + for (uint32_t i = 0; i < ubatch.n_tokens; i++) { + if (!ubatch.output[i]) { + continue; + } + + // Check if the output token has at least one sequence without a backend sampler. + for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) { + llama_seq_id seq_id = ubatch.seq_id[i][j]; + if (samplers.find(seq_id) == samplers.end()) { + return true; + } + } + } + return false; // all sequences use backend sampling +} + int llama_context::decode(const llama_batch & batch_inp) { GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT @@ -1052,8 +1438,35 @@ int llama_context::decode(const llama_batch & batch_inp) { const int64_t n_embd = hparams.n_embd_inp(); const bool output_all = false; + const bool has_samplers = !sampling.samplers.empty(); - if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) { + const uint32_t n_seq_max = cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max; + + // TODO: avoid this workaround in the future + if (has_samplers && batch_inp.logits) { + std::vector seq_output_count(n_seq_max, 0); + + for (int32_t i = 0; i < batch_inp.n_tokens; ++i) { + if (batch_inp.logits[i] == 0) { + continue; + } + + const int ns = batch_inp.n_seq_id ? batch_inp.n_seq_id[i] : 1; + + for (int32_t s = 0; s < ns; ++s) { + const llama_seq_id seq_id = batch_inp.seq_id ? batch_inp.seq_id[i][s] : 0; + + seq_output_count[seq_id]++; + if (seq_output_count[seq_id] > 1) { + LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (seq_id %d had %d)\n", + __func__, seq_id, seq_output_count[seq_id]); + return -1; + } + } + } + } + + if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, n_seq_max, output_all)) { LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); return -1; } @@ -1083,6 +1496,8 @@ int llama_context::decode(const llama_batch & batch_inp) { embd_seq.clear(); output_swaps.clear(); + sched_reserve(); + bool did_optimize = false; // handle any pending shifts/copies @@ -1207,22 +1622,22 @@ int llama_context::decode(const llama_batch & batch_inp) { } // extract logits - if (t_logits && n_outputs > 0) { + if (logits.data && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); - GGML_ASSERT(logits != nullptr); + GGML_ASSERT(logits.data != nullptr); - float * logits_out = logits + n_outputs_prev*n_vocab; + float * logits_out = logits.data + n_outputs_prev*n_vocab; if (n_outputs) { GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits.size); ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); } } // extract embeddings - if (t_embd && n_outputs > 0) { + if (embd.data && t_embd && n_outputs > 0) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); @@ -1230,13 +1645,14 @@ int llama_context::decode(const llama_batch & batch_inp) { case LLAMA_POOLING_TYPE_NONE: { // extract token embeddings - GGML_ASSERT(embd != nullptr); - float * embd_out = embd + n_outputs_prev*n_embd; + GGML_ASSERT(embd.data != nullptr); + const uint32_t n_embd_out = hparams.n_embd_out(); + float * embd_out = embd.data + n_outputs_prev*n_embd_out; if (n_outputs) { GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd.size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float)); } } break; case LLAMA_POOLING_TYPE_MEAN: @@ -1276,6 +1692,19 @@ int llama_context::decode(const llama_batch & batch_inp) { } } + // Copy backend sampling output if this ubatch produced any sampling tensors. + if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) { + const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); + const auto stride = n_vocab; + + // async copy the sampling data from the backend to the host + copy_tensor_async_ints(res->t_sampled, sampling.sampled, seq_to_output_row, sched.get()); + + copy_tensor_async_floats (res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get()); + copy_tensor_async_floats (res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get()); + copy_tensor_async_candidates(res->t_candidates, sampling.candidates, stride, sampling.candidates_count, seq_to_output_row, sched.get()); + } + n_outputs_prev += n_outputs; } while (mctx->next()); @@ -1345,9 +1774,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const int64_t n_outputs_max = std::max(n_outputs, n_seq_max()); - const auto n_batch = cparams.n_batch; - const auto n_vocab = vocab.n_tokens(); - const auto n_embd = hparams.n_embd; + const auto n_batch = cparams.n_batch; + const auto n_vocab = vocab.n_tokens(); + const auto n_embd_out = hparams.n_embd_out(); bool has_logits = true; bool has_embd = cparams.embeddings; @@ -1358,8 +1787,19 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { has_embd = true; } - logits_size = has_logits ? n_vocab*n_outputs_max : 0; - embd_size = has_embd ? n_embd*n_outputs_max : 0; + + size_t backend_float_count = 0; + size_t backend_token_count = 0; + + logits.size = has_logits ? n_vocab*n_outputs_max : 0; + embd.size = has_embd ? n_embd_out*n_outputs_max : 0; + + // Allocate backend sampling output buffers if there are backend samplers configured. + const bool has_sampling = !sampling.samplers.empty(); + if (has_sampling) { + backend_float_count = 2 * n_vocab * n_outputs_max; // logits + probs + backend_token_count = (1 + n_vocab) * n_outputs_max; // sampled + candidates + } if (output_ids.empty()) { // init, never resized afterwards @@ -1367,7 +1807,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { } const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; - const size_t new_size = (logits_size + embd_size) * sizeof(float); + const size_t new_size = + (logits.size + embd.size + backend_float_count) * sizeof(float) + + ( backend_token_count) * sizeof(llama_token); // alloc only when more than the current capacity is required // TODO: also consider shrinking the buffer @@ -1375,12 +1817,14 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { if (buf_output) { #ifndef NDEBUG // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark) - LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); + LLAMA_LOG_DEBUG("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); #endif synchronize(); + + // TODO: not needed? buf_output = nullptr; - logits = nullptr; - embd = nullptr; + logits.data = nullptr; + embd.data = nullptr; } auto * buft = ggml_backend_cpu_buffer_type(); @@ -1399,8 +1843,50 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get()); - logits = has_logits ? output_base : nullptr; - embd = has_embd ? output_base + logits_size : nullptr; + size_t offset = 0; + uint8_t * base = (uint8_t *) output_base; + + logits = has_logits ? buffer_view{output_base, logits.size} : buffer_view{nullptr, 0}; + offset += logits.size * sizeof(float); + + embd = has_embd ? buffer_view{(float *) (base + offset), embd.size} : buffer_view{nullptr, 0}; + offset += embd.size * sizeof(float); + + if (has_sampling) { + sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; + offset += sampling.logits.size * sizeof(float); + + sampling.probs = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; + offset += sampling.probs.size * sizeof(float); + + sampling.sampled = {(llama_token *) (base + offset), (size_t)n_outputs_max}; + offset += sampling.sampled.size * sizeof(llama_token); + + sampling.candidates = {(llama_token *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; + offset += sampling.candidates.size * sizeof(llama_token); + + // The count vectors keep track of the actual number of logits/probs/candidates + // copied from the backend for each output row. + + sampling.logits_count.resize(n_outputs_max); + sampling.probs_count.resize(n_outputs_max); + sampling.candidates_count.resize(n_outputs_max); + + std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0); + std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0); + std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0); + + std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL); + } else { + sampling.logits = {nullptr, 0}; + sampling.probs = {nullptr, 0}; + sampling.sampled = {nullptr, 0}; + sampling.candidates = {nullptr, 0}; + + sampling.logits_count.clear(); + sampling.probs_count.clear(); + sampling.candidates_count.clear(); + } // set all ids as invalid (negative) std::fill(output_ids.begin(), output_ids.end(), -1); @@ -1418,17 +1904,44 @@ void llama_context::output_reorder() { const uint64_t i0 = output_swaps[s].i0; const uint64_t i1 = output_swaps[s].i1; - if (logits_size > 0) { + if (logits.size > 0) { for (uint64_t k = 0; k < n_vocab; k++) { - std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]); + std::swap(logits.data[i0*n_vocab + k], logits.data[i1*n_vocab + k]); } } - if (embd_size > 0) { + if (embd.size > 0) { for (uint64_t k = 0; k < n_embd; k++) { - std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]); + std::swap(embd.data[i0*n_embd + k], embd.data[i1*n_embd + k]); } } + + if (!sampling.samplers.empty()) { + assert(sampling.logits.size > 0); + assert(sampling.probs.size > 0); + assert(sampling.candidates.size > 0); + assert(sampling.sampled.size > 0); + assert(sampling.logits_count.size() > 0); + assert(sampling.probs_count.size() > 0); + assert(sampling.candidates_count.size() > 0); + + for (uint64_t k = 0; k < n_vocab; ++k) { + std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]); + } + + for (uint64_t k = 0; k < n_vocab; ++k) { + std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]); + } + + for (uint64_t k = 0; k < n_vocab; ++k) { + std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]); + } + + std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]); + std::swap(sampling.logits_count[i0], sampling.logits_count[i1]); + std::swap(sampling.probs_count[i0], sampling.probs_count[i1]); + std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]); + } } output_swaps.clear(); @@ -1439,10 +1952,14 @@ void llama_context::output_reorder() { // uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const { - if (model.arch == LLM_ARCH_QWEN3NEXT) { + if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR || model.arch == LLM_ARCH_QWEN35 || model.arch == LLM_ARCH_QWEN35MOE) { return std::max(n_tokens * 40, 32u * model.n_tensors()); } - return std::max(1024u, 8u*model.n_tensors()); + uint32_t res = std::max(1024u, 8u*model.n_tensors()); + for (const auto & lora : model.loras) { + res += lora->get_n_nodes(); + } + return res; } llm_graph_result * llama_context::get_gf_res_reserve() const { @@ -1456,7 +1973,7 @@ ggml_cgraph * llama_context::graph_reserve( if (n_tokens % n_seqs != 0) { n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs - n_outputs = std::min(n_outputs, n_tokens); + n_outputs = std::max(n_outputs, n_tokens); LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs); } @@ -1475,6 +1992,15 @@ ggml_cgraph * llama_context::graph_reserve( llama_batch_allocr balloc(model.hparams.n_pos_per_embd()); llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs); + // set one output token per sequence in order to activate all backend samplers + std::vector seq_ids(n_seqs); + for (uint32_t i = 0; i < n_seqs; ++i) { + seq_ids[i] = i; + ubatch.n_seq_id[i] = 1; + ubatch.seq_id[i] = &seq_ids[i]; + ubatch.output[i] = true; + } + auto * res = gf_res_reserve.get(); const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); @@ -1505,7 +2031,7 @@ llm_graph_params llama_context::graph_params( llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx, - llm_graph_type gtype) const { + llm_graph_type gtype) const { return { /*.arch =*/ model.arch, /*.hparams =*/ model.hparams, @@ -1514,10 +2040,11 @@ llm_graph_params llama_context::graph_params( /*.gtype =*/ gtype, /*.sched =*/ sched.get(), /*.backend_cpu =*/ backend_cpu, - /*.cvec =*/ &cvec, - /*.loras =*/ &loras, + /*.cvec =*/ cvec.get(), + /*.loras =*/ loras.get(), /*.mctx =*/ mctx, /*.cross =*/ &cross, + /*.samplers =*/ sampling.samplers, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), /*.res =*/ res, @@ -1561,16 +2088,9 @@ llm_graph_cb llama_context::graph_get_cb() const { ggml_set_name(cur, name); } - if (!cparams.offload_kqv) { - if (strcmp(name, "kqv_merged_cont") == 0) { - // all nodes between the KV store and the attention output are run on the CPU - ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu); - } - } - // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends // FIXME: fix in ggml_backend_sched - const bool full_offload = model.params.n_gpu_layers > (int) model.hparams.n_layer; + const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer; if (ubatch.n_tokens < 32 || full_offload) { if (il != -1 && strcmp(name, "norm") == 0) { const auto & dev_layer = model.dev_layer(il); @@ -1919,60 +2439,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { // TODO: add more model-specific info which should prevent loading the session file if not identical } - // write output ids - { - LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__); - - const auto n_outputs = this->n_outputs; - const auto & output_ids = this->output_ids; - - std::vector w_output_pos; - - w_output_pos.resize(n_outputs); - - // build a more compact representation of the output ids - for (size_t i = 0; i < n_batch(); ++i) { - // map an output id to a position in the batch - int64_t pos = output_ids[i]; - if (pos >= 0) { - GGML_ASSERT(pos < n_outputs); - w_output_pos[pos] = i; - } - } - - io.write(&n_outputs, sizeof(n_outputs)); - - if (n_outputs) { - io.write(w_output_pos.data(), n_outputs * sizeof(int32_t)); - } - } - - // write logits - { - LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__); - - const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens()); - - io.write(&logits_size, sizeof(logits_size)); - - if (logits_size) { - io.write(logits, logits_size * sizeof(float)); - } - } - - // write embeddings - { - LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__); - - const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd); - - io.write(&embd_size, sizeof(embd_size)); - - if (embd_size) { - io.write(embd, embd_size * sizeof(float)); - } - } - if (memory != nullptr) { LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__); memory->state_write(io); @@ -1998,67 +2464,6 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { // TODO: add more info which needs to be identical but which is not verified otherwise } - // read output ids - { - LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__); - - auto n_outputs = this->n_outputs; - io.read_to(&n_outputs, sizeof(n_outputs)); - - if (n_outputs > output_reserve(n_outputs)) { - throw std::runtime_error("could not reserve outputs"); - } - - std::vector output_pos; - - if (n_outputs) { - output_pos.resize(n_outputs); - io.read_to(output_pos.data(), n_outputs * sizeof(int32_t)); - - for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) { - int32_t id = output_pos[i]; - if ((uint32_t) id >= n_batch()) { - throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch())); - } - this->output_ids[id] = i; - } - - this->n_outputs = n_outputs; - } - } - - // read logits - { - LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__); - - uint64_t logits_size; - io.read_to(&logits_size, sizeof(logits_size)); - - if (this->logits_size < logits_size) { - throw std::runtime_error("logits buffer too small"); - } - - if (logits_size) { - io.read_to(this->logits, logits_size * sizeof(float)); - } - } - - // read embeddings - { - LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__); - - uint64_t embd_size; - io.read_to(&embd_size, sizeof(embd_size)); - - if (this->embd_size < embd_size) { - throw std::runtime_error("embeddings buffer too small"); - } - - if (embd_size) { - io.read_to(this->embd, embd_size * sizeof(float)); - } - } - if (memory) { LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__); @@ -2191,6 +2596,7 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params llama_set_param(model->cls_b, param_filter, param_filter_ud); llama_set_param(model->cls_out, param_filter, param_filter_ud); llama_set_param(model->cls_out_b, param_filter, param_filter_ud); + llama_set_param(model->cls_norm, param_filter, param_filter_ud); for (struct llama_layer & layer : model->layers) { for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) { @@ -2282,7 +2688,7 @@ void llama_context::opt_epoch_iter( }; ctx_compute_opt = ggml_init(params); } - ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits()); + ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_inp_tokens(), res->get_logits()); ggml_opt_alloc(opt_ctx, train); res->set_inputs(&ubatch); @@ -2392,6 +2798,8 @@ llama_context_params llama_context_default_params() { /*.op_offload =*/ true, /*.swa_full =*/ true, /*.kv_unified =*/ false, + /*.sampler =*/ nullptr, + /*.n_sampler =*/ 0, }; return result; @@ -2551,7 +2959,15 @@ float * llama_get_logits(llama_context * ctx) { float * llama_get_logits_ith(llama_context * ctx, int32_t i) { ctx->synchronize(); - return ctx->get_logits_ith(i); + float * res = nullptr; + + res = ctx->get_sampled_logits_ith(i); + + if (!res) { + res = ctx->get_logits_ith(i); + } + + return res; } float * llama_get_embeddings(llama_context * ctx) { @@ -2572,37 +2988,76 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } -// llama adapter API +bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { + return ctx->set_sampler(seq_id, smpl); +} -int32_t llama_set_adapter_lora( - llama_context * ctx, - llama_adapter_lora * adapter, - float scale) { - ctx->set_adapter_lora(adapter, scale); +llama_token llama_get_sampled_token_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); - return 0; + return ctx->get_sampled_token_ith(i); } -int32_t llama_rm_adapter_lora( - llama_context * ctx, - llama_adapter_lora * adapter) { - bool res = ctx->rm_adapter_lora(adapter); +float * llama_get_sampled_probs_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); - return res ? 0 : -1; + return ctx->get_sampled_probs_ith(i); +} + +float * llama_get_sampled_logits_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_sampled_logits_ith(i); } -void llama_clear_adapter_lora(llama_context * ctx) { - ctx->clear_adapter_lora(); +llama_token * llama_get_sampled_candidates_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return const_cast(ctx->get_sampled_candidates_ith(i)); +} + +uint32_t llama_get_sampled_candidates_count_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return static_cast(ctx->get_sampled_candidates_count(i)); +} + +uint32_t llama_get_sampled_logits_count_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return static_cast(ctx->get_sampled_logits_count(i)); +} + +uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return static_cast(ctx->get_sampled_probs_count(i)); +} + +// llama adapter API + +int32_t llama_set_adapters_lora( + llama_context * ctx, + llama_adapter_lora ** adapters, + size_t n_adapters, + float * scales) { + if (adapters == nullptr || scales == nullptr) { + GGML_ASSERT(n_adapters == 0 && "invalid llama_set_adapters_lora call"); + } + + ctx->set_adapters_lora(adapters, n_adapters, scales); + + return 0; } -int32_t llama_apply_adapter_cvec( +int32_t llama_set_adapter_cvec( llama_context * ctx, - const float * data, - size_t len, - int32_t n_embd, - int32_t il_start, - int32_t il_end) { - bool res = ctx->apply_adapter_cvec(data, len, n_embd, il_start, il_end); + const float * data, + size_t len, + int32_t n_embd, + int32_t il_start, + int32_t il_end) { + bool res = ctx->set_adapter_cvec(data, len, n_embd, il_start, il_end); return res ? 0 : -1; } diff --git a/llama/llama.cpp/src/llama-context.h b/llama/llama.cpp/src/llama-context.h index c31101330e2..e0d0085c1c3 100644 --- a/llama/llama.cpp/src/llama-context.h +++ b/llama/llama.cpp/src/llama-context.h @@ -4,6 +4,7 @@ #include "llama-cparams.h" #include "llama-graph.h" #include "llama-adapter.h" +#include "llama-impl.h" #include "ggml-cpp.h" #include "ggml-opt.h" @@ -40,6 +41,14 @@ struct llama_context { ~llama_context(); + // reserve a new backend scheduler (if needed) + // for example, when: + // - changing loras + // - changing samplers + // - changing attention type + // - etc. + void sched_reserve(); + void synchronize(); const llama_model & get_model() const; @@ -70,6 +79,18 @@ struct llama_context { float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); + llama_token * get_sampled_tokens() const; + llama_token get_sampled_token_ith(int32_t idx); + + float * get_sampled_logits_ith(int32_t idx); + size_t get_sampled_logits_count(int32_t idx); + + float * get_sampled_probs_ith(int32_t idx); + size_t get_sampled_probs_count(int32_t idx); + + const llama_token * get_sampled_candidates_ith(int32_t idx); + size_t get_sampled_candidates_count(int32_t idx); + void attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch); @@ -84,16 +105,11 @@ struct llama_context { void set_causal_attn(bool value); void set_warmup(bool value); - void set_adapter_lora( - llama_adapter_lora * adapter, - float scale); - - bool rm_adapter_lora( - llama_adapter_lora * adapter); + void set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales); - void clear_adapter_lora(); + bool adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales); - bool apply_adapter_cvec( + bool set_adapter_cvec( const float * data, size_t len, int32_t n_embd, @@ -196,6 +212,9 @@ struct llama_context { void output_reorder(); + // map the output row index `i` to batch index + int64_t output_resolve_row(int32_t i) const; + // // graph // @@ -213,6 +232,8 @@ struct llama_context { ggml_cgraph * graph_reserve( uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false, size_t * sizes = nullptr); + bool set_sampler(llama_seq_id seq_id, llama_sampler * sampler); + private: llm_graph_params graph_params( llm_graph_result * res, @@ -235,22 +256,40 @@ struct llama_context { const llama_model & model; - llama_cparams cparams; - llama_adapter_cvec cvec; - llama_adapter_loras loras; + llama_cparams cparams; + + llama_adapter_cvec_ptr cvec; + llama_adapter_loras_ptr loras; llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably std::unique_ptr memory; // decode output (2-dimensional array: [n_outputs][n_vocab]) - size_t logits_size = 0; // capacity (of floats) for logits - float * logits = nullptr; + buffer_view logits = {nullptr, 0}; // embeddings output (2-dimensional array: [n_outputs][n_embd]) // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE - size_t embd_size = 0; // capacity (of floats) for embeddings - float * embd = nullptr; + buffer_view embd = {nullptr, 0}; + + struct sampling_info { + // !samplers.empty() to check if any samplers are active + std::map samplers; + + buffer_view logits = {nullptr, 0}; + buffer_view sampled = {nullptr, 0}; + buffer_view probs = {nullptr, 0}; + buffer_view candidates = {nullptr, 0}; + + std::vector logits_count; + std::vector probs_count; + std::vector candidates_count; + + // optimization + std::vector token_ids_full_vocab; + }; + + sampling_info sampling; // sequence embeddings output (map of [n_embd] vectors) // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE @@ -272,6 +311,8 @@ struct llama_context { ggml_backend_sched_ptr sched; + bool sched_need_reserve = true; + ggml_backend_t backend_cpu = nullptr; std::vector backends; diff --git a/llama/llama.cpp/src/llama-cparams.h b/llama/llama.cpp/src/llama-cparams.h index fcef8fa9760..2da3bbd6f94 100644 --- a/llama/llama.cpp/src/llama-cparams.h +++ b/llama/llama.cpp/src/llama-cparams.h @@ -30,10 +30,12 @@ struct llama_cparams { bool causal_attn; bool offload_kqv; bool flash_attn; + bool auto_fa; bool no_perf; bool warmup; bool op_offload; bool kv_unified; + bool pipeline_parallel; enum llama_pooling_type pooling_type; diff --git a/llama/llama.cpp/src/llama-grammar.cpp b/llama/llama.cpp/src/llama-grammar.cpp index a0299d18162..9d3b896a658 100644 --- a/llama/llama.cpp/src/llama-grammar.cpp +++ b/llama/llama.cpp/src/llama-grammar.cpp @@ -2,7 +2,7 @@ #include "llama-impl.h" #include "llama-vocab.h" -#include "llama-sampling.h" +#include "llama-sampler.h" #include #include @@ -369,6 +369,44 @@ static void print_rule( fprintf(file, "\n"); } +// +// Regex utilities +// + +size_t llama_grammar_trigger_pattern::find(const std::string & input) const { + auto find_start_pos = [](const std::smatch & match) { + // get from the first matched capturing group to the end of the string + size_t start = std::string::npos; + for (auto i = 1u; i < match.size(); i++) { + if (match.length(i) > 0) { + start = match.position(i); + break; + } + } + if (start == std::string::npos) { + start = match.position(0); + } + return start; + }; + + if (!pattern.empty() && pattern.front() == '^' && pattern.back() == '$') { + // match against the entire input + std::smatch match; + if (std::regex_match(input, match, regex)) { + return find_start_pos(match); + } + } + + // search anywhere + std::smatch match; + if (std::regex_search(input, match, regex)) { + return find_start_pos(match); + } + + return std::string::npos; +} + + // // implementation // @@ -1321,21 +1359,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token grammar.trigger_buffer_positions.push_back(std::make_pair(token, position)); grammar.trigger_buffer += piece; - std::smatch match; for (const auto & trigger_pattern : grammar.trigger_patterns) { - if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) { + auto start = trigger_pattern.find(grammar.trigger_buffer); + if (start != std::string::npos) { grammar.awaiting_trigger = false; - // get from the first matched capturing group to the end of the string - size_t start = std::string::npos; - for (auto i = 1u; i < match.size(); i++) { - if (match.length(i) > 0) { - start = match.position(i); - break; - } - } - if (start == std::string::npos) { - start = match.position(0); - } // replay tokens that overlap with [start, end) for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) { diff --git a/llama/llama.cpp/src/llama-grammar.h b/llama/llama.cpp/src/llama-grammar.h index 5c0da404938..57847583ada 100644 --- a/llama/llama.cpp/src/llama-grammar.h +++ b/llama/llama.cpp/src/llama-grammar.h @@ -130,6 +130,8 @@ struct llama_grammar_parser { struct llama_grammar_trigger_pattern { std::string pattern; std::regex regex; + + size_t find(const std::string & input) const; }; struct llama_grammar { diff --git a/llama/llama.cpp/src/llama-graph.cpp b/llama/llama.cpp/src/llama-graph.cpp index 1d0d7197e1f..23a86ea2905 100644 --- a/llama/llama.cpp/src/llama-graph.cpp +++ b/llama/llama.cpp/src/llama-graph.cpp @@ -7,11 +7,50 @@ #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" #include "llama-memory-hybrid.h" +#include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" #include #include #include +#include +#include +#include + +// dedup helpers + +static ggml_tensor * build_kq_mask( + ggml_context * ctx, + const llama_kv_cache_context * mctx, + const llama_ubatch & ubatch, + const llama_cparams & cparams) { + const auto n_kv = mctx->get_n_kv(); + const auto n_tokens = ubatch.n_tokens; + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + + return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); +} + +static bool can_reuse_kq_mask( + ggml_tensor * kq_mask, + const llama_kv_cache_context * mctx, + const llama_ubatch & ubatch, + const llama_cparams & cparams) { + const auto n_kv = mctx->get_n_kv(); + const auto n_tokens = ubatch.n_tokens; + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + + bool res = true; + + res &= (kq_mask->ne[0] == n_kv); + res &= (kq_mask->ne[1] == n_tokens/n_stream); + res &= (kq_mask->ne[2] == 1); + res &= (kq_mask->ne[3] == n_stream); + + return res; +} + +// impl void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { if (ubatch->token) { @@ -21,7 +60,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { } if (ubatch->embd) { - const int64_t n_embd = embd->ne[0]; + GGML_ASSERT(n_embd == embd->ne[0]); + const int64_t n_tokens = ubatch->n_tokens; ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd)); @@ -31,8 +71,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) { bool res = true; - res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens); - res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens); + res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens); + res &= (!params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens); return res; } @@ -62,7 +102,7 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) { bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) { bool res = true; - res &= pos->ne[0] == params.ubatch.n_tokens; + res &= pos->ne[0] == params.ubatch.n_tokens*n_pos_per_embd; return res; } @@ -95,11 +135,9 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) { int32_t * data = (int32_t *) pos_bucket->data; - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - for (int i = 0; i < n_tokens; ++i) { - data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true); - } + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_tokens; ++i) { + data[j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true); } } } @@ -147,7 +185,10 @@ bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) { } void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) { - if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { + if (cparams.embeddings && + (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN || + cparams.pooling_type == LLAMA_POOLING_TYPE_RANK )) { + const int64_t n_tokens = ubatch->n_tokens; const int64_t n_seq_tokens = ubatch->n_seq_tokens; const int64_t n_seqs_unq = ubatch->n_seqs_unq; @@ -322,34 +363,32 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { const int64_t n_tokens = ubatch->n_tokens; const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) { - for (int h = 0; h < 1; ++h) { - for (int i1 = 0; i1 < n_tokens; ++i1) { - const llama_seq_id s1 = ubatch->seq_id[i1][0]; - const llama_pos p1 = ubatch->pos[i1]; + for (int i1 = 0; i1 < n_tokens; ++i1) { + const llama_seq_id s1 = ubatch->seq_id[i1][0]; + const llama_pos p1 = ubatch->pos[i1]; - const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv; + const uint64_t idst = i1*n_kv; - for (int i0 = 0; i0 < n_tokens; ++i0) { - const llama_seq_id s0 = ubatch->seq_id[i0][0]; - const llama_pos p0 = ubatch->pos[i0]; + for (int i0 = 0; i0 < n_tokens; ++i0) { + const llama_seq_id s0 = ubatch->seq_id[i0][0]; + const llama_pos p0 = ubatch->pos[i0]; - // mask different sequences - if (s0 != s1) { - continue; - } - - // mask future tokens - if (cparams.causal_attn && p0 > p1) { - continue; - } + // mask different sequences + if (s0 != s1) { + continue; + } - // apply SWA if any - if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) { - continue; - } + // mask future tokens + if (cparams.causal_attn && p0 > p1) { + continue; + } - data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; + // apply SWA if any + if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) { + continue; } + + data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; } } }; @@ -402,8 +441,27 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) { res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= self_kq_mask->ne[0] == mctx->get_n_kv(); - res &= self_kq_mask->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams); + + return res; +} + +void llm_graph_input_attn_k::set_input(const llama_ubatch * ubatch) { + mctx->set_input_k_idxs(self_k_idxs, ubatch); + + mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); +} + +bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; + + res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams); return res; } @@ -433,11 +491,8 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv(); - res &= self_kq_mask->ne[1] == params.ubatch.n_tokens; - - res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv(); - res &= self_kq_mask_swa->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); + res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); return res; } @@ -453,27 +508,19 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { float * data = (float *) cross_kq_mask->data; - for (int h = 0; h < 1; ++h) { - for (int i = 0; i < n_tokens; ++i) { - for (int j = 0; j < n_enc; ++j) { - float f = -INFINITY; + for (int i = 0; i < n_tokens; ++i) { + for (int j = 0; j < n_enc; ++j) { + float f = -INFINITY; - for (int s = 0; s < ubatch->n_seq_id[i]; ++s) { - const llama_seq_id seq_id = ubatch->seq_id[i][s]; + for (int s = 0; s < ubatch->n_seq_id[i]; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[i][s]; - if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) { - f = 0.0f; - } + if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) { + f = 0.0f; } - - data[h*(n_enc*n_tokens) + i*n_enc + j] = f; } - } - for (int i = n_tokens; i < n_tokens; ++i) { - for (int j = 0; j < n_enc; ++j) { - data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY; - } + data[i*n_enc + j] = f; } } } @@ -507,8 +554,7 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv(); - res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams); res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); @@ -521,6 +567,154 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { return res; } +// TODO: Hybrid input classes are a bit redundant. +// Instead of creating a hybrid input, the graph can simply create 2 separate inputs. +// Refactoring is required in the future. +void llm_graph_input_mem_hybrid_k::set_input(const llama_ubatch * ubatch) { + mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); + + mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + + const int64_t n_rs = mctx->get_recr()->get_n_rs(); + + if (inp_rs->s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer)); + int32_t * data = (int32_t *) inp_rs->s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + data[i] = mctx->get_recr()->s_copy(i); + } + } +} + +bool llm_graph_input_mem_hybrid_k::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; + + res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams); + + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); + + res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + + res &= inp_rs->head == mctx->get_recr()->get_head(); + res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z(); + + return res; +} + +void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { + const auto * attn_ctx = mctx->get_attn(); + + // base tensors may not be allocated if there are no non-SWA attention layers + if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) { + attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); + attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch); + + attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + } + + // swa tensors may not be allocated if there are no SWA attention layers + if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) { + attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch); + attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch); + + attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn); + } + + const int64_t n_rs = mctx->get_recr()->get_n_rs(); + + if (inp_rs->s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer)); + int32_t * data = (int32_t *) inp_rs->s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + data[i] = mctx->get_recr()->s_copy(i); + } + } +} + +bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + const auto * attn_ctx = mctx->get_attn(); + + // base tensors may not be allocated if there are no non-SWA attention layers + if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) { + res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams); + } + + // swa tensors may not be allocated if there are no SWA attention layers + if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) { + res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; + //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams); + } + + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); + + res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + + res &= inp_rs->head == mctx->get_recr()->get_head(); + res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z(); + + return res; +} + +void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) { + // set the inputs only for the active samplers in the current ubatch + std::unordered_set active_samplers; + for (uint32_t i = 0; i < ubatch->n_tokens; i++) { + if (ubatch->output[i]) { + llama_seq_id seq_id = ubatch->seq_id[i][0]; + active_samplers.insert(seq_id); + } + } + + for (auto seq_id : active_samplers) { + if (samplers.find(seq_id) == samplers.end()) { + continue; + } + + auto & sampler = samplers[seq_id]; + + if (sampler->iface->backend_set_input) { + sampler->iface->backend_set_input(sampler); + } + } +} + +bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) { + if (samplers.size() != params.samplers.size()) { + return false; + } + + for (const auto & [seq_id, sampler] : params.samplers) { + if (samplers[seq_id] != sampler) { + return false; + } + } + + return true; +} + // // llm_graph_result // @@ -537,10 +731,15 @@ int64_t llm_graph_result::get_max_nodes() const { } void llm_graph_result::reset() { - t_tokens = nullptr; + t_inp_tokens = nullptr; + t_inp_embd = nullptr; t_logits = nullptr; t_embd = nullptr; t_embd_pooled = nullptr; + t_sampled.clear(); + t_sampled_probs.clear(); + t_sampled_logits.clear(); + t_candidates.clear(); params = {}; @@ -565,6 +764,38 @@ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) { } } +void llm_graph_result::set_outputs() { + if (t_logits != nullptr) { + ggml_set_output(t_logits); + } + if (t_embd != nullptr) { + ggml_set_output(t_embd); + } + if (t_embd_pooled != nullptr) { + ggml_set_output(t_embd_pooled); + } + for (auto & [seq_id, t] : t_sampled) { + if (t != nullptr) { + ggml_set_output(t); + } + } + for (auto & [seq_id, t] : t_sampled_probs) { + if (t != nullptr) { + ggml_set_output(t); + } + } + for (auto & [seq_id, t] : t_sampled_logits) { + if (t != nullptr) { + ggml_set_output(t); + } + } + for (auto & [seq_id, t] : t_candidates) { + if (t != nullptr) { + ggml_set_output(t); + } + } +} + bool llm_graph_result::can_reuse(const llm_graph_params & params) { if (!this->params.allow_reuse(params)) { if (debug > 1) { @@ -646,6 +877,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : loras (params.loras), mctx (params.mctx), cross (params.cross), + samplers (params.samplers), cb_func (params.cb), res (params.res), ctx0 (res->get_ctx()), @@ -813,6 +1045,26 @@ ggml_tensor * llm_graph_context::build_ffn( switch (type_op) { case LLM_FFN_SILU: if (gate && type_gate == LLM_FFN_PAR) { + // Step35: HF clamps gate (after SiLU) and up before multiplication + if (arch == LLM_ARCH_STEP35 && il >= 0) { + const float limit = hparams.swiglu_clamp_shexp[il]; + constexpr float eps = 1e-6f; + if (limit > eps) { + ggml_tensor * gate_act = ggml_silu(ctx0, cur); + cb(gate_act, "ffn_silu", il); + gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit); + cb(gate_act, "ffn_silu_clamped", il); + + tmp = ggml_clamp(ctx0, tmp, -limit, limit); + cb(tmp, "ffn_up_clamped", il); + + cur = ggml_mul(ctx0, gate_act, tmp); + cb(cur, "ffn_swiglu_limited", il); + type_gate = LLM_FFN_SEQ; + break; + } + } + cur = ggml_swiglu_split(ctx0, cur, tmp); cb(cur, "ffn_swiglu", il); type_gate = LLM_FFN_SEQ; @@ -876,8 +1128,8 @@ ggml_tensor * llm_graph_context::build_ffn( if (down) { cur = build_lora_mm(down, cur); - if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { - // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) { + // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } } @@ -913,7 +1165,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn( float w_scale, llama_expert_gating_func_type gating_op, int il, - ggml_tensor * probs_in) const { + ggml_tensor * probs_in, + ggml_tensor * gate_up_exps) const { return build_moe_ffn( cur, gate_inp, /* gate_inp_b */ nullptr, @@ -929,7 +1182,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn( w_scale, gating_op, il, - probs_in + probs_in, + gate_up_exps ); } @@ -952,7 +1206,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn( float w_scale, llama_expert_gating_func_type gating_op, int il, - ggml_tensor * probs_in) const { + ggml_tensor * probs_in, + ggml_tensor * gate_up_exps, + ggml_tensor * gate_up_exps_b) const { const int64_t n_embd = cur->ne[0]; const int64_t n_tokens = cur->ne[1]; const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN @@ -1091,30 +1347,73 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(cur, "ffn_moe_weighted", il); } - ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] - cb(up, "ffn_moe_up", il); + ggml_tensor * up = nullptr; + ggml_tensor * experts = nullptr; - if (up_exps_b) { - up = ggml_add_id(ctx0, up, up_exps_b, selected_experts); - cb(up, "ffn_moe_up_biased", il); - } + if (gate_up_exps) { + // merged gate_up path: one mul_mat_id, then split into gate and up views + ggml_tensor * gate_up = build_lora_mm_id(gate_up_exps, cur, selected_experts); // [n_ff*2, n_expert_used, n_tokens] + cb(gate_up, "ffn_moe_gate_up", il); - ggml_tensor * experts = nullptr; - if (gate_exps) { - cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + if (gate_up_exps_b) { + gate_up = ggml_add_id(ctx0, gate_up, gate_up_exps_b, selected_experts); + cb(gate_up, "ffn_moe_gate_up_biased", il); + } + + const int64_t n_ff = gate_up->ne[0] / 2; + cur = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], 0); cb(cur, "ffn_moe_gate", il); + up = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], n_ff * gate_up->nb[0]); + cb(up, "ffn_moe_up", il); } else { - cur = up; - } + // separate gate and up path + up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(up, "ffn_moe_up", il); + + if (up_exps_b) { + up = ggml_add_id(ctx0, up, up_exps_b, selected_experts); + cb(up, "ffn_moe_up_biased", il); + } + + if (gate_exps) { + cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(cur, "ffn_moe_gate", il); + } else { + cur = up; + } - if (gate_exps_b) { - cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts); - cb(cur, "ffn_moe_gate_biased", il); + if (gate_exps_b) { + cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts); + cb(cur, "ffn_moe_gate_biased", il); + } } + const bool has_gate = gate_exps || gate_up_exps; + switch (type_op) { case LLM_FFN_SILU: if (gate_exps) { + // Step35: per-layer clamp for routed experts + if (arch == LLM_ARCH_STEP35 && il >= 0) { + const float limit = hparams.swiglu_clamp_exp[il]; + constexpr float eps = 1e-6f; + if (limit > eps) { + ggml_tensor * gate_act = ggml_silu(ctx0, cur); + cb(gate_act, "ffn_moe_silu", il); + gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit); + cb(gate_act, "ffn_moe_silu_clamped", il); + + up = ggml_clamp(ctx0, up, -limit, limit); + cb(up, "ffn_moe_up_clamped", il); + + cur = ggml_mul(ctx0, gate_act, up); + cb(cur, "ffn_moe_swiglu_limited", il); + break; + } + } + } + + if (has_gate) { cur = ggml_swiglu_split(ctx0, cur, up); cb(cur, "ffn_moe_swiglu", il); } else { @@ -1122,7 +1421,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(cur, "ffn_moe_silu", il); } break; case LLM_FFN_GELU: - if (gate_exps) { + if (has_gate) { cur = ggml_geglu_split(ctx0, cur, up); cb(cur, "ffn_moe_geglu", il); } else { @@ -1138,7 +1437,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(cur, "ffn_moe_swiglu_oai", il); } break; case LLM_FFN_RELU: - if (gate_exps) { + if (has_gate) { cur = ggml_reglu_split(ctx0, cur, up); cb(cur, "ffn_moe_reglu", il); } else { @@ -1146,7 +1445,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(cur, "ffn_moe_relu", il); } break; case LLM_FFN_RELU_SQR: - if (gate_exps) { + if (has_gate) { // TODO: add support for gated squared relu GGML_ABORT("fatal error: gated squared relu not implemented"); } else { @@ -1204,17 +1503,29 @@ ggml_tensor * llm_graph_context::build_moe_ffn( // input embeddings with optional lora ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { - const int64_t n_embd = hparams.n_embd_inp(); + const int64_t n_embd_inp = hparams.n_embd_inp(); + const int64_t n_embd = hparams.n_embd; + + assert(n_embd_inp >= n_embd); - auto inp = std::make_unique(); + auto inp = std::make_unique(n_embd_inp); - ggml_tensor * cur = nullptr; + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + cb(inp->tokens, "inp_tokens", -1); + ggml_set_input(inp->tokens); + res->t_inp_tokens = inp->tokens; - if (ubatch.token) { - inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); - //cb(inp->tokens, "inp_tokens", -1); - ggml_set_input(inp->tokens); - res->t_tokens = inp->tokens; + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_inp, ubatch.n_tokens); + cb(inp->embd, "inp_embd", -1); + ggml_set_input(inp->embd); + + // select one of the 2 inputs, based on the batch contents + // ref: https://github.com/ggml-org/llama.cpp/pull/18550 + std::array inps; + + // token embeddings path (ubatch.token != nullptr) + { + auto & cur = inps[0]; cur = ggml_get_rows(ctx0, tok_embd, inp->tokens); @@ -1235,22 +1546,43 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { cur = ggml_add(ctx0, cur, inpL_delta); } - } else { - inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens); - ggml_set_input(inp->embd); + + if (n_embd_inp != n_embd) { + cur = ggml_pad(ctx0, cur, hparams.n_embd_inp() - n_embd, 0, 0, 0); + } + } + + // vector embeddings path (ubatch.embd != nullptr) + { + auto & cur = inps[1]; cur = inp->embd; } + assert(ggml_are_same_shape (inps[0], inps[1])); + assert(ggml_are_same_stride(inps[0], inps[1])); + + ggml_tensor * cur = ggml_build_forward_select(gf, inps.data(), inps.size(), ubatch.token ? 0 : 1); + + if (n_embd_inp != n_embd) { + cur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0); + } + + res->t_inp_embd = cur; + // For Granite architecture if (hparams.f_embedding_scale != 0.0f) { cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale); } - cb(cur, "inp_embd", -1); + cb(cur, "embd", -1); res->add_input(std::move(inp)); + // make sure the produced embeddings are immediately materialized in the ggml graph + // ref: https://github.com/ggml-org/llama.cpp/pull/18599 + ggml_build_forward_expand(gf, cur); + return cur; } @@ -1342,7 +1674,7 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const { //} const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp(); - const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; + const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc); ggml_set_input(cur); @@ -1420,7 +1752,8 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_tensor * cur; - if (cparams.flash_attn && kq_b == nullptr) { + const bool use_flash_attn = cparams.flash_attn && kq_b == nullptr; + if (use_flash_attn) { GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet"); if (v_trans) { @@ -1616,14 +1949,11 @@ static std::unique_ptr build_attn_inp_kv_impl( { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA"); - const auto n_kv = mctx_cur->get_n_kv(); - const auto n_tokens = ubatch.n_tokens; - const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams); + ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1649,9 +1979,11 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * v_cur, ggml_tensor * kq_b, ggml_tensor * sinks, - ggml_tensor * v_mla, + ggml_tensor * v_mla, // TODO: remove float kq_scale, int il) const { + GGML_ASSERT(v_mla == nullptr); + // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced // expand k later to enable rope fusion which directly writes into k-v cache @@ -1679,6 +2011,89 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); + if (wo) { + cur = build_lora_mm(wo, cur); + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) { + // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators + ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + } + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + +static std::unique_ptr build_attn_inp_k_impl( + ggml_context * ctx0, + const llama_ubatch & ubatch, + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_context * mctx_cur) { + + auto inp = std::make_unique(hparams, cparams, mctx_cur); + + { + GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA"); + + inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); + + inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams); + ggml_set_input(inp->self_kq_mask); + + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + } + + return inp; +} + +llm_graph_input_attn_k * llm_graph_context::build_attn_inp_k() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur); + + return (llm_graph_input_attn_k *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_k * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * sinks, + ggml_tensor * v_mla, + float kq_scale, + int il) const { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + // expand k later to enable rope fusion which directly writes into k-v cache + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, v_cur); + ggml_build_forward_expand(gf, k_cur); + + const auto * mctx_cur = inp->mctx; + + // store to KV cache + { + const auto & k_idxs = inp->get_k_idxs(); + + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); + } + + const auto & kq_mask = inp->get_kq_mask(); + + ggml_tensor * q = q_cur; + ggml_tensor * k = mctx_cur->get_k(ctx0, il); + ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0); + + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); + cb(cur, "kqv_out", il); + if (wo) { cur = build_lora_mm(wo, cur); if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { @@ -1824,32 +2239,30 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const auto inp = std::make_unique(hparams, cparams, mctx_cur); - const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - { - const auto n_kv = mctx_cur->get_base()->get_n_kv(); - inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams); ggml_set_input(inp->self_kq_mask); + ggml_set_name(inp->self_kq_mask, "self_kq_mask"); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv"); } { GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA"); - const auto n_kv = mctx_cur->get_swa()->get_n_kv(); - inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams); ggml_set_input(inp->self_kq_mask_swa); + ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa"); inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; + ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv"); } return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp)); @@ -1985,17 +2398,71 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } +llm_graph_input_mem_hybrid_k * llm_graph_context::build_inp_mem_hybrid_k() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr()); + auto inp_attn = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); + + auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); + + return (llm_graph_input_mem_hybrid_k *) res->add_input(std::move(inp)); +} + +llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); + + // build iswa attention input + const auto * attn_ctx = mctx_cur->get_attn(); + + auto inp_attn = std::make_unique(hparams, cparams, attn_ctx); + + { + inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch); + inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch); + + inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams); + ggml_set_input(inp_attn->self_kq_mask); + + inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask; + } + + { + inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch); + inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch); + + inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams); + ggml_set_input(inp_attn->self_kq_mask_swa); + + inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa; + } + + auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); + + return (llm_graph_input_mem_hybrid_iswa *) res->add_input(std::move(inp)); +} + void llm_graph_context::build_dense_out( ggml_tensor * dense_2, + ggml_tensor * dense_2_b, ggml_tensor * dense_3) const { - if (!cparams.embeddings || dense_2 == nullptr || dense_3 == nullptr) { + if (!cparams.embeddings || !(dense_2 || dense_2_b || dense_3)) { return; } ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd; GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd"); - cur = ggml_mul_mat(ctx0, dense_2, cur); - cur = ggml_mul_mat(ctx0, dense_3, cur); + if (dense_2) { + cur = ggml_mul_mat(ctx0, dense_2, cur); + } + if (dense_2_b) { + cur = ggml_add(ctx0, cur, dense_2_b); + } + if (dense_3) { + cur = ggml_mul_mat(ctx0, dense_3, cur); + } cb(cur, "result_embd_pooled", -1); res->t_embd_pooled = cur; ggml_build_forward_expand(gf, cur); @@ -2006,7 +2473,8 @@ void llm_graph_context::build_pooling( ggml_tensor * cls, ggml_tensor * cls_b, ggml_tensor * cls_out, - ggml_tensor * cls_out_b) const { + ggml_tensor * cls_out_b, + ggml_tensor * cls_norm) const { if (!cparams.embeddings) { return; } @@ -2045,8 +2513,15 @@ void llm_graph_context::build_pooling( } break; case LLAMA_POOLING_TYPE_RANK: { - ggml_tensor * inp_cls = build_inp_cls(); - cur = ggml_get_rows(ctx0, inp, inp_cls); + if (arch == LLM_ARCH_MODERN_BERT) { + // modern bert gte reranker builds mean first then applies prediction head and classifier + // https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modular_modernbert.py#L1404-1411 + ggml_tensor * inp_mean = build_inp_mean(); + cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean); + } else { + ggml_tensor * inp_cls = build_inp_cls(); + cur = ggml_get_rows(ctx0, inp, inp_cls); + } // classification head // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566 @@ -2055,7 +2530,15 @@ void llm_graph_context::build_pooling( if (cls_b) { cur = ggml_add(ctx0, cur, cls_b); } - cur = ggml_tanh(ctx0, cur); + if (arch == LLM_ARCH_MODERN_BERT) { + cur = ggml_gelu(ctx0, cur); + } else { + cur = ggml_tanh(ctx0, cur); + } + if (cls_norm) { + // head norm + cur = build_norm(cur, cls_norm, NULL, LLM_NORM, -1); + } } // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en @@ -2086,6 +2569,94 @@ void llm_graph_context::build_pooling( ggml_build_forward_expand(gf, cur); } +void llm_graph_context::build_sampling() const { + if (samplers.empty() || !res->t_logits) { + return; + } + + std::array outs; + outs[0] = res->t_logits; + + auto inp_sampling = std::make_unique(samplers); + res->add_input(std::move(inp_sampling)); + + std::map seq_to_logit_row; + int32_t logit_row_idx = 0; + + for (uint32_t i = 0; i < ubatch.n_tokens; i++) { + if (ubatch.output[i]) { + llama_seq_id seq_id = ubatch.seq_id[i][0]; + seq_to_logit_row[seq_id] = logit_row_idx; + logit_row_idx++; + } + } + + // res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1) + GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor"); + + // add a dummy row of logits + // this trick makes the graph static, regardless of which samplers are activated + // this is important in order to minimize graph reallocations + ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0); + + for (const auto & [seq_id, sampler] : samplers) { + const auto it = seq_to_logit_row.find(seq_id); + + // inactive samplers always work on the first row + const auto row_idx = it != seq_to_logit_row.end() ? it->second : 0; + const int i_out = it != seq_to_logit_row.end() ? 1 : 0; + + ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]); + ggml_format_name(logits_seq, "logits_seq_%d", seq_id); + + struct llama_sampler_data data = { + /*.logits =*/ logits_seq, + /*.probs =*/ nullptr, + /*.sampled =*/ nullptr, + /*.candidates =*/ nullptr, + }; + + assert(sampler->iface->backend_apply); + sampler->iface->backend_apply(sampler, ctx0, gf, &data); + + if (data.sampled != nullptr) { + res->t_sampled[seq_id] = data.sampled; + outs[1] = data.sampled; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); + } + + if (data.probs != nullptr) { + res->t_sampled_probs[seq_id] = data.probs; + outs[1] = data.probs; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); + } + + if (data.logits != nullptr) { + res->t_sampled_logits[seq_id] = data.logits; + outs[1] = data.logits; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); + } + + if (data.candidates != nullptr) { + res->t_candidates[seq_id] = data.candidates; + outs[1] = data.candidates; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); + } + } + + // TODO: Call llama_sampler_accept_ggml after all samplers have been applied. + /* + for (const auto & [seq_id, sampler] : samplers) { + if (auto it = res->t_sampled.find(seq_id); it != res->t_sampled.end()) { + ggml_tensor * selected_token = it->second; + if (selected_token != nullptr) { + llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token); + } + } + } + */ +} + int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { // TODO move to hparams if a T5 variant appears that uses a different value const int64_t max_distance = 128; diff --git a/llama/llama.cpp/src/llama-graph.h b/llama/llama.cpp/src/llama-graph.h index 81ac329cc31..e8f006977d2 100644 --- a/llama/llama.cpp/src/llama-graph.h +++ b/llama/llama.cpp/src/llama-graph.h @@ -10,6 +10,7 @@ #include #include #include +#include struct ggml_cgraph; struct ggml_context; @@ -23,6 +24,7 @@ class llama_kv_cache_context; class llama_kv_cache_iswa_context; class llama_memory_recurrent_context; class llama_memory_hybrid_context; +class llama_memory_hybrid_iswa_context; // certain models (typically multi-modal) can produce different types of graphs enum llm_graph_type { @@ -104,7 +106,7 @@ using llm_graph_input_ptr = std::unique_ptr; class llm_graph_input_embd : public llm_graph_input_i { public: - llm_graph_input_embd() = default; + llm_graph_input_embd(int64_t n_embd) : n_embd(n_embd) {} virtual ~llm_graph_input_embd() = default; void set_input(const llama_ubatch * ubatch) override; @@ -113,6 +115,8 @@ class llm_graph_input_embd : public llm_graph_input_i { ggml_tensor * tokens = nullptr; // I32 [n_batch] ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch] + + const int64_t n_embd = 0; }; class llm_graph_input_pos : public llm_graph_input_i { @@ -313,6 +317,39 @@ class llm_graph_input_attn_kv : public llm_graph_input_i { const llama_kv_cache_context * mctx; }; +// V-less input for the KV cache +// ref: https://github.com/ggml-org/llama.cpp/pull/19067 +class llm_graph_input_attn_k : public llm_graph_input_i { +public: + llm_graph_input_attn_k( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_context * mctx) : + hparams(hparams), + cparams(cparams), + mctx(mctx) { + } + ~llm_graph_input_attn_k() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + ggml_tensor * get_k_idxs() const { return self_k_idxs; } + + ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } + + ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] + + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + + const llama_hparams hparams; + const llama_cparams cparams; + + const llama_kv_cache_context * mctx; +}; + class llm_graph_input_attn_kv_iswa : public llm_graph_input_i { public: llm_graph_input_attn_kv_iswa( @@ -396,6 +433,74 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i { const llama_memory_hybrid_context * mctx; }; +class llm_graph_input_mem_hybrid_k : public llm_graph_input_i { +public: + llm_graph_input_mem_hybrid_k( + const llama_cparams & cparams, + std::unique_ptr inp_attn, + std::unique_ptr inp_rs, + const llama_memory_hybrid_context * mctx) : + inp_attn(std::move(inp_attn)), + inp_rs(std::move(inp_rs)), + cparams(cparams), + mctx(mctx) { } + virtual ~llm_graph_input_mem_hybrid_k() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + std::unique_ptr inp_attn; + std::unique_ptr inp_rs; + + llm_graph_input_attn_k * get_attn() const { return inp_attn.get(); } + llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + + const llama_cparams cparams; + + const llama_memory_hybrid_context * mctx; +}; + +class llm_graph_input_mem_hybrid_iswa : public llm_graph_input_i { +public: + llm_graph_input_mem_hybrid_iswa( + const llama_cparams & cparams, + std::unique_ptr inp_attn, + std::unique_ptr inp_rs, + const llama_memory_hybrid_iswa_context * mctx) : + inp_attn(std::move(inp_attn)), + inp_rs(std::move(inp_rs)), + cparams(cparams), + mctx(mctx) { } + virtual ~llm_graph_input_mem_hybrid_iswa() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + std::unique_ptr inp_attn; + std::unique_ptr inp_rs; + + llm_graph_input_attn_kv_iswa * get_attn() const { return inp_attn.get(); } + llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + + const llama_cparams cparams; + + const llama_memory_hybrid_iswa_context * mctx; +}; + +class llm_graph_input_sampling : public llm_graph_input_i { +public: + llm_graph_input_sampling(std::map samplers) : + samplers(std::move(samplers)) { } + virtual ~llm_graph_input_sampling() = default; + + void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + + std::map samplers; +}; + // // llm_graph_result // @@ -429,6 +534,23 @@ struct llm_graph_params { const llama_memory_context_i * mctx; const llama_cross * cross; + std::map samplers; + + static bool samplers_equal( + const std::map & lhs, + const std::map & rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (const auto & [seq_id, sampler] : lhs) { + auto it = rhs.find(seq_id); + if (it == rhs.end() || it->second != sampler) { + return false; + } + } + return true; + } + uint32_t n_outputs; llm_graph_cb cb; @@ -468,15 +590,36 @@ struct llm_graph_params { return false; } + if (n_outputs != other.n_outputs) { + return false; + } + + if (!samplers_equal(samplers, other.samplers)) { + return false; + } + + if (samplers.size() > 0) { + if (!ubatch.data || !other.ubatch.data) { + return false; + } + + // check that the outputs are the same for all samplers + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + if (ubatch.output[i] != other.ubatch.output[i] || + ubatch.seq_id[i][0] != other.ubatch.seq_id[i][0]) { + return false; + } + } + } + return cparams.embeddings == other.cparams.embeddings && cparams.causal_attn == other.cparams.causal_attn && - arch == other.arch && - gtype == other.gtype && - cvec == other.cvec && - loras == other.loras && - cross == other.cross && - n_outputs == other.n_outputs; + arch == other.arch && + gtype == other.gtype && + cvec == other.cvec && + loras == other.loras && + cross == other.cross; } }; @@ -486,7 +629,7 @@ class llm_graph_result { virtual ~llm_graph_result() = default; - ggml_tensor * get_tokens() const { return t_tokens; } + ggml_tensor * get_inp_tokens() const { return t_inp_tokens; } ggml_tensor * get_logits() const { return t_logits; } ggml_tensor * get_embd() const { return t_embd; } ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } @@ -499,6 +642,7 @@ class llm_graph_result { void reset(); void set_inputs(const llama_ubatch * ubatch); + void set_outputs(); // try to update the existing graph result using the new graph parameters in order to reuse it // this can only be done if we determine that the resulting graph using the new graph parameters @@ -512,11 +656,17 @@ class llm_graph_result { void set_params(const llm_graph_params & params); // important graph nodes - ggml_tensor * t_tokens = nullptr; + ggml_tensor * t_inp_tokens = nullptr; + ggml_tensor * t_inp_embd = nullptr; // [n_embd_inp, n_tokens] ggml_tensor * t_logits = nullptr; ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; + std::map t_sampled_logits; + std::map t_candidates; + std::map t_sampled; + std::map t_sampled_probs; + std::vector inputs; ggml_context_ptr ctx_compute; @@ -592,6 +742,8 @@ struct llm_graph_context { const llama_memory_context_i * mctx; const llama_cross * cross; + std::map samplers; + const llm_graph_cb & cb_func; llm_graph_result * res; @@ -662,7 +814,8 @@ struct llm_graph_context { float w_scale, llama_expert_gating_func_type gating_op, int il, - ggml_tensor * probs_in = nullptr) const; + ggml_tensor * probs_in = nullptr, + ggml_tensor * gate_up_exps = nullptr) const; ggml_tensor * build_moe_ffn( ggml_tensor * cur, @@ -683,7 +836,9 @@ struct llm_graph_context { float w_scale, llama_expert_gating_func_type gating_op, int il, - ggml_tensor * probs_in = nullptr) const; + ggml_tensor * probs_in = nullptr, + ggml_tensor * gate_up_exps = nullptr, + ggml_tensor * gate_up_exps_b = nullptr) const; // // inputs @@ -742,6 +897,21 @@ struct llm_graph_context { ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * kq_b, ggml_tensor * sinks, // [n_head_q] + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] // TODO: remove + float kq_scale, + int il) const; + + llm_graph_input_attn_k * build_attn_inp_k() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_k * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + ggml_tensor * sinks, // [n_head_q] ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale, int il) const; @@ -821,6 +991,9 @@ struct llm_graph_context { // llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const; + llm_graph_input_mem_hybrid_k * build_inp_mem_hybrid_k() const; + + llm_graph_input_mem_hybrid_iswa * build_inp_mem_hybrid_iswa() const; // // pooling @@ -830,7 +1003,14 @@ struct llm_graph_context { ggml_tensor * cls, ggml_tensor * cls_b, ggml_tensor * cls_out, - ggml_tensor * cls_out_b) const; + ggml_tensor * cls_out_b, + ggml_tensor * cls_norm) const; + + // + // sampling (backend sampling) + // + + void build_sampling() const; // // dense (out) @@ -838,6 +1018,7 @@ struct llm_graph_context { void build_dense_out( ggml_tensor * dense_2, + ggml_tensor * dense_2_b, ggml_tensor * dense_3) const; }; diff --git a/llama/llama.cpp/src/llama-hparams.cpp b/llama/llama.cpp/src/llama-hparams.cpp index aabff2f066b..515a900b35f 100644 --- a/llama/llama.cpp/src/llama-hparams.cpp +++ b/llama/llama.cpp/src/llama-hparams.cpp @@ -72,6 +72,10 @@ uint32_t llama_hparams::n_embd_inp() const { return n_embd_inp; } +uint32_t llama_hparams::n_embd_out() const { + return n_embd_out_impl > 0 ? n_embd_out_impl : n_embd; +} + uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const { const uint32_t n_head_kv = this->n_head_kv(il); @@ -135,6 +139,13 @@ uint32_t llama_hparams::n_embd_r() const { return n_embd * (n_shortconv_l_cache - 1); } + if (n_embd_head_kda != 0) { + // for Kimi KDA layers + // Conv state for Q, K, V: 3 * (d_conv - 1) * n_head * head_dim + const uint32_t d_inner = n_head() * n_embd_head_kda; // 32 * 128 = 4096 + return 3 * (ssm_d_conv > 0 ? ssm_d_conv - 1 : 3) * d_inner; + } + // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed // Corresponds to Mamba's conv_states size @@ -147,6 +158,13 @@ uint32_t llama_hparams::n_embd_s() const { return n_embd * wkv_head_size; } + if (n_embd_head_kda != 0) { + // for Kimi KDA layers + // Full recurrent state: head_dim * head_dim * n_head + // h tensor shape for delta attention: [head_dim, head_dim, n_head] + return n_embd_head_kda * n_embd_head_kda * n_head(); // 128 * 128 * 32 = 524288 + } + // corresponds to Mamba's ssm_states size return ssm_d_state * ssm_d_inner; } @@ -179,6 +197,21 @@ bool llama_hparams::is_swa(uint32_t il) const { GGML_ABORT("fatal error"); } +bool llama_hparams::is_mla() const { + assert((n_embd_head_k_mla_impl == 0 && n_embd_head_v_mla_impl == 0) || + (n_embd_head_k_mla_impl != 0 && n_embd_head_v_mla_impl != 0)); + + return n_embd_head_k_mla_impl != 0 && n_embd_head_v_mla_impl != 0; +} + +uint32_t llama_hparams::n_embd_head_k_mla() const { + return is_mla() ? n_embd_head_k_mla_impl : n_embd_head_k; +} + +uint32_t llama_hparams::n_embd_head_v_mla() const { + return is_mla() ? n_embd_head_v_mla_impl : n_embd_head_v; +} + bool llama_hparams::has_kv(uint32_t il) const { if (n_layer_kv_from_start >= 0) { if (il < (uint32_t) n_layer_kv_from_start) { @@ -204,42 +237,6 @@ uint32_t llama_hparams::n_layer_kv() const { return res; } -bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) { - assert(p0 >= 0 && p1 >= 0); - - switch (swa_type) { - case LLAMA_SWA_TYPE_NONE: - { - } break; - case LLAMA_SWA_TYPE_STANDARD: - { - if (p1 - p0 >= (int32_t) n_swa) { - return true; - } - } break; - case LLAMA_SWA_TYPE_CHUNKED: - { - const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; - - if (p0 < pos_chunk_start) { - return true; - } - } break; - case LLAMA_SWA_TYPE_SYMMETRIC: - { - const int32_t half_n_swa = (int32_t) n_swa / 2; - const int32_t pos_diff = p1 - p0; - - // Mask if outside the symmetric window - if (pos_diff < -half_n_swa || pos_diff > half_n_swa) { - return true; - } - } break; - } - - return false; -} - bool llama_hparams::use_mrope() const { return rope_sections[0] > 0 && rope_sections[1] > 0; } diff --git a/llama/llama.cpp/src/llama-hparams.h b/llama/llama.cpp/src/llama-hparams.h index c6e67327620..ccca1bd5f09 100644 --- a/llama/llama.cpp/src/llama-hparams.h +++ b/llama/llama.cpp/src/llama-hparams.h @@ -3,6 +3,7 @@ #include "llama.h" #include +#include // bump if necessary #define LLAMA_MAX_LAYERS 512 @@ -41,7 +42,6 @@ struct llama_hparams { uint32_t n_ctx_train; // context size the model was trained on uint32_t n_embd; - uint32_t n_embd_features = 0; uint32_t n_layer; int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache uint32_t n_rot; @@ -52,8 +52,8 @@ struct llama_hparams { uint32_t n_rel_attn_bkts = 0; // note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA - uint32_t n_embd_head_k_mla = 0; - uint32_t n_embd_head_v_mla = 0; + uint32_t n_embd_head_k_mla_impl = 0; + uint32_t n_embd_head_v_mla_impl = 0; // for WavTokenizer struct llama_hparams_posnet posnet; @@ -107,9 +107,9 @@ struct llama_hparams { float rope_attn_factor = 1.0f; float rope_freq_base_train; - float rope_freq_base_train_swa; + float rope_freq_base_train_swa = 10000.0f; float rope_freq_scale_train; - float rope_freq_scale_train_swa; + float rope_freq_scale_train_swa = 1.0f; uint32_t n_ctx_orig_yarn; float rope_yarn_log_mul = 0.0f; @@ -125,10 +125,11 @@ struct llama_hparams { llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; // the size of the sliding window (0 - no SWA) uint32_t n_swa = 0; - // if swa_layers[il] == true, then layer il is SWA - // if swa_layers[il] == false, then layer il is dense (i.e. non-SWA) + // if swa_layers[il] == 1, then layer il is SWA + // if swa_layers[il] == 0, then layer il is dense (i.e. non-SWA) // by default, all layers are dense - std::array swa_layers; + // note: using uint32_t type for compatibility reason + std::array swa_layers; // for State Space Models uint32_t ssm_d_conv = 0; @@ -137,6 +138,9 @@ struct llama_hparams { uint32_t ssm_dt_rank = 0; uint32_t ssm_n_group = 0; + // for Kimi Linear KDA + uint32_t n_embd_head_kda = 0; + // for hybrid state space models std::array recurrent_layer_arr; @@ -163,6 +167,9 @@ struct llama_hparams { // for Classifiers uint32_t n_cls_out = 1; + // output embedding dimension (0 = use n_embd) + uint32_t n_embd_out_impl = 0; + // llama4 smallthinker uint32_t n_moe_layer_step = 0; uint32_t n_no_rope_layer_step = 4; @@ -188,11 +195,16 @@ struct llama_hparams { std::array xielu_beta; std::array xielu_eps; + // DSA (deepseek sparse attention) + uint32_t indexer_n_head = 0; + uint32_t indexer_head_size = 0; + uint32_t indexer_top_k = 0; + // qwen3vl deepstack uint32_t n_deepstack_layers = 0; // needed by encoder-decoder models (e.g. T5, FLAN-T5) - // ref: https://github.com/ggerganov/llama.cpp/pull/8141 + // ref: https://github.com/ggml-org/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; uint32_t dec_n_layer = 0; @@ -200,6 +212,11 @@ struct llama_hparams { enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; + + // Step35: optional per-layer clamps for (Swi)GLU + std::array swiglu_clamp_exp; // clamping for expert FFN + std::array swiglu_clamp_shexp; // shared expert + // this value n_pattern means that every nth layer is dense (i.e. non-SWA) // dense_first means whether the pattern is start with a dense layer // note that if n_pattern == 0, all layers are SWA @@ -235,6 +252,9 @@ struct llama_hparams { // dimension of main + auxiliary input embeddings uint32_t n_embd_inp() const; + // dimension of output embeddings + uint32_t n_embd_out() const; + // dimension of key embeddings across all k-v heads uint32_t n_embd_k_gqa(uint32_t il = 0) const; @@ -266,15 +286,57 @@ struct llama_hparams { bool is_swa(uint32_t il) const; + // note: currently only support if either all or none of the layers are MLA + bool is_mla() const; + + uint32_t n_embd_head_k_mla() const; + uint32_t n_embd_head_v_mla() const; + bool has_kv(uint32_t il) const; // number of layers for which has_kv() returns true uint32_t n_layer_kv() const; // note that this function uses different SWA parameters from those in the hparams + // note: inlined on purpose for performance reasons // TODO: think of a better place for this function // TODO: pack the SWA params in a struct? - static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1); + static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) { + assert(p0 >= 0 && p1 >= 0); + + switch (swa_type) { + case LLAMA_SWA_TYPE_NONE: + { + } break; + case LLAMA_SWA_TYPE_STANDARD: + { + if (p1 - p0 >= (int32_t) n_swa) { + return true; + } + } break; + case LLAMA_SWA_TYPE_CHUNKED: + { + const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; + + if (p0 < pos_chunk_start) { + return true; + } + } break; + case LLAMA_SWA_TYPE_SYMMETRIC: + { + const int32_t half_n_swa = (int32_t) n_swa / 2; + const int32_t pos_diff = p1 - p0; + + // Mask if outside the symmetric window + if (pos_diff < -half_n_swa || pos_diff > half_n_swa) { + return true; + } + } break; + } + + return false; + } + bool use_mrope() const; }; diff --git a/llama/llama.cpp/src/llama-impl.cpp b/llama/llama.cpp/src/llama-impl.cpp index 8e3e7b223a6..710a5a1e08d 100644 --- a/llama/llama.cpp/src/llama-impl.cpp +++ b/llama/llama.cpp/src/llama-impl.cpp @@ -109,9 +109,9 @@ std::string llama_format_tensor_shape(const std::vector & ne) { std::string llama_format_tensor_shape(const struct ggml_tensor * t) { char buf[256]; - snprintf(buf, sizeof(buf), "%5" PRId64, t->ne[0]); + snprintf(buf, sizeof(buf), "%6" PRId64, t->ne[0]); for (int i = 1; i < GGML_MAX_DIMS; i++) { - snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, t->ne[i]); + snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %6" PRId64, t->ne[i]); } return buf; } diff --git a/llama/llama.cpp/src/llama-impl.h b/llama/llama.cpp/src/llama-impl.h index c3391e79f51..dfd9fee9f44 100644 --- a/llama/llama.cpp/src/llama-impl.h +++ b/llama/llama.cpp/src/llama-impl.h @@ -49,6 +49,16 @@ struct time_meas { int64_t & t_acc; }; +template +struct buffer_view { + T * data; + size_t size = 0; + + bool has_data() const { + return data && size > 0; + } +}; + void replace_all(std::string & s, const std::string & search, const std::string & replace); // TODO: rename to llama_format ? diff --git a/llama/llama.cpp/src/llama-kv-cache-iswa.cpp b/llama/llama.cpp/src/llama-kv-cache-iswa.cpp index 3a34102a23d..26e2cb4270b 100644 --- a/llama/llama.cpp/src/llama-kv-cache-iswa.cpp +++ b/llama/llama.cpp/src/llama-kv-cache-iswa.cpp @@ -218,7 +218,9 @@ llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, } bool llama_kv_cache_iswa::get_can_shift() const { - return kv_base->get_size() == kv_swa->get_size(); + return kv_base->get_can_shift() && + kv_swa->get_can_shift() && + kv_base->get_size() == kv_swa->get_size(); } void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { diff --git a/llama/llama.cpp/src/llama-kv-cache.cpp b/llama/llama.cpp/src/llama-kv-cache.cpp index 3186242d60f..6b668ee9abd 100644 --- a/llama/llama.cpp/src/llama-kv-cache.cpp +++ b/llama/llama.cpp/src/llama-kv-cache.cpp @@ -97,6 +97,8 @@ llama_kv_cache::llama_kv_cache( __func__, hparams.n_embd_v_gqa_max()); } + const bool is_mla = hparams.is_mla(); + for (uint32_t il = 0; il < hparams.n_layer; il++) { if (!hparams.has_kv(il)) { LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il); @@ -130,18 +132,21 @@ llama_kv_cache::llama_kv_cache( throw std::runtime_error("failed to create ggml context for kv cache"); } - ggml_tensor * k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream); - ggml_tensor * v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream); + const bool has_k = true; + const bool has_v = !is_mla; + + ggml_tensor * k = has_k ? ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream) : nullptr; + ggml_tensor * v = has_v ? ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream) : nullptr; - ggml_format_name(k, "cache_k_l%d", il); - ggml_format_name(v, "cache_v_l%d", il); + has_k && ggml_format_name(k, "cache_k_l%d", il); + has_v && ggml_format_name(v, "cache_v_l%d", il); std::vector k_stream; std::vector v_stream; for (uint32_t s = 0; s < n_stream; ++s) { - k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2])); - v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2])); + k_stream.push_back(has_k ? ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]) : nullptr); + v_stream.push_back(has_v ? ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]) : nullptr); } map_layer_ids[il] = layers.size(); @@ -647,7 +652,10 @@ bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_co const auto & layer = layers[il]; ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]); - ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]); + + if (layer.v_stream[ssrc]) { + ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]); + } } } } @@ -852,7 +860,7 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, const llama_seq_id seq_id_cell = cells.seq_get(idx); // SWA mask - if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { + if (llama_hparams::is_masked_swa(n_swa, swa_type, pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { can_use = true; } } @@ -966,6 +974,13 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & } bool llama_kv_cache::get_can_shift() const { + // Step35 uses per-layer RoPE dims; K-shift assumes a single global n_rot. + if (model.arch == LLM_ARCH_STEP35) { + return false; + } + if (hparams.n_pos_per_embd() > 1) { + return false; + } return true; } @@ -1237,90 +1252,236 @@ void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const { } } -void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { - const uint32_t n_tokens = ubatch->n_tokens; +struct args_set_input_kq_mask { + const llama_hparams & hparams; + const llama_ubatch * ubatch; - GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - float * data = (float *) dst->data; + const std::vector & v_cells; + const std::vector & seq_to_stream; - const int64_t n_kv = dst->ne[0]; - const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch + uint32_t n_swa; + llama_swa_type swa_type; - GGML_ASSERT(n_tokens%n_stream == 0); + int64_t n_kv; + int64_t n_stream; + int64_t n_tps; +}; - // n_tps == n_tokens_per_stream - const int64_t n_tps = n_tokens/n_stream; +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + //const auto & hparams = args.hparams; + const auto & ubatch = args.ubatch; - std::fill(data, data + ggml_nelements(dst), -INFINITY); - - // Use only the previous KV cells of the correct sequence for each token of the ubatch. - // It's assumed that if a token in the batch has multiple sequences, they are equivalent. - // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch: - // Causal mask: - // xxx------- - // xxxx------ - // xxxxx----- - // Non-causal mask: - // xxxxx----- - // xxxxx----- - // xxxxx----- - // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 - // TODO: optimize this section - for (uint32_t h = 0; h < 1; ++h) { - for (uint32_t s = 0; s < n_stream; ++s) { - for (uint32_t ii = 0; ii < n_tps; ++ii) { - const uint32_t i = s*n_tps + ii; + const auto & v_cells = args.v_cells; + const auto & seq_to_stream = args.seq_to_stream; - const llama_seq_id seq_id = ubatch->seq_id[i][0]; + const uint32_t n_swa = args.n_swa; + const llama_swa_type swa_type = args.swa_type; - const auto & cells = v_cells[seq_to_stream[seq_id]]; + const int64_t n_kv = args.n_kv; + const int64_t n_stream = args.n_stream; + const int64_t n_tps = args.n_tps; - const llama_pos p1 = ubatch->pos[i]; + // the min position in the batch for each sequence + llama_pos seq_pos_min[LLAMA_MAX_SEQ]; + std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX); - // for M-RoPE - const bool is_2d = ubatch->is_pos_2d(); - const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0; - const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0; + for (uint32_t i = 0; i < ubatch->n_tokens; ++i) { + const llama_seq_id seq_id = ubatch->seq_id[i][0]; - const uint64_t idst = n_kv*(h*n_stream*n_tps + s*n_tps + ii); + seq_pos_min[seq_id] = std::min(seq_pos_min[seq_id], ubatch->pos[i]); + } - for (uint32_t j = 0; j < n_kv; ++j) { - if (cells.is_empty(j)) { - continue; - } + for (uint32_t s = 0; s < n_stream; ++s) { + // bookeeping of the KQ mask cells that could change for other tokens of the same sequence + std::unordered_map seq_srct; + std::unordered_map> seq_idxs; + + for (uint32_t ii = 0; ii < n_tps; ++ii) { + const uint32_t i = s*n_tps + ii; + + const llama_seq_id seq_id = ubatch->seq_id[i][0]; + + const auto & cells = v_cells.at(seq_to_stream[seq_id]); + + llama_pos p0 = -1; + const llama_pos p1 = ubatch->pos[i]; - // mask the token if not the same sequence - if (!cells.seq_has(j, seq_id)) { - continue; + // for M-RoPE + const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0; + const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0; + + const uint64_t idst = n_kv*i; + + // for tokens of the same sequence, the mask is mostly the same, so we can reuse it + // the only cells that could change are the ones that are with similar positions as the + // ones in the batch (i.e. due to causal masking, SWA, etc.) + // keep track of those cells and shortcut the loop to save time + // note: this optimization is not compatible with Alibi position encoding + // ref: https://github.com/ggml-org/llama.cpp/pull/18842 + bool prev = false; + + auto & idxs = seq_idxs[seq_id]; + + if (!alibi) { + if (seq_srct.find(seq_id) != seq_srct.end()) { + const uint32_t srct = seq_srct[seq_id]; + + const uint64_t idst_prev = n_kv*srct; + + std::copy(data + idst_prev, data + idst_prev + n_kv, data + idst); + + prev = true; + } else { + idxs.clear(); + idxs.reserve(ubatch->n_tokens + n_swa + 32); + + seq_srct[seq_id] = i; + } + } + + for (uint32_t jj = 0; jj < n_kv; ++jj) { + uint32_t j = jj; + + // we have an exiting mask for this sequence -> update just seq_idxs + if (!alibi) { + if (prev) { + if (jj >= idxs.size()) { + break; + } + + j = idxs[jj]; } + } + + if (cells.is_empty(j)) { + goto skip; + } + + // mask the token if not the same sequence + if (!cells.seq_has(j, seq_id)) { + goto skip; + } - const llama_pos p0 = cells.pos_get(j); + p0 = cells.pos_get(j); + if (!alibi) { + if (!prev) { + // record all cells for which: p0 >= seq_pos_min[seq_id] - n_swa - 32 + if (p0 + (int32_t) (n_swa + 32) >= seq_pos_min[seq_id]) { + idxs.push_back(j); + } + } + } + + if (causal) { // mask future tokens - if (causal_attn && p0 > p1) { - continue; + if (p0 > p1) { + goto skip; } // M-RoPE causal mask - if (causal_attn && is_2d && p0 == p1) { - const auto & p0_ext = cells.ext_get(j); - if (p0_ext.is_2d_gt(p1_x, p1_y)) { - continue; + if (is_2d) { + if (p0 == p1) { + const auto & p0_ext = cells.ext_get(j); + + if (p0_ext.is_2d_gt(p1_x, p1_y)) { + goto skip; + } } } + } - // apply SWA if any - if (is_masked_swa(p0, p1)) { - continue; + // apply SWA if any + if (swa) { + if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) { + goto skip; } + } - data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; + if (alibi) { + data[idst + j] = -std::abs(p0 - p1); + } else { + data[idst + j] = 0.0f; } + + continue; +skip: + data[idst + j] = -INFINITY; } } } } +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + const bool alibi = args.hparams.use_alibi; + if (alibi) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); + } +} + +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + const bool is_2d = args.ubatch->is_pos_2d(); + if (is_2d) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); + } +} + +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE; + if (swa) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); + } +} + +void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { + const uint32_t n_tokens = ubatch->n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + float * data = (float *) dst->data; + + const int64_t n_kv = dst->ne[0]; + const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch + + GGML_ASSERT(n_tokens%n_stream == 0); + + // n_tps == n_tokens_per_stream + const int64_t n_tps = n_tokens/n_stream; + + //const int64_t t_start = ggml_time_us(); + + const args_set_input_kq_mask args = { + /*.hparams =*/ hparams, + /*.ubatch =*/ ubatch, + /*.v_cells =*/ v_cells, + /*.seq_to_stream =*/ seq_to_stream, + /*.n_swa =*/ n_swa, + /*.swa_type =*/ swa_type, + /*.n_kv =*/ n_kv, + /*.n_stream =*/ n_stream, + /*.n_tps =*/ n_tps, + }; + + if (causal_attn) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); + } + + //const int64_t t_end = ggml_time_us(); + + //LLAMA_LOG_ERROR("%s: kq mask time: %0.3f ms\n", __func__, (t_end - t_start)/1000.0); +} + void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { const int64_t n_tokens = ubatch->n_tokens; @@ -1370,7 +1531,7 @@ size_t llama_kv_cache::size_v_bytes() const { size_t size_v_bytes = 0; for (const auto & layer : layers) { - size_v_bytes += ggml_nbytes(layer.v); + size_v_bytes += layer.v ? ggml_nbytes(layer.v) : 0; } return size_v_bytes; @@ -1448,6 +1609,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co const auto & n_embd_head_k = hparams.n_embd_head_k; //const auto & n_embd_head_v = hparams.n_embd_head_v; + const auto & n_rot = hparams.n_rot; + + const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0; + auto inp = std::make_unique(this); inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream); @@ -1468,10 +1633,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co ggml_tensor * k = ggml_view_3d(ctx, layer.k, - n_embd_head_k, n_head_kv, get_size()*n_stream, + n_rot, n_head_kv, get_size()*n_stream, ggml_row_size(layer.k->type, n_embd_head_k), ggml_row_size(layer.k->type, n_embd_k_gqa), - 0); + ggml_row_size(layer.k->type, n_embd_nope)); ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); @@ -1483,10 +1648,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co return gf; } -bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const { - return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1); -} - void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { GGML_UNUSED(flags); @@ -1618,8 +1779,6 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t io.write(&v_trans, sizeof(v_trans)); io.write(&n_layer, sizeof(n_layer)); - std::vector tmp_buf; - // Iterate and write all the keys first, each row is a cell // Get whole range at a time for (const auto & layer : layers) { @@ -1637,7 +1796,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa); io.write(&k_size_row, sizeof(k_size_row)); - // Read each range of cells of k_size length each into tmp_buf and write out + // Read each range of cells of k_size length and write out for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * k_size_row; @@ -1652,6 +1811,9 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[cr.strm]; + if (!v) { + continue; + } // Write value type const int32_t v_type_i = (int32_t) v->type; @@ -1661,7 +1823,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa); io.write(&v_size_row, sizeof(v_size_row)); - // Read each range of cells of v_size length each into tmp_buf and write out + // Read each range of cells of v_size length and write out for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * v_size_row; @@ -1678,6 +1840,9 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[cr.strm]; + if (!v) { + continue; + } // Write value type const int32_t v_type_i = (int32_t) v->type; @@ -1692,7 +1857,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t // For each row, we get the element values of each cell for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - // Read each range of cells of v_size_el length each into tmp_buf and write out + // Read each range of cells of v_size_el length and write out for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t src_offset = (range.first + j * kv_size) * v_size_el; @@ -1881,6 +2046,9 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[strm]; + if (!v) { + continue; + } // Read type of value int32_t v_type_i_ref; @@ -1922,6 +2090,9 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[strm]; + if (!v) { + continue; + } // Read type of value int32_t v_type_i_ref; diff --git a/llama/llama.cpp/src/llama-kv-cache.h b/llama/llama.cpp/src/llama-kv-cache.h index 1868f118572..e194bf3e26f 100644 --- a/llama/llama.cpp/src/llama-kv-cache.h +++ b/llama/llama.cpp/src/llama-kv-cache.h @@ -257,8 +257,6 @@ class llama_kv_cache : public llama_memory_i { size_t size_k_bytes() const; size_t size_v_bytes() const; - bool is_masked_swa(llama_pos p0, llama_pos p1) const; - ggml_tensor * build_rope_shift( const llama_cparams & cparams, ggml_context * ctx, @@ -305,7 +303,7 @@ class llama_kv_cache_context : public llama_memory_context_i { bool do_shift, stream_copy_info sc_info); - // used to create a batch procesing context from a batch + // used to create a batch processing context from a batch llama_kv_cache_context( llama_kv_cache * kv, slot_info_vec_t sinfos, diff --git a/llama/llama.cpp/src/llama-memory-hybrid-iswa.cpp b/llama/llama.cpp/src/llama-memory-hybrid-iswa.cpp new file mode 100644 index 00000000000..411769672af --- /dev/null +++ b/llama/llama.cpp/src/llama-memory-hybrid-iswa.cpp @@ -0,0 +1,275 @@ +#include "llama-memory-hybrid-iswa.h" + +#include "llama-impl.h" +#include "llama-model.h" +#include "llama-context.h" + +// +// llama_memory_hybrid_iswa +// + +llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( + const llama_model & model, + /* attn */ + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool swa_full, + uint32_t kv_size, + uint32_t n_ubatch, + uint32_t n_pad, + /* recurrent */ + ggml_type type_r, + ggml_type type_s, + uint32_t rs_size, + /* common */ + uint32_t n_seq_max, + bool offload, + bool unified, + /* layer filters */ + const layer_filter_cb & filter_attn, + const layer_filter_cb & filter_recr) : + hparams(model.hparams), + mem_attn(new llama_kv_cache_iswa( + model, + type_k, + type_v, + v_trans, + offload, + swa_full, + unified, + kv_size, + n_seq_max, + n_ubatch, + n_pad, + filter_attn == nullptr ? + [&](int32_t il) { return !hparams.is_recurrent(il); } + : filter_attn, + nullptr + )), + mem_recr(new llama_memory_recurrent( + model, + type_r, + type_s, + offload, + rs_size, + n_seq_max, + filter_recr == nullptr ? + [&](int32_t il) { return hparams.is_recurrent(il); } + : filter_recr + )) {} + +llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { + do { + balloc.split_reset(); + + // follow the recurrent pattern for creating the ubatch splits + std::vector ubatches; + + while (true) { + llama_ubatch ubatch; + + if (embd_all) { + // if all tokens are output, split by sequence + ubatch = balloc.split_seq(n_ubatch); + } else { + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); + } + + if (ubatch.n_tokens == 0) { + break; + } + + ubatches.push_back(std::move(ubatch)); // NOLINT + } + + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + + // prepare the recurrent batches first + if (!mem_recr->prepare(ubatches)) { + // TODO: will the recurrent cache be in an undefined context at this point? + LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + // prepare the attention cache (iswa version returns both base and swa slot infos) + auto sinfos_base = mem_attn->get_base()->prepare(ubatches); + if (sinfos_base.empty()) { + LLAMA_LOG_ERROR("%s: failed to prepare attention base ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + auto sinfos_swa = mem_attn->get_swa()->prepare(ubatches); + if (sinfos_swa.empty()) { + LLAMA_LOG_ERROR("%s: failed to prepare attention swa ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique( + this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches)); + } while(false); + + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); +} + +llama_memory_context_ptr llama_memory_hybrid_iswa::init_full() { + return std::make_unique(this); +} + +llama_memory_context_ptr llama_memory_hybrid_iswa::init_update(llama_context * lctx, bool optimize) { + return std::make_unique(this, lctx, optimize); +} + +bool llama_memory_hybrid_iswa::get_can_shift() const { + // Shifting is trivially supported for recurrent + return mem_attn->get_can_shift(); +} + +void llama_memory_hybrid_iswa::clear(bool data) { + mem_attn->clear(data); + mem_recr->clear(data); +} + +bool llama_memory_hybrid_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + // Try removing from the recurrent cache first since it may fail. If it does + // fail, the cache will not have been mutated. + if (!mem_recr->seq_rm(seq_id, p0, p1)) { + return false; + } + return mem_attn->seq_rm(seq_id, p0, p1); +} + +void llama_memory_hybrid_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1); + mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1); +} + +void llama_memory_hybrid_iswa::seq_keep(llama_seq_id seq_id) { + mem_attn->seq_keep(seq_id); + mem_recr->seq_keep(seq_id); +} + +void llama_memory_hybrid_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + mem_attn->seq_add(seq_id, p0, p1, shift); + mem_recr->seq_add(seq_id, p0, p1, shift); +} + +void llama_memory_hybrid_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + mem_attn->seq_div(seq_id, p0, p1, d); + mem_recr->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_memory_hybrid_iswa::seq_pos_min(llama_seq_id seq_id) const { + // the min of the total cache is the max of the two caches' min values + return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id)); +} + +llama_pos llama_memory_hybrid_iswa::seq_pos_max(llama_seq_id seq_id) const { + // the max of the total cache is the min of the two caches' max values + return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id)); +} + +std::map llama_memory_hybrid_iswa::memory_breakdown() const { + std::map mb = mem_attn->memory_breakdown(); + for (const auto & buft_size : mem_recr->memory_breakdown()) { + mb[buft_size.first] += buft_size.second; + } + return mb; +} + +void llama_memory_hybrid_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + mem_attn->state_write(io, seq_id, flags); + mem_recr->state_write(io, seq_id, flags); +} + +void llama_memory_hybrid_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + mem_attn->state_read(io, seq_id, flags); + mem_recr->state_read(io, seq_id, flags); +} + +llama_kv_cache_iswa * llama_memory_hybrid_iswa::get_mem_attn() const { + return mem_attn.get(); +} + +llama_memory_recurrent * llama_memory_hybrid_iswa::get_mem_recr() const { + return mem_recr.get(); +} + +// +// llama_memory_hybrid_iswa_context +// + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_status status) : status(status) {} + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem) : + ctx_attn(mem->get_mem_attn()->init_full()), + ctx_recr(mem->get_mem_recr()->init_full()), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { +} + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + llama_context * lctx, + bool optimize) : + ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)), + ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { +} + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + slot_info_vec_t sinfos_base, + slot_info_vec_t sinfos_swa, + std::vector ubatches) : + ubatches(std::move(ubatches)), + // note: here we copy the ubatches. not sure if this is ideal + ctx_attn(new llama_kv_cache_iswa_context(mem->get_mem_attn(), std::move(sinfos_base), std::move(sinfos_swa), this->ubatches)), + ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { +} + +bool llama_memory_hybrid_iswa_context::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + ctx_attn->next(); + ctx_recr->next(); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_memory_hybrid_iswa_context::apply() { + assert(!llama_memory_status_is_fail(status)); + + bool res = true; + + res = res & ctx_attn->apply(); + res = res & ctx_recr->apply(); + + return res; +} + +llama_memory_status llama_memory_hybrid_iswa_context::get_status() const { + return status; +} + +const llama_ubatch & llama_memory_hybrid_iswa_context::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + return ubatches[i_next]; +} + +const llama_kv_cache_iswa_context * llama_memory_hybrid_iswa_context::get_attn() const { + return static_cast(ctx_attn.get()); +} + +const llama_memory_recurrent_context * llama_memory_hybrid_iswa_context::get_recr() const { + return static_cast(ctx_recr.get()); +} diff --git a/llama/llama.cpp/src/llama-memory-hybrid-iswa.h b/llama/llama.cpp/src/llama-memory-hybrid-iswa.h new file mode 100644 index 00000000000..807c8aac96c --- /dev/null +++ b/llama/llama.cpp/src/llama-memory-hybrid-iswa.h @@ -0,0 +1,140 @@ +#pragma once + +#include "llama-batch.h" +#include "llama-graph.h" +#include "llama-kv-cache-iswa.h" +#include "llama-memory.h" +#include "llama-memory-recurrent.h" + +#include +#include + +// +// llama_memory_hybrid_iswa +// + +// utilizes instances of llama_memory_recurrent and llama_kv_cache_iswa to +// support models where each layer may be either attention-based (with SWA support) or recurrent + +class llama_memory_hybrid_iswa : public llama_memory_i { +public: + llama_memory_hybrid_iswa( + const llama_model & model, + /* attn */ + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool swa_full, + uint32_t kv_size, + uint32_t n_ubatch, + uint32_t n_pad, + /* recurrent */ + ggml_type type_r, + ggml_type type_s, + uint32_t rs_size, + /* common */ + uint32_t n_seq_max, + bool offload, + bool unified, + /* layer filters */ + const layer_filter_cb & filter_attn = nullptr, + const layer_filter_cb & filter_recr = nullptr); + + ~llama_memory_hybrid_iswa() = default; + + // + // llama_memory_i + // + + llama_memory_context_ptr init_batch( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + bool embd_all) override; + + llama_memory_context_ptr init_full() override; + + llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; + + bool get_can_shift() const override; + + void clear(bool data) override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + std::map memory_breakdown() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; + + // + // llama_memory_hybrid_iswa specific API + // + + llama_kv_cache_iswa * get_mem_attn() const; + llama_memory_recurrent * get_mem_recr() const; + +private: + const llama_hparams & hparams; + + const std::unique_ptr mem_attn; + const std::unique_ptr mem_recr; +}; + +class llama_memory_hybrid_iswa_context : public llama_memory_context_i { +public: + using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; + + // init failure + explicit llama_memory_hybrid_iswa_context(llama_memory_status status); + + // init full + explicit llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem); + + // init update + explicit llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + llama_context * lctx, + bool optimize); + + // init success + llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + slot_info_vec_t sinfos_base, + slot_info_vec_t sinfos_swa, + std::vector ubatches); + + ~llama_memory_hybrid_iswa_context() = default; + + bool next() override; + bool apply() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_memory_hybrid_iswa_context + // + + const llama_kv_cache_iswa_context * get_attn() const; + const llama_memory_recurrent_context * get_recr() const; + +private: + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector ubatches; + + const llama_memory_context_ptr ctx_attn; + const llama_memory_context_ptr ctx_recr; + + const llama_memory_status status; +}; diff --git a/llama/llama.cpp/src/llama-memory-recurrent.cpp b/llama/llama.cpp/src/llama-memory-recurrent.cpp index 812bf253049..6e8413f493d 100644 --- a/llama/llama.cpp/src/llama-memory-recurrent.cpp +++ b/llama/llama.cpp/src/llama-memory-recurrent.cpp @@ -163,7 +163,7 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos const auto & cell = cells[tail_id]; // partial intersection is invalid if it includes the final pos if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) { - //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n"); + //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1); return false; } // invalidate tails which will be cleared @@ -785,23 +785,21 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: io.write(&s_trans, sizeof(s_trans)); io.write(&n_layer, sizeof(n_layer)); - std::vector tmp_buf; - - // Iterate and write all the keys first, each row is a cell + // Iterate and write all the R tensors first, each row is a cell // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) if (r_l[il] == nullptr) continue; - // Write key type + // Write R tensor type const int32_t r_type_i = (int32_t)r_l[il]->type; io.write(&r_type_i, sizeof(r_type_i)); - // Write row size of key + // Write row size of R tensor const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); io.write(&r_size_row, sizeof(r_size_row)); - // Read each range of cells of k_size length each into tmp_buf and write out + // Write each range of cells of r_size_row length for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * r_size_row; @@ -814,15 +812,15 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) if (s_l[il] == nullptr) continue; - // Write value type + // Write S tensor type const int32_t s_type_i = (int32_t)s_l[il]->type; io.write(&s_type_i, sizeof(s_type_i)); - // Write row size of value + // Write row size of S tensor const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); io.write(&s_size_row, sizeof(s_size_row)); - // Read each range of cells of s_size length each into tmp_buf and write out + // Write each range of S tensor rows for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * s_size_row; @@ -830,7 +828,7 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: } } } else { - // When v is transposed, we also need the element size and get the element ranges from each row + // When S tensor is transposed, we also need the element size and get the element ranges from each row const uint32_t mem_size = size; for (uint32_t il = 0; il < n_layer; ++il) { // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) @@ -838,7 +836,7 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: const uint32_t n_embd_s = hparams.n_embd_s(); - // Write value type + // Write S tensor type const int32_t s_type_i = (int32_t)s_l[il]->type; io.write(&s_type_i, sizeof(s_type_i)); @@ -851,7 +849,7 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // For each row, we get the element values of each cell for (uint32_t j = 0; j < n_embd_s; ++j) { - // Read each range of cells of v_size_el length each into tmp_buf and write out + // Write each range of cells of s_size_el length for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t src_offset = (range.first + j * mem_size) * s_size_el; diff --git a/llama/llama.cpp/src/llama-mmap.cpp b/llama/llama.cpp/src/llama-mmap.cpp index 0641c2d22f6..c03228e9ce2 100644 --- a/llama/llama.cpp/src/llama-mmap.cpp +++ b/llama/llama.cpp/src/llama-mmap.cpp @@ -13,9 +13,10 @@ #ifdef __has_include #if __has_include() #include + #include + #include #if defined(_POSIX_MAPPED_FILES) #include - #include #endif #if defined(_POSIX_MEMLOCK_RANGE) #include @@ -74,7 +75,7 @@ struct llama_file::impl { return ret; } - impl(const char * fname, const char * mode) { + impl(const char * fname, const char * mode, [[maybe_unused]] const bool use_direct_io = false) { fp = ggml_fopen(fname, mode); if (fp == NULL) { throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); @@ -109,7 +110,7 @@ struct llama_file::impl { } } - void read_raw(void * ptr, size_t len) const { + void read_raw(void * ptr, size_t len) { size_t bytes_read = 0; while (bytes_read < len) { size_t chunk_size = std::min(len - bytes_read, 64*1024*1024); @@ -126,7 +127,7 @@ struct llama_file::impl { } } - uint32_t read_u32() const { + uint32_t read_u32() { uint32_t val; read_raw(&val, sizeof(val)); return val; @@ -153,16 +154,55 @@ struct llama_file::impl { write_raw(&val, sizeof(val)); } + bool has_direct_io() const { + return true; + } + ~impl() { if (fp) { std::fclose(fp); } } #else - impl(const char * fname, const char * mode) { - fp = ggml_fopen(fname, mode); + impl(const char * fname, const char * mode, [[maybe_unused]] const bool use_direct_io = false) : fname(fname) { +#ifdef __linux__ + // Try unbuffered I/O for read only + if (use_direct_io && std::strcmp(mode, "rb") == 0) { + if (init_fd()) { + return; + } + LLAMA_LOG_WARN("Failed to open file '%s' with error: %s. Falling back to buffered I/O", + fname, strerror(errno)); + } +#endif + init_fp(mode); + } + +#ifdef __linux__ + bool init_fd() { + fd = open(fname.c_str(), O_RDONLY | O_DIRECT); + + if (fd != -1) { + struct stat file_stats{}; + fstat(fd, &file_stats); + + size = file_stats.st_size; + alignment = file_stats.st_blksize; + + off_t ret = lseek(fd, 0, SEEK_SET); + if (ret == -1) { + throw std::runtime_error(format("seek error: %s", strerror(errno))); + } + return true; + } + return false; + } +#endif + + void init_fp(const char * mode) { + fp = ggml_fopen(fname.c_str(), mode); if (fp == NULL) { - throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); + throw std::runtime_error(format("failed to open %s: %s", fname.c_str(), strerror(errno))); } seek(0, SEEK_END); size = tell(); @@ -170,46 +210,122 @@ struct llama_file::impl { } size_t tell() const { -// TODO: this ifdef is never true? -#ifdef _WIN32 - __int64 ret = _ftelli64(fp); -#else - long ret = std::ftell(fp); -#endif - if (ret == -1) { - throw std::runtime_error(format("ftell error: %s", strerror(errno))); + if (fd == -1) { + long ret = std::ftell(fp); + if (ret == -1) { + throw std::runtime_error(format("ftell error: %s", strerror(errno))); + } + + return (size_t) ret; } - return (size_t) ret; + off_t pos = lseek(fd, 0, SEEK_CUR); + if (pos == -1) { + throw std::runtime_error(format("lseek error: %s", strerror(errno))); + } + return (size_t) pos; } void seek(size_t offset, int whence) const { -// TODO: this ifdef is never true? -#ifdef _WIN32 - int ret = _fseeki64(fp, (__int64) offset, whence); -#else - int ret = std::fseek(fp, (long) offset, whence); -#endif - if (ret != 0) { + off_t ret = 0; + if (fd == -1) { + ret = std::fseek(fp, (long) offset, whence); + } else { + ret = lseek(fd, offset, whence); + } + if (ret == -1) { throw std::runtime_error(format("seek error: %s", strerror(errno))); } } - void read_raw(void * ptr, size_t len) const { + void read_raw_unsafe(void * ptr, size_t len) { if (len == 0) { return; } errno = 0; - std::size_t ret = std::fread(ptr, len, 1, fp); - if (ferror(fp)) { - throw std::runtime_error(format("read error: %s", strerror(errno))); + if (fd == -1) { + const size_t curr_off = tell(); + const size_t to_read = std::min(len, size - curr_off); + + std::size_t ret = std::fread(ptr, to_read, 1, fp); + if (ferror(fp)) { + throw std::runtime_error(format("read error: %s", strerror(errno))); + } + if (to_read > 0 && ret != 1) { + throw std::runtime_error("unexpectedly reached end of file"); + } + } else { + size_t bytes_read = 0; + while (bytes_read < len) { + const size_t to_read = len - bytes_read; + ssize_t ret = ::read(fd, reinterpret_cast(ptr) + bytes_read, to_read); + + if (ret == -1) { + if (errno == EINTR) { + continue; // Interrupted by signal, retry + } + // Fallback to std::fread in case the DMA controller cannot access the buffer + if (errno == EFAULT || errno == EINVAL) { + LLAMA_LOG_WARN("%s: Falling back to buffered IO due to %s\n", __func__, strerror(errno)); + auto curr_off = tell(); + close(fd); + fd = -1; + alignment = 1; + init_fp("rb"); + seek(curr_off, SEEK_SET); + read_raw_unsafe(ptr, len); + return; + } + throw std::runtime_error(format("read error: %s", strerror(errno))); + } + if (ret == 0) { + // EOF: allow if this read was only pulling alignment padding past file end + off_t pos = lseek(fd, 0, SEEK_CUR); + if (pos != -1 && (size_t) pos == size) { + std::memset(reinterpret_cast(ptr) + bytes_read, 0, len - bytes_read); + return; + } + throw std::runtime_error("unexpectedly reached end of file"); + } + + bytes_read += (size_t) ret; + } } - if (ret != 1) { - throw std::runtime_error("unexpectedly reached end of file"); + } + + void read_aligned_chunk(void * dest, size_t size) { + size_t offset = tell(); + off_t aligned_offset = offset & ~(alignment - 1); + off_t offset_from_alignment = offset - aligned_offset; + size_t bytes_to_read = (offset_from_alignment + size + alignment - 1) & ~(alignment - 1); + + void * raw_buffer = nullptr; + int ret = posix_memalign(&raw_buffer, alignment, bytes_to_read); + if (ret != 0) { + throw std::runtime_error(format("posix_memalign failed with error %d", ret)); } + + struct aligned_buffer_deleter { + void operator()(void * p) const { free(p); } + }; + std::unique_ptr buffer(raw_buffer); + + seek(aligned_offset, SEEK_SET); + read_raw_unsafe(buffer.get(), bytes_to_read); + + uintptr_t actual_data = reinterpret_cast(buffer.get()) + offset_from_alignment; + memcpy(dest, reinterpret_cast(actual_data), size); } - uint32_t read_u32() const { + void read_raw(void * ptr, size_t len) { + if (has_direct_io()) { + read_aligned_chunk(ptr, len); + } else { + read_raw_unsafe(ptr, len); + } + } + + uint32_t read_u32() { uint32_t ret; read_raw(&ret, sizeof(ret)); return ret; @@ -230,27 +346,48 @@ struct llama_file::impl { write_raw(&val, sizeof(val)); } + bool has_direct_io() const { + return fd != -1 && alignment > 1; + } + ~impl() { - if (fp) { + if (fd != -1) { + close(fd); + } else { std::fclose(fp); } } + int fd = -1; + std::string fname; #endif - FILE * fp; - size_t size; + size_t read_alignment() const { + return alignment; + } + + size_t alignment = 1; + + FILE * fp{}; + size_t size{}; }; -llama_file::llama_file(const char * fname, const char * mode) : pimpl(std::make_unique(fname, mode)) {} +llama_file::llama_file(const char * fname, const char * mode, const bool use_direct_io) : + pimpl(std::make_unique(fname, mode, use_direct_io)) {} llama_file::~llama_file() = default; size_t llama_file::tell() const { return pimpl->tell(); } size_t llama_file::size() const { return pimpl->size; } +size_t llama_file::read_alignment() const { return pimpl->read_alignment(); } +bool llama_file::has_direct_io() const { return pimpl->has_direct_io(); } + int llama_file::file_id() const { #ifdef _WIN32 return _fileno(pimpl->fp); #else + if (pimpl->fd != -1) { + return pimpl->fd; + } #if defined(fileno) return fileno(pimpl->fp); #else @@ -260,9 +397,14 @@ int llama_file::file_id() const { } void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); } -void llama_file::read_raw(void * ptr, size_t len) const { pimpl->read_raw(ptr, len); } +void llama_file::read_raw(void * ptr, size_t len) { pimpl->read_raw(ptr, len); } +#ifdef _WIN32 +void llama_file::read_raw_unsafe(void * ptr, size_t len) { pimpl->read_raw(ptr, len); } +#else +void llama_file::read_raw_unsafe(void * ptr, size_t len) { pimpl->read_raw_unsafe(ptr, len); } +#endif -uint32_t llama_file::read_u32() const { return pimpl->read_u32(); } +uint32_t llama_file::read_u32() { return pimpl->read_u32(); } void llama_file::write_raw(const void * ptr, size_t len) const { pimpl->write_raw(ptr, len); } void llama_file::write_u32(uint32_t val) const { pimpl->write_u32(val); } @@ -362,6 +504,8 @@ struct llama_mmap::impl { } } #elif defined(_WIN32) + HANDLE hMapping = nullptr; + impl(struct llama_file * file, size_t prefetch, bool numa) { GGML_UNUSED(numa); @@ -369,7 +513,7 @@ struct llama_mmap::impl { HANDLE hFile = (HANDLE) _get_osfhandle(file->file_id()); - HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); + hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); if (hMapping == NULL) { DWORD error = GetLastError(); @@ -378,9 +522,9 @@ struct llama_mmap::impl { addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0); DWORD error = GetLastError(); - CloseHandle(hMapping); if (addr == NULL) { + CloseHandle(hMapping); throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str())); } @@ -412,9 +556,17 @@ struct llama_mmap::impl { } ~impl() { - if (!UnmapViewOfFile(addr)) { - LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n", - llama_format_win_err(GetLastError()).c_str()); + if (hMapping) { + if (addr) { + if (!UnmapViewOfFile(addr)) { + LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } + if (!CloseHandle(hMapping)) { + LLAMA_LOG_WARN("warning: CloseHandle failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } } } #else @@ -476,9 +628,9 @@ struct llama_mlock::impl { char* errmsg = std::strerror(errno); bool suggest = (errno == ENOMEM); -#if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV) || defined(_AIX) - // visionOS/tvOS dont't support RLIMIT_MEMLOCK - // Skip resource limit checks on visionOS/tvOS +#if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV) || defined(_AIX) || defined(__HAIKU__) + // visionOS/tvOS/Haiku don't support RLIMIT_MEMLOCK + // Skip resource limit checks on these platforms suggest = false; #else struct rlimit lock_limit; diff --git a/llama/llama.cpp/src/llama-mmap.h b/llama/llama.cpp/src/llama-mmap.h index 4e5aec3f440..29ce4d24685 100644 --- a/llama/llama.cpp/src/llama-mmap.h +++ b/llama/llama.cpp/src/llama-mmap.h @@ -3,6 +3,7 @@ #include #include #include +#include struct llama_file; struct llama_mmap; @@ -13,7 +14,7 @@ using llama_mmaps = std::vector>; using llama_mlocks = std::vector>; struct llama_file { - llama_file(const char * fname, const char * mode); + llama_file(const char * fname, const char * mode, bool use_direct_io = false); ~llama_file(); size_t tell() const; @@ -23,12 +24,16 @@ struct llama_file { void seek(size_t offset, int whence) const; - void read_raw(void * ptr, size_t len) const; - uint32_t read_u32() const; + void read_raw(void * ptr, size_t len); + void read_raw_unsafe(void * ptr, size_t len); + void read_aligned_chunk(void * dest, size_t size); + uint32_t read_u32(); void write_raw(const void * ptr, size_t len) const; void write_u32(uint32_t val) const; + size_t read_alignment() const; + bool has_direct_io() const; private: struct impl; std::unique_ptr pimpl; diff --git a/llama/llama.cpp/src/llama-model-loader.cpp b/llama/llama.cpp/src/llama-model-loader.cpp index 8916a6242f4..37b69a4b366 100644 --- a/llama/llama.cpp/src/llama-model-loader.cpp +++ b/llama/llama.cpp/src/llama-model-loader.cpp @@ -2,6 +2,7 @@ #include "ggml.h" +#include #include #include #include @@ -344,6 +345,7 @@ namespace GGUFMeta { GGUFMeta::GKV::get_kv(ctx, kid); switch (arr_info.gt) { + case GGUF_TYPE_BOOL: case GGUF_TYPE_UINT32: case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same::value) || (std::is_same::value)); break; @@ -365,7 +367,13 @@ namespace GGUFMeta { result[i] = value; } } else { - std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); + if (arr_info.gt == GGUF_TYPE_BOOL) { + std::transform((const bool *)arr_info.data, (const bool *)arr_info.data + arr_info.length, result.begin(), [](bool x) { + return static_cast(x); + }); + } else { + std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); + } } return true; @@ -462,6 +470,29 @@ namespace GGUFMeta { return get_key_or_arr(llm_kv(kid), result, n, required); } + bool llama_model_loader::get_key_or_arr(enum llm_kv kid, uint32_t & result, bool required) { + const std::string key = llm_kv(kid); + + const int id = gguf_find_key(meta.get(), key.c_str()); + + if (id < 0) { + if (required) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + return false; + } + + // throw and error if type is an array + if (gguf_get_kv_type(meta.get(), id) == GGUF_TYPE_ARRAY) { + if (required) { + throw std::runtime_error(format("expected scalar, found array for key: %s", key.c_str())); + } + return false; + } + + return get_key(key, result, required); + } + // TODO: this is not very clever - figure out something better template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); @@ -472,6 +503,7 @@ llama_model_loader::llama_model_loader( const std::string & fname, std::vector & splits, bool use_mmap, + bool use_direct_io, bool check_tensors, bool no_alloc, const llama_model_kv_override * param_overrides_p, @@ -504,9 +536,23 @@ llama_model_loader::llama_model_loader( get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); llm_kv = LLM_KV(llm_arch_from_string(arch_name)); - files.emplace_back(new llama_file(fname.c_str(), "rb")); + files.emplace_back(new llama_file(fname.c_str(), "rb", use_direct_io)); contexts.emplace_back(ctx); + if (use_mmap && use_direct_io) { + if (files.back()->has_direct_io()) { + LLAMA_LOG_WARN("%s: direct I/O is enabled, disabling mmap\n", __func__); + use_mmap = false; + } else { + LLAMA_LOG_WARN("%s: direct I/O is not available, using mmap\n", __func__); + use_direct_io = false; + + // reopen file using std::fopen for mmap + files.pop_back(); + files.emplace_back(new llama_file(fname.c_str(), "rb", false)); + } + } + // Save tensors data offset of the main file. // For subsidiary files, `meta` tensor data offset must not be used, // so we build a unified tensors index for weights. @@ -572,7 +618,7 @@ llama_model_loader::llama_model_loader( } } - files.emplace_back(new llama_file(fname_split, "rb")); + files.emplace_back(new llama_file(fname_split, "rb", use_direct_io)); contexts.emplace_back(ctx); // Save tensors data offset info of the shard. @@ -716,6 +762,7 @@ llama_model_loader::llama_model_loader( } this->use_mmap = use_mmap; + this->use_direct_io = use_direct_io; this->check_tensors = check_tensors; this->no_alloc = no_alloc; } @@ -935,7 +982,15 @@ bool llama_model_loader::load_all_data( // 4 staging buffers for async uploads, each sized 1MB seems to be a good default for single NVMe drives. // NVMe raid configurations might require more / larger buffers. constexpr size_t n_buffers = 4; - constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB + + size_t alignment = 1; + for (const auto & file : files) { + alignment = std::max(file->read_alignment(), alignment); + } + + // Buffer size: balance between memory usage and I/O efficiency + // 64MB works well for NVMe drives + const size_t buffer_size = alignment != 1 ? 64 * 1024 * 1024 + 2 * alignment : 1 * 1024 * 1024; std::vector host_buffers; std::vector events; @@ -985,6 +1040,7 @@ bool llama_model_loader::load_all_data( // If the backend is supported, create pinned memory buffers and events for synchronisation. for (size_t idx = 0; idx < n_buffers; ++idx) { auto * buf = ggml_backend_buft_alloc_buffer(host_buft, buffer_size); + if (!buf) { LLAMA_LOG_DEBUG("%s: failed to allocate host buffer for async uploads for device %s\n", func, ggml_backend_dev_name(dev)); @@ -1066,6 +1122,7 @@ bool llama_model_loader::load_all_data( } } else { const auto & file = files.at(weight->idx); + if (ggml_backend_buffer_is_host(cur->buffer)) { file->seek(weight->offs, SEEK_SET); file->read_raw(cur->data, n_size); @@ -1077,19 +1134,54 @@ bool llama_model_loader::load_all_data( } else { // If upload_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU. if (upload_backend) { - file->seek(weight->offs, SEEK_SET); + size_t offset = weight->offs; + alignment = file->read_alignment(); + size_t aligned_offset = offset & ~(alignment - 1); + size_t offset_from_alignment = offset - aligned_offset; + file->seek(aligned_offset, SEEK_SET); + + // Calculate aligned read boundaries + size_t read_start = aligned_offset; + size_t read_end = (offset + n_size + alignment - 1) & ~(alignment - 1); size_t bytes_read = 0; + size_t data_read = 0; // Actual tensor data copied (excluding padding) + + while (bytes_read < read_end - read_start) { + size_t read_size = std::min(buffer_size, read_end - read_start - bytes_read); - while (bytes_read < n_size) { - size_t read_iteration = std::min(buffer_size, n_size - bytes_read); + // Align the destination pointer within the pinned buffer + uintptr_t ptr_dest_aligned = (reinterpret_cast(host_ptrs[buffer_idx]) + alignment - 1) & ~(alignment - 1); + // Wait for previous upload to complete before reusing buffer ggml_backend_event_synchronize(events[buffer_idx]); - file->read_raw(host_ptrs[buffer_idx], read_iteration); - ggml_backend_tensor_set_async(upload_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration); + + // Read aligned chunk from file + file->read_raw_unsafe(reinterpret_cast(ptr_dest_aligned), read_size); + + // Calculate actual data portion (excluding alignment padding) + uintptr_t ptr_data = ptr_dest_aligned; + size_t data_to_copy = read_size; + + // Skip alignment padding at start of first chunk + if (bytes_read == 0) { + ptr_data += offset_from_alignment; + data_to_copy -= offset_from_alignment; + } + + // Trim alignment padding at end of last chunk + if (aligned_offset + bytes_read + read_size > offset + n_size) { + data_to_copy -= (read_end - (offset + n_size)); + } + + // Async upload actual data to GPU + ggml_backend_tensor_set_async(upload_backend, cur, + reinterpret_cast(ptr_data), data_read, data_to_copy); ggml_backend_event_record(events[buffer_idx], upload_backend); - bytes_read += read_iteration; + data_read += data_to_copy; + bytes_read += read_size; + ++buffer_idx; buffer_idx %= n_buffers; } diff --git a/llama/llama.cpp/src/llama-model-loader.h b/llama/llama.cpp/src/llama-model-loader.h index 0380c92fde0..65953dd3d5a 100644 --- a/llama/llama.cpp/src/llama-model-loader.h +++ b/llama/llama.cpp/src/llama-model-loader.h @@ -70,6 +70,7 @@ struct llama_model_loader { size_t n_bytes = 0; bool use_mmap = false; + bool use_direct_io = false; bool check_tensors; bool no_alloc; @@ -97,6 +98,7 @@ struct llama_model_loader { const std::string & fname, std::vector & splits, // optional, only need if the split does not follow naming scheme bool use_mmap, + bool use_direct_io, bool check_tensors, bool no_alloc, const llama_model_kv_override * param_overrides_p, @@ -131,6 +133,8 @@ struct llama_model_loader { template bool get_key_or_arr(enum llm_kv kid, T & result, uint32_t n, bool required = true); + bool get_key_or_arr(enum llm_kv kid, uint32_t & result, bool required = true); + std::string get_arch_name() const; enum llm_arch get_arch() const; diff --git a/llama/llama.cpp/src/llama-model-saver.cpp b/llama/llama.cpp/src/llama-model-saver.cpp index 563823dc35d..676efeda709 100644 --- a/llama/llama.cpp/src/llama-model-saver.cpp +++ b/llama/llama.cpp/src/llama-model-saver.cpp @@ -146,6 +146,9 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_VOCAB_SIZE, vocab.n_tokens()); add_kv(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); add_kv(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); + if (hparams.n_embd_out_impl > 0) { + add_kv(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out_impl); + } add_kv(LLM_KV_BLOCK_COUNT, hparams.n_layer); add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); add_kv(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, true); @@ -268,6 +271,7 @@ void llama_model_saver::add_tensors_from_model() { add_tensor(model.cls_b); add_tensor(model.cls_out); add_tensor(model.cls_out_b); + add_tensor(model.cls_norm); for (const struct llama_layer & layer : model.layers) { for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) { diff --git a/llama/llama.cpp/src/llama-model.cpp b/llama/llama.cpp/src/llama-model.cpp index 00cd579e02e..6436cde36c3 100644 --- a/llama/llama.cpp/src/llama-model.cpp +++ b/llama/llama.cpp/src/llama-model.cpp @@ -8,6 +8,7 @@ #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" #include "llama-memory-hybrid.h" +#include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" #include "ggml-cpp.h" @@ -31,12 +32,14 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_17M: return "17M"; case LLM_TYPE_22M: return "22M"; case LLM_TYPE_33M: return "33M"; + case LLM_TYPE_47M: return "47M"; case LLM_TYPE_60M: return "60M"; case LLM_TYPE_70M: return "70M"; case LLM_TYPE_80M: return "80M"; case LLM_TYPE_109M: return "109M"; case LLM_TYPE_137M: return "137M"; case LLM_TYPE_140M: return "140M"; + case LLM_TYPE_149M: return "149M"; case LLM_TYPE_160M: return "160M"; case LLM_TYPE_190M: return "190M"; case LLM_TYPE_220M: return "220M"; @@ -46,6 +49,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_335M: return "335M"; case LLM_TYPE_350M: return "350M"; case LLM_TYPE_360M: return "360M"; + case LLM_TYPE_395M: return "395M"; case LLM_TYPE_410M: return "410M"; case LLM_TYPE_450M: return "450M"; case LLM_TYPE_475M: return "475M"; @@ -119,15 +123,22 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_8B_A1B: return "8B.A1B"; case LLM_TYPE_16B_A1B: return "16B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; + case LLM_TYPE_24B_A2B: return "24B.A2B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_31B_A3_5B: return "31B.A3.5B"; + case LLM_TYPE_35B_A3B: return "35B.A3B"; + case LLM_TYPE_48B_A3B: return "48B.A3B"; case LLM_TYPE_80B_A3B: return "80B.A3B"; case LLM_TYPE_100B_A6B: return "100B.A6B"; + case LLM_TYPE_102B_A12B: return "102B.A12B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; + case LLM_TYPE_196B_A11B: return "196B.A11B"; case LLM_TYPE_230B_A10B: return "230B.A10B"; case LLM_TYPE_235B_A22B: return "235B.A22B"; case LLM_TYPE_300B_A47B: return "300B.A47B"; + case LLM_TYPE_310B_A15B: return "310B.A15B"; case LLM_TYPE_355B_A32B: return "355B.A32B"; + case LLM_TYPE_744B_A40B: return "744B.A40B"; case LLM_TYPE_E2B: return "E2B"; case LLM_TYPE_E4B: return "E4B"; default: return "?B"; @@ -441,7 +452,7 @@ struct llama_model::impl { llama_mlocks mlock_bufs; llama_mlocks mlock_mmaps; - // contexts where the model tensors metadata is stored as well ass the corresponding buffers: + // contexts where the model tensors metadata is stored as well as the corresponding buffers: std::vector>> ctxs_bufs; buft_list_t cpu_buft_list; @@ -463,7 +474,11 @@ llama_model::llama_model(const llama_model_params & params) : params(params), pi pimpl->has_tensor_overrides = params.tensor_buft_overrides && params.tensor_buft_overrides[0].pattern; } -llama_model::~llama_model() = default; +llama_model::~llama_model() { + for (auto * lora : loras) { + delete lora; + } +} void llama_model::load_stats(llama_model_loader & ml) { pimpl->n_elements = ml.n_elements; @@ -502,6 +517,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); + ml.get_key(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out_impl, false); ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); @@ -509,7 +525,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used, false); if (arch == LLM_ARCH_WAVTOKENIZER_DEC) { - ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features); + ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd); + ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd_out_impl); ml.get_key(LLM_KV_POSNET_EMBEDDING_LENGTH, hparams.posnet.n_embd); ml.get_key(LLM_KV_POSNET_BLOCK_COUNT, hparams.posnet.n_layer); @@ -548,6 +565,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { std::fill(hparams.xielu_alpha_p.begin(), hparams.xielu_alpha_p.end(), 0.0f); std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0.0f); std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f); + std::fill(hparams.swiglu_clamp_exp.begin(), hparams.swiglu_clamp_exp.end(), 0.0f); + std::fill(hparams.swiglu_clamp_shexp.begin(), hparams.swiglu_clamp_shexp.end(), 0.0f); ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); @@ -573,6 +592,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling); GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED); + // TODO: Handle SWA metadata similarly when models start implementing it // rope_freq_scale (inverse of the kv) is optional float ropescale = 0.0f; if (!ml.get_key(LLM_KV_ROPE_SCALING_FACTOR, ropescale, false)) { @@ -581,10 +601,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { } hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; - // by default assume that the sliding-window layers use the same scaling type as the non-sliding-window layers - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; - ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false); // non-transformer models do not have attention heads @@ -603,7 +619,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); - if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON) { + if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON || arch == LLM_ARCH_LLAMA_EMBED) { if (hparams.n_rot != hparams.n_embd_head_k) { throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); } @@ -627,6 +643,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // arch-specific KVs switch (arch) { case LLM_ARCH_LLAMA: + case LLM_ARCH_LLAMA_EMBED: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -671,6 +688,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.f_attn_temp_scale = 0.1f; hparams.f_attn_temp_offset = 1.0f; hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); } switch (hparams.n_expert) { @@ -716,6 +737,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { if (hparams.n_swa > 0) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.set_swa_pattern(4); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); } else { hparams.swa_type = LLAMA_SWA_TYPE_NONE; } @@ -875,6 +900,34 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_MODERN_BERT: + { + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa > 0) { + uint32_t swa_period = 3; + hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period, true); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); + + switch (hparams.n_layer) { + case 12: + type = LLM_TYPE_47M; break; // granite-embedding-small + case 22: + type = LLM_TYPE_149M; break; // modern-bert-base + case 28: + type = LLM_TYPE_395M; break; // modern-bert-large + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_JINA_BERT_V2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -926,6 +979,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { type = LLM_TYPE_250M; } } break; + case LLM_ARCH_EUROBERT: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + + if (hparams.n_layer == 12) { + type = LLM_TYPE_SMALL; // 0.2B + } + } break; case LLM_ARCH_BLOOM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -1076,6 +1139,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_MAINCODER: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_1B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_QWEN3VL: { ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false); @@ -1194,6 +1265,25 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); } break; + case LLM_ARCH_PLAMO3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa > 0) { + uint32_t swa_period = 8; + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_2B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_GPT2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -1247,7 +1337,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.n_swa = 4096; // default value of gemma 2 hparams.set_swa_pattern(2); hparams.attn_soft_cap = true; + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); @@ -1272,8 +1365,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.set_swa_pattern(6); - hparams.rope_freq_base_train_swa = 10000.0f; - hparams.rope_freq_scale_train_swa = 1.0f; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); } else { hparams.swa_type = LLAMA_SWA_TYPE_NONE; } @@ -1303,10 +1395,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.set_swa_pattern(5); hparams.n_layer_kv_from_start = 20; - hparams.rope_freq_base_train_swa = 10000.0f; - hparams.rope_freq_scale_train_swa = 1.0f; hparams.f_attention_scale = 1.0f; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -1322,9 +1413,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.set_swa_pattern(6); hparams.causal_attn = false; // embeddings do not use causal attention - hparams.rope_freq_base_train_swa = 10000.0f; - hparams.rope_freq_scale_train_swa = 1.0f; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); @@ -1463,7 +1553,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.set_swa_pattern(4); + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -1502,6 +1595,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { if (found_swa && hparams.n_swa > 0) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.set_swa_pattern(4); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = 1.0; // See olmo2.cpp + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); } else { hparams.swa_type = LLAMA_SWA_TYPE_NONE; } @@ -1617,25 +1714,31 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_DEEPSEEK2: { - // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B - bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); + // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B, Kanana-2-30B-A3B + const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26 || (hparams.n_layer == 48 && n_vocab == 128256)); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); if (!is_lite) { ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); } ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); - ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla, false); - ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { // for compatibility with existing DeepSeek V2 and V2.5 GGUFs // that have no expert_gating_func model parameter set - hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; + if ((hparams.n_layer == 47 || hparams.n_layer == 48) && n_vocab == 154880) { + // GLM 4.7 Lite + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } else { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; + } } if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f)) { @@ -1652,6 +1755,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { switch (hparams.n_layer) { case 27: type = LLM_TYPE_16B; break; + case 47: type = LLM_TYPE_30B_A3B; break; case 60: type = LLM_TYPE_236B; break; case 61: type = LLM_TYPE_671B; break; default: type = LLM_TYPE_UNKNOWN; @@ -1691,7 +1795,15 @@ void llama_model::load_hparams(llama_model_loader & ml) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // NextN/MTP parameters (GLM-OCR) + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + + // TODO: when MTP is implemented, this should probably be updated if needed + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + switch (hparams.n_layer) { + case 17: type = LLM_TYPE_1B; break; // GLM-OCR case 40: type = LLM_TYPE_9B; break; case 61: type = LLM_TYPE_32B; break; default: type = LLM_TYPE_UNKNOWN; @@ -1725,10 +1837,55 @@ void llama_model::load_hparams(llama_model_loader & ml) { switch (hparams.n_layer) { case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) + case 48: type = LLM_TYPE_102B_A12B; break; // Solar Open case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer) default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GLM_DSA: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // MoE parameters + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // deepseek MLA parameters + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + + // DSA parameters + ml.get_key(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head); + ml.get_key(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size); + ml.get_key(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k); + + // Expert gating function (GLM-4.5 uses sigmoid) + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + // NextN/MTP parameters + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + + // TODO: when MTP is implemented, this should probably be updated if needed + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + + switch (hparams.n_layer) { + case 79: type = LLM_TYPE_744B_A40B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_BITNET: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -1791,6 +1948,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_JAIS2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8B; break; + case 68: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_NEMOTRON: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -1843,6 +2010,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.n_swa = 4096; hparams.set_swa_pattern(4); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); } ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); @@ -1854,6 +2025,34 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_EXAONE_MOE: + { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.n_swa = 128; + hparams.set_swa_pattern(4); + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_30B_A3B; break; + case 48: + case 49: type = LLM_TYPE_235B_A22B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_RWKV6: case LLM_ARCH_RWKV6QWEN2: { @@ -2071,7 +2270,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: + case LLM_ARCH_PADDLEOCR: { + // paddleocr need mrope_section + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); if (arch == LLM_ARCH_ERNIE4_5_MOE) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); @@ -2160,6 +2363,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.set_swa_pattern(2); + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + switch (hparams.n_layer) { case 24: type = LLM_TYPE_20B; break; case 36: type = LLM_TYPE_120B; break; @@ -2181,6 +2388,12 @@ void llama_model::load_hparams(llama_model_loader & ml) { case 10752: type = LLM_TYPE_2_6B; break; default: type = LLM_TYPE_UNKNOWN; } + if (const auto is_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); is_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + hparams.swa_layers[il] = !hparams.recurrent_layer_arr[il]; + } + } } break; case LLM_ARCH_LFM2MOE: { @@ -2194,7 +2407,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; } - type = LLM_TYPE_8B_A1B; + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_8B_A1B; break; + case 40: type = LLM_TYPE_24B_A2B; break; + default: type = LLM_TYPE_UNKNOWN; + } } break; case LLM_ARCH_SMALLTHINKER: { @@ -2204,6 +2421,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.n_swa = 4096; hparams.set_swa_pattern(4, true); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); } else { hparams.swa_type = LLAMA_SWA_TYPE_NONE; hparams.n_no_rope_layer_step = hparams.n_layer; @@ -2287,11 +2508,71 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); // Mark recurrent layers (linear attention layers) - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % 4 != 0); // TODO: extract the magic 4 from "full_attention_interval" + { + uint32_t full_attn_interval = 4; + ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + } + } + + switch (hparams.n_layer) { + case 48: type = LLM_TYPE_80B_A3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_QWEN35: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + + // Load linear attention (gated delta net) parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Mark recurrent layers (linear attention layers) + { + uint32_t full_attn_interval = 4; + ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + } + } + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_2B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_QWEN35MOE: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + + // Load linear attention (gated delta net) parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Mark recurrent layers (linear attention layers) + { + uint32_t full_attn_interval = 4; + ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + } } switch (hparams.n_layer) { + case 28: type = LLM_TYPE_35B_A3B; break; case 48: type = LLM_TYPE_80B_A3B; break; default: type = LLM_TYPE_UNKNOWN; } @@ -2322,6 +2603,82 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_MIMO2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + + switch (hparams.n_layer) { + case 48: type = LLM_TYPE_310B_A15B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_KIMI_LINEAR: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot); + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_KDA_HEAD_DIM, hparams.n_embd_head_kda); + + // MLA qk_rope_head_dim (for reference) + // qk_rope_head_dim = 64, qk_nope_head_dim = 128, qk_head_dim = 192 + + // Mark KDA layers as recurrent using n_head_kv pattern (like Jamba) + // Set n_head_kv = 0 for KDA layers (recurrent), n_head_kv = n_head for MLA layers (attention) + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; // KDA layers are recurrent + } + + // MoE parameters - Kimi uses moe_intermediate_size = 1024 + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + + switch (hparams.n_layer) { + case 27: type = LLM_TYPE_48B_A3B; break; // Kimi-Linear-48B-A3B + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_STEP35: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + + // MoE + SWA parameters + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // Step35 uses sigmoid gating by default (if not set in GGUF) + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp, hparams.n_layer, false); + ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp, hparams.n_layer, false); + + switch (hparams.n_layer) { + case 45: type = LLM_TYPE_196B_A11B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -2344,15 +2701,16 @@ void llama_model::load_vocab(llama_model_loader & ml) { bool llama_model::load_tensors(llama_model_loader & ml) { const auto & split_mode = params.split_mode; - const auto & n_gpu_layers = params.n_gpu_layers; const auto & use_mlock = params.use_mlock; const auto & tensor_split = params.tensor_split; - const int n_layer = hparams.n_layer; + const int n_layer = hparams.n_layer; + const int n_gpu_layers = this->n_gpu_layers(); const bool use_mmap_buffer = true; - LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s)\n", __func__, ml.use_mmap ? "true" : "false"); + LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s, direct_io = %s)\n", + __func__, ml.use_mmap ? "true" : "false", ml.use_direct_io ? "true" : "false"); // build a list of buffer types for the CPU and GPU devices pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts, params.no_host); @@ -2363,6 +2721,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { pimpl->gpu_buft_list.emplace(dev, std::move(buft_list)); } + ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (cpu_dev == nullptr) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } + // calculate the split points bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + n_devices(), [](float x) { return x == 0.0f; }); std::vector splits(n_devices()); @@ -2373,6 +2736,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { size_t total; size_t free; ggml_backend_dev_memory(dev, &free, &total); + + // devices can return 0 bytes for free and total memory if they do not + // have any to report. in this case, we will use the host memory as a fallback + // fixes: https://github.com/ggml-org/llama.cpp/issues/18577 + if (free == 0 && total == 0) { + ggml_backend_dev_memory(cpu_dev, &free, &total); + } splits[i] = free; } } else { @@ -2389,14 +2759,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { splits[i] /= split_sum; } - ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - if (cpu_dev == nullptr) { - throw std::runtime_error(format("%s: no CPU backend found", __func__)); - } - const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0); - const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1); + const int i_gpu_start = std::max(int(hparams.n_layer) + 1 - n_gpu_layers, 0); + const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, int(n_layer) + 1); auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev { - const bool is_swa = il < (int) hparams.n_layer && hparams.is_swa(il); + const bool is_swa = il < int(hparams.n_layer) && hparams.is_swa(il); if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) { LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(cpu_dev), is_swa); return {cpu_dev, &pimpl->cpu_buft_list}; @@ -2629,6 +2995,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // TODO: move to a separate function const auto tn = LLM_TN(arch); + + // helper: try merged gate_up_exps first, fall back to separate gate and up + auto create_tensor_gate_up_exps = [&](llama_layer & layer, int bid, int64_t n_embd_, int64_t n_ff_, int64_t n_expert_, int flags) { + layer.ffn_gate_up_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", bid), {n_embd_, n_ff_ * 2, n_expert_}, TENSOR_NOT_REQUIRED); + if (layer.ffn_gate_up_exps == nullptr) { + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags); + } + }; switch (arch) { case LLM_ARCH_LLAMA: case LLM_ARCH_REFACT: @@ -2636,6 +3011,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_MISTRAL3: + case LLM_ARCH_LLAMA_EMBED: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -3170,6 +3546,38 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); } } break; + case LLM_ARCH_MODERN_BERT: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for(int i = 0; i < n_layer; ++i) { + auto& layer = layers[i]; + + if ( i != 0 ) { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + } else{ + // layer 0 uses identity + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + } + + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3 * n_embd }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, 2 * n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + } + + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_norm = create_tensor(tn(LLM_TENSOR_CLS_NORM, "weight"), {n_embd}, TENSOR_NOT_REQUIRED); + + } break; case LLM_ARCH_NEO_BERT: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -3196,7 +3604,30 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); } } break; - case LLM_ARCH_JINA_BERT_V2: + case LLM_ARCH_EUROBERT: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + } + } break; + case LLM_ARCH_JINA_BERT_V2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); // token_type_embeddings @@ -3234,7 +3665,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, layer.ffn_gate ? n_ff : n_ff * 2}, 0); + + const auto tn_ffn_up_weight = tn(LLM_TENSOR_FFN_UP, "weight", i); + ggml_tensor * t_ffn_up = ml.get_tensor_meta(tn_ffn_up_weight.str().c_str()); + const int64_t n_ffn_up = t_ffn_up ? t_ffn_up->ne[1] : n_ff; + + GGML_ASSERT(n_ffn_up == n_ff || n_ffn_up == n_ff * 2); + layer.ffn_up = create_tensor(tn_ffn_up_weight, {n_embd, n_ffn_up}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ffn_up}, TENSOR_NOT_REQUIRED); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); @@ -3762,6 +4200,44 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0); } } break; + case LLM_ARCH_PLAMO3: + { + const int64_t head_dim_q = hparams.n_embd_head_k; + const int64_t head_dim_v = hparams.n_embd_head_v; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + const int64_t num_attention_heads = hparams.n_head(i); + const int64_t num_key_value_heads = hparams.n_head_kv(i); + const int64_t q_proj_dim = num_attention_heads * head_dim_q; + const int64_t k_proj_dim = num_key_value_heads * head_dim_q; + const int64_t v_proj_dim = num_key_value_heads * head_dim_v; + const int64_t n_ff_cur = hparams.n_ff(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), + {n_embd,q_proj_dim + k_proj_dim + v_proj_dim}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim_q}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim_q}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {num_attention_heads * head_dim_v, n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff_cur * 2}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_cur, n_embd}, 0); + } + } break; case LLM_ARCH_GPT2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4652,7 +5128,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -4693,14 +5173,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_DEEPSEEK2: { - // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B - const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); - - const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); + const bool is_mla = hparams.is_mla(); // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA - const int64_t n_embd_head_k_mla = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k; - const int64_t n_embd_head_v_mla = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v; + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); const int64_t n_embd_head_qk_rope = hparams.n_rot; const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; @@ -4715,19 +5192,23 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (!is_lite) { + if (q_lora_rank > 0) { layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); } layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); - if (!is_lite) { + if (q_lora_rank > 0) { layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); } else { @@ -4764,9 +5245,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // MoE branch - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); // Shared expert branch layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); @@ -4970,6 +5450,45 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); } } break; + case LLM_ARCH_JAIS2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // attention biases - all have shape n_embd (output dimension of projections) + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + // Jais-2 uses simple MLP (no gate) with biases + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + } + } break; case LLM_ARCH_CHATGLM: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5020,30 +5539,48 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + flags |= TENSOR_SKIP; + } + auto & layer = layers[i]; - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); if (layer.wqkv == nullptr) { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, flags); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, flags); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, flags); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, flags | TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); } - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, flags); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, flags); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, flags); + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); + } } } break; case LLM_ARCH_GLM4_MOE: @@ -5082,9 +5619,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, flags); layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, flags); layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, flags); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, flags); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, flags); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, flags); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, TENSOR_NOT_REQUIRED | flags); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, TENSOR_NOT_REQUIRED | flags); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, TENSOR_NOT_REQUIRED | flags); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags); @@ -5147,6 +5684,108 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; + case LLM_ARCH_GLM_DSA: + { + const bool is_mla = hparams.is_mla(); + if (!is_mla) { + throw std::runtime_error("GLM_DSA architecture requires MLA"); + } + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); + + const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + // TODO @ngxson : TENSOR_NOT_REQUIRED was a hack, need to remove it later + flags |= TENSOR_SKIP | TENSOR_NOT_REQUIRED; + } + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, flags); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, flags); + + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, flags); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, flags); + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, flags); + + // note: only old legacy GGUF files will have the unsplit wkv_b tensor in + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, flags); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, flags); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + + // DSA indexer + layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags); + layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags); + layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags); + layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags); + layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags); + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); + } + } + } break; case LLM_ARCH_NEMOTRON: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5196,9 +5835,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_group = hparams.ssm_n_group; const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - const int64_t n_ff_shexp = hparams.n_ff_shexp; - // embeddings tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5250,6 +5886,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); } else { if (n_expert != 0) { + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0); layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert }, 0); @@ -5334,6 +5973,84 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); } } break; + case LLM_ARCH_EXAONE_MOE: + { + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert = hparams.n_expert; + const int64_t n_expert_used = hparams.n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp; + const int64_t head_dim = hparams.n_embd_head_k; + const int64_t n_qo_dim = n_head * head_dim; + const int64_t n_kv_dim = n_head_kv * head_dim; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + flags |= TENSOR_SKIP; + } + + auto & layer = layers[i]; + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_qo_dim}, flags); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_kv_dim}, flags); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_kv_dim}, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, flags); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0) | flags); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + + // dense layers for first n_layer_dense_lead layers or nextn_predict_layers layers at the end + if (i < (int) hparams.n_layer_dense_lead || (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers)) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | flags); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, flags); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), {2 * n_embd, n_embd}, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), {n_embd}, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), {n_embd}, flags); + + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), {n_embd}, flags | TENSOR_NOT_REQUIRED); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED); + } + } + } break; case LLM_ARCH_RWKV6: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5652,9 +6369,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_WAVTOKENIZER_DEC: { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd_features, n_vocab}, 0); + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd, n_vocab}, 0); - conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd_features, hparams.posnet.n_embd}, 0); + conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd, hparams.posnet.n_embd}, 0); conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"), {1, hparams.posnet.n_embd}, 0); // posnet @@ -5750,8 +6467,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); } - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0); - output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, hparams.n_embd_out()}, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {hparams.n_embd_out()}, 0); } break; case LLM_ARCH_BAILINGMOE: { @@ -6007,6 +6724,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: + case LLM_ARCH_PADDLEOCR: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -6279,8 +6997,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM_LFM2, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); if (output == NULL) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); @@ -6325,6 +7043,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.shortconv.out_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_OUTPROJ, "weight", i), {n_embd, n_embd}, 0); } } + + // for LFM2-ColBert-350M + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.n_embd_out()}, TENSOR_NOT_REQUIRED); + dense_2_out_layers_b = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "bias"), {hparams.n_embd_out() }, TENSOR_NOT_REQUIRED); } break; case LLM_ARCH_SMALLTHINKER: { @@ -6480,15 +7202,150 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); } } break; - case LLM_ARCH_COGVLM: + case LLM_ARCH_KIMI_LINEAR: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - // if output is NULL, init from the input tok embed + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // Check for KDA specific tensors to determine layer type or if it's a mixed model + // Assuming KDA layer if KDA tensors are present + + // KDA uses head_dim = 128 (from linear_attn_config.head_dim) + const int64_t n_embd_head_k_kda = hparams.n_embd_head_kda; + const int64_t n_embd_head_v_kda = hparams.n_embd_head_kda; + const int64_t ssm_d_conv = hparams.ssm_d_conv; + + // Try loading KDA specific tensors (using SSM_ prefix) + // Conv1d weights: try 4D first, then 3D (quantization may remove trailing 1) + // 4D: [d_conv, 1, d_inner, 1], 3D: [d_conv, 1, d_inner] + layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_q_conv) { + layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, TENSOR_NOT_REQUIRED); + } + + if (layer.ssm_q_conv) { + // KDA Layer - Conv1d weights may be 3D or 4D + layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_k_conv) { + layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, 0); + } + layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_v_conv) { + layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head}, 0); + } + + // q, k, v projections + // Python: q_proj, k_proj, v_proj + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k_kda * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_head_k_kda * n_head}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_head_v_kda * n_head}, 0); + + // KDA specific projections + // f_a_proj, f_b_proj + layer.ssm_f_a = create_tensor(tn(LLM_TENSOR_SSM_F_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); // head_dim + layer.ssm_f_b = create_tensor(tn(LLM_TENSOR_SSM_F_B, "weight", i), {n_embd_head_k_kda, n_embd_head_k_kda * n_head}, 0); // projection_size + + // b_proj (beta mixing coefficient) + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), {n_embd, n_head}, 0); + + // A_log - Shape in GGUF: [1, num_heads, 1, 1] (4D) or [1, num_heads] (2D after quantization) Note: -exp(A_log) is applied in convert_hf_to_gguf.py + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head, 1, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_a) { + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); + } + + // dt_bias - shape [n_embd_head_k_kda * n_head] = [4096] + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_embd_head_k_kda * n_head}, 0); + + // g_a_proj, g_b_proj (output gate) + layer.ssm_g_a = create_tensor(tn(LLM_TENSOR_SSM_G_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); + layer.ssm_g_b = create_tensor(tn(LLM_TENSOR_SSM_G_B, "weight", i), {n_embd_head_k_kda, n_embd_head_k_kda * n_head}, 0); + + // o_norm (reusing SSM_NORM) + layer.ssm_o_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {n_embd_head_k_kda}, 0); // FusedRMSNormGated + + // o_proj + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v_kda * n_head, n_embd}, 0); + + } else { + // MLA Layer - use MLA-specific head dimensions + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); + + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, TENSOR_NOT_REQUIRED); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + + if (layer.attn_q_a_norm) { + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); + } else { + // Kimi MLA without Q compression: wq = [n_embd, n_head * n_embd_head_k_mla] + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); + } + + // Kimi: qk_rope_head_dim = 64 (actual RoPE dimension for MLA) + // Note: hparams.n_rot may be 72 (from conversion) but actual is 64 + const int64_t qk_rope_head_dim = hparams.n_rot; // From config: qk_rope_head_dim + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + qk_rope_head_dim}, 0); + // Support Legacy GGUFs that don't split wkv_b (MLA KV cache disabled) + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_k_mla - qk_rope_head_dim + n_embd_head_v_mla)}, TENSOR_NOT_REQUIRED); + if (!layer.wkv_b) { // MLA KV cache enabled + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_k_mla - qk_rope_head_dim, kv_lora_rank, n_head}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); + } + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); + } + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + // MoE intermediate size (different from dense FFN) + const int64_t n_ff_exp = hparams.n_ff_exp; + + // Kimi uses n_layer_dense_lead to determine which layers use dense FFN vs MoE + // first_k_dense_replace = 1 means layer 0 uses dense FFN, layers 1+ use MoE + if (i < (int) hparams.n_layer_dense_lead) { + // Dense FFN layer - use normal n_ff + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + // MoE layer - use n_ff_exp (1024) instead of n_ff (9216) + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // Shared experts use moe_intermediate_size * num_shared_experts + // Kimi: shared_expert_intermediate_size = 1024 * 1 = 1024 + // Tensors are 2D: [n_embd, n_ff_shexp] or [n_ff_shexp, n_embd] + const int64_t n_ff_shexp_actual = n_ff_exp * (hparams.n_expert_shared > 0 ? hparams.n_expert_shared : 1); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp_actual, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); + + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + } + } + } break; + case LLM_ARCH_COGVLM: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed if (output == NULL) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } @@ -6606,7 +7463,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } else { // Linear attention (gated delta net) specific tensors // Create tensors with calculated dimensions - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, 0); + // note: ssm_in is used by legacy GGUF + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); @@ -6616,9 +7476,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); // Shared experts layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); @@ -6627,6 +7486,265 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { hparams.n_ff_shexp, n_embd }, 0); } } break; + case LLM_ARCH_QWEN35MOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + if (!hparams.is_recurrent(i)) { + // Attention layers + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // Q/K normalization for attention layers + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + } else { + // Linear attention (gated delta net) specific tensors + // Create tensors with calculated dimensions + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + } + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + + // Shared experts + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; + + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + } + } break; + case LLM_ARCH_QWEN35: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + if (!hparams.is_recurrent(i)) { + // Attention layers + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // Q/K normalization for attention layers + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + } else { + // Linear attention (gated delta net) specific tensors + // Create tensors with calculated dimensions + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + } + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_MIMO2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + uint32_t n_head = hparams.n_head(i); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_v * n_head, n_embd }, 0); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + // non-MoE branch + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + + // MoE branch + int64_t n_ff_exp = hparams.n_ff_exp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + } break; + case LLM_ARCH_STEP35: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + // STEP35 supports per-layer partial RoPE dims; rope factors are stored as a single shared tensor + // ("rope_freqs.weight") and ggml uses only the first (n_rot_l/2) entries per layer. + uint32_t n_rot_max = 0; + for (int i = 0; i < n_layer; ++i) { + n_rot_max = std::max(n_rot_max, hparams.n_rot); + } + if (n_rot_max == 0) { + n_rot_max = n_rot; + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + const uint32_t n_head_l = hparams.n_head(i); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + + // optional rope factors (llama3) / longrope tensors + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_l}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, 0); + + // head-wise attention gate (Step35 self_attn.g_proj) + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + // dense MLP (leading dense blocks) + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + + // MoE routed experts + selection bias (router_bias) + const int64_t n_ff_exp = hparams.n_ff_exp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + // shared expert MLP + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); + } + } break; + case LLM_ARCH_MAINCODER: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -6736,10 +7854,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (llama_supports_gpu_offload()) { const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); - LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_gpu); - if (n_gpu_layers > (int) hparams.n_layer) { + int n_repeating = n_gpu; + if (n_repeating > 0) { LLAMA_LOG_INFO("%s: offloading output layer to GPU\n", __func__); + n_repeating--; } + LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_repeating); const int max_backend_supported_layers = hparams.n_layer + 1; const int max_offloadable_layers = hparams.n_layer + 1; @@ -6806,6 +7926,14 @@ size_t llama_model::n_devices() const { return devices.size(); } +uint32_t llama_model::n_gpu_layers() const { + return params.n_gpu_layers >= 0 ? params.n_gpu_layers : hparams.n_layer + 1; +} + +llama_split_mode llama_model::split_mode() const { + return params.split_mode; +} + std::map llama_model::memory_breakdown() const { std::map ret; for (const auto & [ctx, bufs] : pimpl->ctxs_bufs) { @@ -6862,55 +7990,59 @@ void llama_model::print_info() const { }; // hparams - LLAMA_LOG_INFO("%s: arch = %s\n", __func__, arch_name().c_str()); - LLAMA_LOG_INFO("%s: vocab_only = %d\n", __func__, hparams.vocab_only); - LLAMA_LOG_INFO("%s: no_alloc = %d\n", __func__, hparams.no_alloc); + LLAMA_LOG_INFO("%s: arch = %s\n", __func__, arch_name().c_str()); + LLAMA_LOG_INFO("%s: vocab_only = %d\n", __func__, hparams.vocab_only); + LLAMA_LOG_INFO("%s: no_alloc = %d\n", __func__, hparams.no_alloc); if (!hparams.vocab_only) { - LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); - LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); - LLAMA_LOG_INFO("%s: n_embd_inp = %u\n", __func__, hparams.n_embd_inp()); - LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); - LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); - LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); - LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any()); - LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k); - LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v); - LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); - LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); - LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); - LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); - LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale); - LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale); - LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); - LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); - LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups); - LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used); - LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn); - LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type); - LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type); - LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str()); - LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); - LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); - LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); - LLAMA_LOG_INFO("%s: rope_yarn_log_mul= %.4f\n", __func__, hparams.rope_yarn_log_mul); - LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); + LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); + LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); + LLAMA_LOG_INFO("%s: n_embd_inp = %u\n", __func__, hparams.n_embd_inp()); + LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); + LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); + LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); + LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any()); + LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k); + LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v); + LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); + LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); + LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); + LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); + LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale); + LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); + LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); + LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups); + LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used); + LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn); + LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type); + LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type); + LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str()); + LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); + LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + LLAMA_LOG_INFO("%s: freq_base_swa = %.1f\n", __func__, hparams.rope_freq_base_train_swa); + LLAMA_LOG_INFO("%s: freq_scale_swa = %g\n", __func__, hparams.rope_freq_scale_train_swa); + } + LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); + LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); + LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); // MRoPE (Multi-axis Rotary Position Embedding) sections if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) { - LLAMA_LOG_INFO("%s: mrope sections = [%d, %d, %d, %d]\n", __func__, s[0], s[1], s[2], s[3]); + LLAMA_LOG_INFO("%s: mrope sections = [%d, %d, %d, %d]\n", __func__, s[0], s[1], s[2], s[3]); } if (!classifier_labels.empty()) { - LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out); + LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out); size_t i = 0; for (auto label : classifier_labels) { - LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str()); + LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str()); } } } @@ -6922,57 +8054,59 @@ void llama_model::print_info() const { arch == LLM_ARCH_PLAMO2 || arch == LLM_ARCH_GRANITE_HYBRID || arch == LLM_ARCH_QWEN3NEXT || + arch == LLM_ARCH_QWEN35 || + arch == LLM_ARCH_QWEN35MOE || arch == LLM_ARCH_NEMOTRON_H || arch == LLM_ARCH_NEMOTRON_H_MOE) { - LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); - LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); - LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); - LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); - LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); - LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); + LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); + LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); + LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); + LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); + LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); } - LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); + LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); if (pimpl->n_elements >= 1e12) { - LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); + LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); } else if (pimpl->n_elements >= 1e9) { - LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); + LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); } else if (pimpl->n_elements >= 1e6) { - LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); + LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); } else { - LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); + LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); } // general kv - LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str()); + LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str()); if (arch == LLM_ARCH_DEEPSEEK) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); } - if (arch == LLM_ARCH_DEEPSEEK2) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); - LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); - LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla); - LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); + LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); + LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla()); + LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla()); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); } if (arch == LLM_ARCH_QWEN2MOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); } if (arch == LLM_ARCH_MINICPM || @@ -6980,41 +8114,41 @@ void llama_model::print_info() const { arch == LLM_ARCH_GRANITE_MOE || arch == LLM_ARCH_GRANITE_HYBRID || arch == LLM_ARCH_NEMOTRON_H_MOE) { - LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); - LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); - LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); + LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); + LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } if (arch == LLM_ARCH_BAILINGMOE) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); } if (arch == LLM_ARCH_BAILINGMOE2) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); } if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); } if (arch == LLM_ARCH_GROVEMOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp); - LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts); - LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp); + LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts); + LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale); } vocab.print_info(); @@ -7129,7 +8263,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_NEO_BERT: + case LLM_ARCH_EUROBERT: case LLM_ARCH_WAVTOKENIZER_DEC: + case LLM_ARCH_MODERN_BERT: case LLM_ARCH_GEMMA_EMBEDDING: case LLM_ARCH_DREAM: case LLM_ARCH_LLADA: @@ -7152,7 +8288,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.n_seq_max, nullptr); } else if (llm_arch_is_hybrid(arch)) { - // The main difference between hybrid architectures is the // layer filters, so pick the right one here llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; @@ -7169,23 +8304,44 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, }; } - res = new llama_memory_hybrid( - /* model */ *this, - /* attn_type_k */ params.type_k, - /* attn_type_v */ params.type_v, - /* attn_v_trans */ !cparams.flash_attn, - /* attn_kv_size */ cparams.n_ctx, - /* attn_n_pad */ 1, - /* attn_n_swa */ hparams.n_swa, - /* attn_swa_type */ hparams.swa_type, - /* recurrent_type_k */ GGML_TYPE_F32, - /* recurrent_type_v */ GGML_TYPE_F32, - /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), - /* n_seq_max */ cparams.n_seq_max, - /* offload */ cparams.offload_kqv, - /* unified */ cparams.kv_unified, - /* filter_attn */ std::move(filter_attn), - /* filter_recr */ std::move(filter_recr)); + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + // Use hybrid-iswa for hybrid models with SWA + res = new llama_memory_hybrid_iswa( + /* model */ *this, + /* attn_type_k */ params.type_k, + /* attn_type_v */ params.type_v, + /* attn_v_trans */ !cparams.flash_attn, + /* attn_swa_full */ params.swa_full, + /* attn_kv_size */ cparams.n_ctx_seq, + /* attn_n_ubatch */ cparams.n_ubatch, + /* attn_n_pad */ 1, + /* recurrent_type_r */ GGML_TYPE_F32, + /* recurrent_type_s */ GGML_TYPE_F32, + /* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* n_seq_max */ cparams.n_seq_max, + /* offload */ cparams.offload_kqv, + /* unified */ cparams.kv_unified, + /* filter_attn */ std::move(filter_attn), + /* filter_recr */ std::move(filter_recr)); + } else { + res = new llama_memory_hybrid( + /* model */ *this, + /* attn_type_k */ params.type_k, + /* attn_type_v */ params.type_v, + /* attn_v_trans */ !cparams.flash_attn, + /* attn_kv_size */ cparams.n_ctx_seq, + /* attn_n_pad */ 1, + /* attn_n_swa */ hparams.n_swa, + /* attn_swa_type */ hparams.swa_type, + /* recurrent_type_k */ GGML_TYPE_F32, + /* recurrent_type_v */ GGML_TYPE_F32, + /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* n_seq_max */ cparams.n_seq_max, + /* offload */ cparams.offload_kqv, + /* unified */ cparams.kv_unified, + /* filter_attn */ std::move(filter_attn), + /* filter_recr */ std::move(filter_recr)); + } } else { llama_memory_i::layer_reuse_cb reuse = nullptr; @@ -7247,16 +8403,24 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { switch (arch) { case LLM_ARCH_LLAMA: { - llm = std::make_unique(*this, params); + llm = std::make_unique>(*this, params); } break; case LLM_ARCH_LLAMA4: { if (hparams.swa_type == LLAMA_SWA_TYPE_NONE) { - llm = std::make_unique(*this, params); + llm = std::make_unique>(*this, params); } else { llm = std::make_unique(*this, params); } } break; + case LLM_ARCH_LLAMA_EMBED: + { + llm = std::make_unique>(*this, params); + } break; + case LLM_ARCH_MAINCODER: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_DECI: { llm = std::make_unique(*this, params); @@ -7289,10 +8453,18 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_MODERN_BERT: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_NEO_BERT: { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_EUROBERT: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_BLOOM: { llm = std::make_unique(*this, params); @@ -7378,6 +8550,14 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_PLAMO3: + { + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + llm = std::make_unique> (*this, params); + } else { + llm = std::make_unique>(*this, params); + } + } break; case LLM_ARCH_GPT2: { llm = std::make_unique(*this, params); @@ -7484,6 +8664,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_GLM_DSA: { llm = std::make_unique(*this, params); } break; @@ -7526,6 +8707,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_JAIS2: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_NEMOTRON: { llm = std::make_unique(*this, params); @@ -7547,6 +8732,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique>(*this, params); } } break; + case LLM_ARCH_EXAONE_MOE: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_RWKV6: { llm = std::make_unique(*this, params); @@ -7621,6 +8810,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_PADDLEOCR: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_HUNYUAN_MOE: { llm = std::make_unique(*this, params); @@ -7644,7 +8837,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { case LLM_ARCH_LFM2: case LLM_ARCH_LFM2MOE: { - llm = std::make_unique(*this, params); + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + llm = std::make_unique>(*this, params); + } else { + llm = std::make_unique>(*this, params); + } } break; case LLM_ARCH_SMALLTHINKER: { @@ -7678,22 +8875,47 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_QWEN35: + { + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_QWEN35MOE: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_MISTRAL3: { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_MIMO2: + { + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_KIMI_LINEAR: + { + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_STEP35: + { + llm = std::make_unique(*this, params); + } break; default: GGML_ABORT("fatal error"); } // add on pooling layer - llm->build_pooling(cls, cls_b, cls_out, cls_out_b); + llm->build_pooling(cls, cls_b, cls_out, cls_out_b, cls_norm); + + // add backend sampling layers (if any) + llm->build_sampling(); // if the gguf model was converted with --sentence-transformers-dense-modules // there will be two additional dense projection layers // dense linear projections are applied after pooling // TODO: move reranking logic here and generalize - llm->build_dense_out(dense_2_out_layers, dense_3_out_layers); + llm->build_dense_out(dense_2_out_layers, dense_2_out_layers_b, dense_3_out_layers); + + llm->res->set_outputs(); return llm->res->get_gf(); } @@ -7707,7 +8929,7 @@ llama_model_params llama_model_default_params() { llama_model_params result = { /*.devices =*/ nullptr, /*.tensor_buft_overrides =*/ nullptr, - /*.n_gpu_layers =*/ 999, + /*.n_gpu_layers =*/ -1, /*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER, /*.main_gpu =*/ 0, /*.tensor_split =*/ nullptr, @@ -7716,6 +8938,7 @@ llama_model_params llama_model_default_params() { /*.kv_overrides =*/ nullptr, /*.vocab_only =*/ false, /*.use_mmap =*/ true, + /*.use_direct_io =*/ false, /*.use_mlock =*/ false, /*.check_tensors =*/ false, /*.use_extra_bufts =*/ true, @@ -7750,6 +8973,10 @@ int32_t llama_model_n_embd_inp(const llama_model * model) { return model->hparams.n_embd_inp(); } +int32_t llama_model_n_embd_out(const llama_model * model) { + return model->hparams.n_embd_out(); +} + int32_t llama_model_n_layer(const llama_model * model) { return model->hparams.n_layer; } @@ -7821,6 +9048,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_WAVTOKENIZER_DEC: case LLM_ARCH_NEMOTRON_H: case LLM_ARCH_NEMOTRON_H_MOE: + case LLM_ARCH_KIMI_LINEAR: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values @@ -7853,6 +9081,9 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: case LLM_ARCH_MISTRAL3: + case LLM_ARCH_LLAMA_EMBED: + case LLM_ARCH_MAINCODER: + case LLM_ARCH_GLM_DSA: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 @@ -7862,8 +9093,10 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_DBRX: case LLM_ARCH_BERT: case LLM_ARCH_JINA_BERT_V3: + case LLM_ARCH_MODERN_BERT: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: + case LLM_ARCH_EUROBERT: case LLM_ARCH_STABLELM: case LLM_ARCH_BITNET: case LLM_ARCH_QWEN: @@ -7881,6 +9114,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_PHIMOE: case LLM_ARCH_PLAMO: case LLM_ARCH_PLAMO2: + case LLM_ARCH_PLAMO3: case LLM_ARCH_GEMMA: case LLM_ARCH_GEMMA2: case LLM_ARCH_GEMMA3: @@ -7894,10 +9128,12 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_NEMOTRON: case LLM_ARCH_EXAONE: case LLM_ARCH_EXAONE4: + case LLM_ARCH_EXAONE_MOE: case LLM_ARCH_MINICPM3: case LLM_ARCH_BAILINGMOE2: case LLM_ARCH_DOTS1: case LLM_ARCH_HUNYUAN_MOE: + case LLM_ARCH_JAIS2: case LLM_ARCH_OPENAI_MOE: case LLM_ARCH_HUNYUAN_DENSE: case LLM_ARCH_LFM2: @@ -7911,12 +9147,17 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_PANGU_EMBED: case LLM_ARCH_AFMOE: case LLM_ARCH_QWEN3NEXT: + case LLM_ARCH_MIMO2: + case LLM_ARCH_STEP35: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: + case LLM_ARCH_PADDLEOCR: return LLAMA_ROPE_TYPE_MROPE; case LLM_ARCH_QWEN3VL: case LLM_ARCH_QWEN3VLMOE: + case LLM_ARCH_QWEN35: + case LLM_ARCH_QWEN35MOE: return LLAMA_ROPE_TYPE_IMROPE; case LLM_ARCH_GLM4: diff --git a/llama/llama.cpp/src/llama-model.h b/llama/llama.cpp/src/llama-model.h index b378b23ec84..679977bee69 100644 --- a/llama/llama.cpp/src/llama-model.h +++ b/llama/llama.cpp/src/llama-model.h @@ -11,6 +11,7 @@ #include #include #include +#include #include struct llama_cparams; @@ -24,12 +25,14 @@ enum llm_type { LLM_TYPE_17M, LLM_TYPE_22M, LLM_TYPE_33M, + LLM_TYPE_47M, LLM_TYPE_60M, LLM_TYPE_70M, LLM_TYPE_80M, LLM_TYPE_109M, LLM_TYPE_137M, LLM_TYPE_140M, + LLM_TYPE_149M, LLM_TYPE_160M, LLM_TYPE_190M, LLM_TYPE_220M, @@ -39,6 +42,7 @@ enum llm_type { LLM_TYPE_335M, LLM_TYPE_350M, LLM_TYPE_360M, + LLM_TYPE_395M, LLM_TYPE_410M, LLM_TYPE_450M, LLM_TYPE_475M, @@ -113,15 +117,22 @@ enum llm_type { LLM_TYPE_8B_A1B, // lfm2moe LLM_TYPE_16B_A1B, LLM_TYPE_21B_A3B, // Ernie MoE small + LLM_TYPE_24B_A2B, // lfm2moe LLM_TYPE_30B_A3B, LLM_TYPE_31B_A3_5B, + LLM_TYPE_35B_A3B, // Qwen3.5 + LLM_TYPE_48B_A3B, // Kimi Linear LLM_TYPE_80B_A3B, // Qwen3 Next LLM_TYPE_100B_A6B, + LLM_TYPE_102B_A12B, // Solar-Open LLM_TYPE_106B_A12B, // GLM-4.5-Air + LLM_TYPE_196B_A11B, // Step3.5-Flash LLM_TYPE_230B_A10B, // Minimax M2 LLM_TYPE_235B_A22B, LLM_TYPE_300B_A47B, // Ernie MoE big + LLM_TYPE_310B_A15B, // /MiMo-V2-Flash LLM_TYPE_355B_A32B, // GLM-4.5 + LLM_TYPE_744B_A40B, // GLM-5 LLM_TYPE_E2B, LLM_TYPE_E4B, }; @@ -270,14 +281,16 @@ struct llama_layer { struct ggml_tensor * ffn_up_enc = nullptr; // ff MoE - struct ggml_tensor * ffn_gate_inp = nullptr; - struct ggml_tensor * ffn_gate_exps = nullptr; - struct ggml_tensor * ffn_down_exps = nullptr; - struct ggml_tensor * ffn_up_exps = nullptr; - struct ggml_tensor * ffn_gate_inp_b = nullptr; - struct ggml_tensor * ffn_gate_exps_b = nullptr; - struct ggml_tensor * ffn_down_exps_b = nullptr; - struct ggml_tensor * ffn_up_exps_b = nullptr; + struct ggml_tensor * ffn_gate_inp = nullptr; + struct ggml_tensor * ffn_gate_exps = nullptr; + struct ggml_tensor * ffn_down_exps = nullptr; + struct ggml_tensor * ffn_up_exps = nullptr; + struct ggml_tensor * ffn_gate_up_exps = nullptr; + struct ggml_tensor * ffn_gate_inp_b = nullptr; + struct ggml_tensor * ffn_gate_exps_b = nullptr; + struct ggml_tensor * ffn_down_exps_b = nullptr; + struct ggml_tensor * ffn_up_exps_b = nullptr; + struct ggml_tensor * ffn_gate_up_exps_b = nullptr; // ff shared expert (shexp) struct ggml_tensor * ffn_gate_inp_shexp = nullptr; @@ -315,6 +328,9 @@ struct llama_layer { // qwen3next struct ggml_tensor * ssm_beta_alpha = nullptr; + // qwen3.5 + struct ggml_tensor * ssm_alpha = nullptr; + // rwkv struct ggml_tensor * time_mix_w1 = nullptr; struct ggml_tensor * time_mix_w2 = nullptr; @@ -406,6 +422,25 @@ struct llama_layer { struct ggml_tensor * ffn_act_beta = nullptr; struct ggml_tensor * ffn_act_eps = nullptr; + // Kimi Linear KDA (using ssm_ prefix for consistency) + // Note: ssm_dt_b already exists above (mamba bias), reused for Kimi dt_bias + struct ggml_tensor * ssm_q_conv = nullptr; + struct ggml_tensor * ssm_k_conv = nullptr; + struct ggml_tensor * ssm_v_conv = nullptr; + struct ggml_tensor * ssm_f_a = nullptr; + struct ggml_tensor * ssm_f_b = nullptr; + struct ggml_tensor * ssm_beta = nullptr; + struct ggml_tensor * ssm_g_a = nullptr; + struct ggml_tensor * ssm_g_b = nullptr; + struct ggml_tensor * ssm_o_norm = nullptr; + + // DSA (deepseek sparse attention) + struct ggml_tensor * indexer_k_norm = nullptr; + struct ggml_tensor * indexer_k_norm_b = nullptr; + struct ggml_tensor * indexer_proj = nullptr; + struct ggml_tensor * indexer_attn_k = nullptr; + struct ggml_tensor * indexer_attn_q_b = nullptr; // note: for lora a/b, not bias + struct ggml_tensor * bskcn_tv = nullptr; struct llama_layer_posnet posnet; @@ -446,6 +481,7 @@ struct llama_model { struct ggml_tensor * cls_b = nullptr; struct ggml_tensor * cls_out = nullptr; struct ggml_tensor * cls_out_b = nullptr; + struct ggml_tensor * cls_norm = nullptr; struct ggml_tensor * conv1d = nullptr; struct ggml_tensor * conv1d_b = nullptr; @@ -462,10 +498,9 @@ struct llama_model { //Dense linear projections for SentenceTransformers models like embeddinggemma // For Sentence Transformers models structure see // https://sbert.net/docs/sentence_transformer/usage/custom_models.html#structure-of-sentence-transformer-models - struct ggml_tensor * dense_2_out_layers = nullptr; - struct ggml_tensor * dense_3_out_layers = nullptr; - - llama_model_params params; + struct ggml_tensor * dense_2_out_layers = nullptr; + struct ggml_tensor * dense_2_out_layers_b = nullptr; + struct ggml_tensor * dense_3_out_layers = nullptr; // gguf metadata std::unordered_map gguf_kv; @@ -476,6 +511,9 @@ struct llama_model { // for quantize-stats only std::vector> tensors_by_name; + // for keeping track of associated LoRA adapters + std::unordered_set loras; + int64_t t_load_us = 0; int64_t t_start_us = 0; @@ -497,6 +535,9 @@ struct llama_model { size_t n_tensors() const; size_t n_devices() const; + uint32_t n_gpu_layers() const; + llama_split_mode split_mode() const; + std::map memory_breakdown() const; // total number of parameters in the model @@ -525,6 +566,8 @@ struct llama_model { ggml_cgraph * build_graph(const llm_graph_params & params) const; private: + llama_model_params params; + struct impl; std::unique_ptr pimpl; }; diff --git a/llama/llama.cpp/src/llama-quant.cpp b/llama/llama.cpp/src/llama-quant.cpp index bc4b05c3b50..24770430e1c 100644 --- a/llama/llama.cpp/src/llama-quant.cpp +++ b/llama/llama.cpp/src/llama-quant.cpp @@ -422,57 +422,6 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t ++qs.i_ffn_up; } - // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; - //} - // IK: let's remove this, else Q2_K is almost the same as Q3_K_S - //else if (name.find("ffn_gate") != std::string::npos || name.find("ffn_up") != std::string::npos) { - // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; - //} - // This can be used to reduce the size of the Q5_K_S model. - // The associated PPL increase is fully in line with the size reduction - //else { - // if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K; - //} - bool convert_incompatible_tensor = false; - { - const int64_t nx = tensor->ne[0]; - const int64_t ny = tensor->ne[1]; - const int64_t qk_k = ggml_blck_size(new_type); - - if (nx % qk_k != 0) { - LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type)); - convert_incompatible_tensor = true; - } else { - ++qs.n_k_quantized; - } - } - - if (convert_incompatible_tensor) { - switch (new_type) { - case GGML_TYPE_TQ1_0: - case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; - case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; - case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break; - case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; - default: throw std::runtime_error("\nUnsupported tensor size encountered\n"); - } - if (tensor->ne[0] % ggml_blck_size(new_type) != 0) { - new_type = GGML_TYPE_F16; - } - LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type)); - ++qs.n_fallback; - } - return new_type; } @@ -530,6 +479,17 @@ static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * return new_size; } +static bool tensor_type_requires_imatrix(const ggml_tensor * t, const ggml_type dst_type, const llama_ftype ftype) { + return ( + dst_type == GGML_TYPE_IQ2_XXS || dst_type == GGML_TYPE_IQ2_XS || + dst_type == GGML_TYPE_IQ3_XXS || dst_type == GGML_TYPE_IQ1_S || + dst_type == GGML_TYPE_IQ2_S || dst_type == GGML_TYPE_IQ1_M || + ( // Q2_K_S is the worst k-quant type - only allow it without imatrix for token embeddings + dst_type == GGML_TYPE_Q2_K && ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(t->name, "token_embd.weight") != 0 + ) + ); +} + static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { ggml_type default_type; llama_ftype ftype = params->ftype; @@ -596,7 +556,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } std::vector splits = {}; - llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr); + llama_model_loader ml(fname_inp, splits, use_mmap, /*use_direct_io*/ false, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr); ml.init_mappings(false); // no prefetching llama_model model(llama_model_default_params()); @@ -786,24 +746,36 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: }; const auto tn = LLM_TN(model.arch); - new_ofstream(0); + + // no output file for --dry-run + if (!params->dry_run) { + new_ofstream(0); + } + + // flag for `--dry-run`, to let the user know if imatrix will be required for a real + // quantization, as a courtesy + bool will_require_imatrix = false; + for (const auto * it : tensors) { const auto & weight = *it; ggml_tensor * tensor = weight.tensor; - if (weight.idx != cur_split && params->keep_split) { + if (!params->dry_run && (weight.idx != cur_split && params->keep_split)) { close_ofstream(); new_ofstream(weight.idx); } const std::string name = ggml_get_name(tensor); + const size_t tensor_size = ggml_nbytes(tensor); - if (!ml.use_mmap) { - if (read_data.size() < ggml_nbytes(tensor)) { - read_data.resize(ggml_nbytes(tensor)); + if (!params->dry_run) { + if (!ml.use_mmap) { + if (read_data.size() < tensor_size) { + read_data.resize(tensor_size); + } + tensor->data = read_data.data(); } - tensor->data = read_data.data(); + ml.load_data_for(tensor); } - ml.load_data_for(tensor); LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ", ++idx, ml.n_tensors, @@ -838,9 +810,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight"); quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight"); - // do not quantize Mamba's small yet 2D weights + // do not quantize Mamba /Kimi's small conv1d weights // NOTE: can't use LLM_TN here because the layer number is not known - quantize &= name.find("ssm_conv1d.weight") == std::string::npos; + quantize &= name.find("ssm_conv1d") == std::string::npos; quantize &= name.find("shortconv.conv.weight") == std::string::npos; // do not quantize RWKV's small yet 2D weights @@ -875,21 +847,69 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // get more optimal quantization type based on the tensor shape, layer, etc. if (!params->pure && ggml_is_quantized(default_type)) { - int fallback = qs.n_fallback; - new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); - // unless the user specifies a type, and the tensor geometry will not require fallback quantisation - if (params->tensor_types && qs.n_fallback - fallback == 0) { + // if the user provided tensor types - use those + bool manual = false; + if (params->tensor_types) { const std::vector & tensor_types = *static_cast *>(params->tensor_types); const std::string tensor_name(tensor->name); for (const auto & [tname, qtype] : tensor_types) { if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) { if (qtype != new_type) { - LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type)); + LLAMA_LOG_WARN("(manual override: %s -> %s) ", ggml_type_name(new_type), ggml_type_name(qtype)); new_type = qtype; // if two or more types are specified for the same tensor, the last match wins + manual = true; + break; } } } } + + // if not manual - use the standard logic for choosing the quantization type based on the selected mixture + if (!manual) { + new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); + } + + // incompatible tensor shapes are handled here - fallback to a compatible type + { + bool convert_incompatible_tensor = false; + + const int64_t nx = tensor->ne[0]; + const int64_t ny = tensor->ne[1]; + const int64_t qk_k = ggml_blck_size(new_type); + + if (nx % qk_k != 0) { + LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type)); + convert_incompatible_tensor = true; + } else { + ++qs.n_k_quantized; + } + + if (convert_incompatible_tensor) { + switch (new_type) { + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; + case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; + case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break; + case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; + default: throw std::runtime_error("\nUnsupported tensor size encountered\n"); + } + if (tensor->ne[0] % ggml_blck_size(new_type) != 0) { + new_type = GGML_TYPE_F16; + } + LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type)); + ++qs.n_fallback; + } + } } if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) { new_type = params->token_embedding_type; @@ -903,129 +923,155 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: quantize = tensor->type != new_type; } - if (!quantize) { - new_type = tensor->type; - new_data = tensor->data; - new_size = ggml_nbytes(tensor); - LLAMA_LOG_INFO("size = %8.3f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0); + // we have now decided on the target type for this tensor + if (params->dry_run) { + // the --dry-run option calculates the final quantization size without quantizting + if (quantize) { + new_size = ggml_nrows(tensor) * ggml_row_size(new_type, tensor->ne[0]); + LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB (%s)\n", + tensor_size/1024.0/1024.0, + new_size/1024.0/1024.0, + ggml_type_name(new_type)); + if (!will_require_imatrix && tensor_type_requires_imatrix(tensor, new_type, params->ftype)) { + will_require_imatrix = true; + } + } else { + new_size = tensor_size; + LLAMA_LOG_INFO("size = %8.3f MiB\n", new_size/1024.0/1024.0); + } + total_size_org += tensor_size; + total_size_new += new_size; + continue; } else { - const int64_t nelements = ggml_nelements(tensor); + // no --dry-run, perform quantization + if (!quantize) { + new_type = tensor->type; + new_data = tensor->data; + new_size = tensor_size; + LLAMA_LOG_INFO("size = %8.3f MiB\n", tensor_size/1024.0/1024.0); + } else { + const int64_t nelements = ggml_nelements(tensor); - const float * imatrix = nullptr; - if (imatrix_data) { - auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped)); - if (it == imatrix_data->end()) { - LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name); - } else { - if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) { - imatrix = it->second.data(); + const float * imatrix = nullptr; + if (imatrix_data) { + auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped)); + if (it == imatrix_data->end()) { + LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name); } else { - LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__, - int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name); - - // this can happen when quantizing an old mixtral model with split tensors with a new incompatible imatrix - // this is a significant error and it may be good idea to abort the process if this happens, - // since many people will miss the error and not realize that most of the model is being quantized without an imatrix - // tok_embd should be ignored in this case, since it always causes this warning - if (name != tn(LLM_TENSOR_TOKEN_EMBD, "weight")) { - throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s", - int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name)); + if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) { + imatrix = it->second.data(); + } else { + LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__, + int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name); + + // this can happen when quantizing an old mixtral model with split tensors with a new incompatible imatrix + // this is a significant error and it may be good idea to abort the process if this happens, + // since many people will miss the error and not realize that most of the model is being quantized without an imatrix + // tok_embd should be ignored in this case, since it always causes this warning + if (name != tn(LLM_TENSOR_TOKEN_EMBD, "weight")) { + throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s", + int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name)); + } } } } - } - if ((new_type == GGML_TYPE_IQ2_XXS || - new_type == GGML_TYPE_IQ2_XS || - new_type == GGML_TYPE_IQ2_S || - new_type == GGML_TYPE_IQ1_S || - (new_type == GGML_TYPE_IQ1_M && strcmp(tensor->name, "token_embd.weight") && strcmp(tensor->name, "output.weight")) || - (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0)) && !imatrix) { - LLAMA_LOG_ERROR("\n\n============================================================\n"); - LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name); - LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n"); - LLAMA_LOG_ERROR("============================================================\n\n"); - throw std::runtime_error(format("Missing importance matrix for tensor %s in a very low-bit quantization", tensor->name)); - } + if (!imatrix && tensor_type_requires_imatrix(tensor, new_type, params->ftype)) { + LLAMA_LOG_ERROR("\n\n============================================================\n"); + LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name); + LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n"); + LLAMA_LOG_ERROR("============================================================\n\n"); + throw std::runtime_error(format("Missing importance matrix for tensor %s in a very low-bit quantization", tensor->name)); + } - float * f32_data; + float * f32_data; - if (tensor->type == GGML_TYPE_F32) { - f32_data = (float *) tensor->data; - } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) { - throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type))); - } else { - llama_tensor_dequantize_impl(tensor, f32_conv_buf, workers, nelements, nthread); - f32_data = (float *) f32_conv_buf.data(); - } + if (tensor->type == GGML_TYPE_F32) { + f32_data = (float *) tensor->data; + } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) { + throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type))); + } else { + llama_tensor_dequantize_impl(tensor, f32_conv_buf, workers, nelements, nthread); + f32_data = (float *) f32_conv_buf.data(); + } - LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type)); - fflush(stdout); + LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type)); + fflush(stdout); - if (work.size() < (size_t)nelements * 4) { - work.resize(nelements * 4); // upper bound on size - } - new_data = work.data(); + if (work.size() < (size_t)nelements * 4) { + work.resize(nelements * 4); // upper bound on size + } + new_data = work.data(); - const int64_t n_per_row = tensor->ne[0]; - const int64_t nrows = tensor->ne[1]; + const int64_t n_per_row = tensor->ne[0]; + const int64_t nrows = tensor->ne[1]; - static const int64_t min_chunk_size = 32 * 512; - const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row)); + static const int64_t min_chunk_size = 32 * 512; + const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row)); - const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1]; - const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size; - const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1; + const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1]; + const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size; + const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1; - // quantize each expert separately since they have different importance matrices - new_size = 0; - for (int64_t i03 = 0; i03 < tensor->ne[2]; ++i03) { - const float * f32_data_03 = f32_data + i03 * nelements_matrix; - void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows; - const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr; + // quantize each expert separately since they have different importance matrices + new_size = 0; + for (int64_t i03 = 0; i03 < tensor->ne[2]; ++i03) { + const float * f32_data_03 = f32_data + i03 * nelements_matrix; + void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows; + const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr; - new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use); + new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use); - // TODO: temporary sanity check that the F16 -> MXFP4 is lossless + // TODO: temporary sanity check that the F16 -> MXFP4 is lossless #if 0 - if (new_type == GGML_TYPE_MXFP4) { - auto * x = f32_data_03; - - //LLAMA_LOG_INFO("nrows = %d, n_per_row = %d\n", nrows, n_per_row); - std::vector deq(nrows*n_per_row); - const ggml_type_traits * qtype = ggml_get_type_traits(new_type); - qtype->to_float(new_data_03, deq.data(), deq.size()); - - double err = 0.0f; - for (int i = 0; i < (int) deq.size(); ++i) { - err += fabsf(deq[i] - x[i]); - //if (fabsf(deq[i] - x[i]) > 0.00001 && i < 256) { - if (deq[i] != x[i]) { - LLAMA_LOG_INFO("deq[%d] = %f, x[%d] = %f\n", i, deq[i], i, x[i]); + if (new_type == GGML_TYPE_MXFP4) { + auto * x = f32_data_03; + + //LLAMA_LOG_INFO("nrows = %d, n_per_row = %d\n", nrows, n_per_row); + std::vector deq(nrows*n_per_row); + const ggml_type_traits * qtype = ggml_get_type_traits(new_type); + qtype->to_float(new_data_03, deq.data(), deq.size()); + + double err = 0.0f; + for (int i = 0; i < (int) deq.size(); ++i) { + err += fabsf(deq[i] - x[i]); + //if (fabsf(deq[i] - x[i]) > 0.00001 && i < 256) { + if (deq[i] != x[i]) { + LLAMA_LOG_INFO("deq[%d] = %f, x[%d] = %f\n", i, deq[i], i, x[i]); + } } + //LLAMA_LOG_INFO("err = %f\n", err); + GGML_ASSERT(err == 0.00000); } - //LLAMA_LOG_INFO("err = %f\n", err); - GGML_ASSERT(err == 0.00000); - } #endif + } + LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", tensor_size/1024.0/1024.0, new_size/1024.0/1024.0); } - LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0); - } - total_size_org += ggml_nbytes(tensor); - total_size_new += new_size; + total_size_org += tensor_size; + total_size_new += new_size; + + // update the gguf meta data as we go + gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type); + GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size); + gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data); + + // write tensor data + padding + fout.write((const char *) new_data, new_size); + zeros(fout, GGML_PAD(new_size, align) - new_size); + } // no --dry-run + } // iterate over tensors + + if (!params->dry_run) { + close_ofstream(); + } - // update the gguf meta data as we go - gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type); - GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size); - gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data); + LLAMA_LOG_INFO("%s: model size = %8.2f MiB (%.2f BPW)\n", __func__, total_size_org/1024.0/1024.0, total_size_org*8.0/ml.n_elements); + LLAMA_LOG_INFO("%s: quant size = %8.2f MiB (%.2f BPW)\n", __func__, total_size_new/1024.0/1024.0, total_size_new*8.0/ml.n_elements); - // write tensor data + padding - fout.write((const char *) new_data, new_size); - zeros(fout, GGML_PAD(new_size, align) - new_size); + if (!params->imatrix && params->dry_run && will_require_imatrix) { + LLAMA_LOG_WARN("%s: WARNING: dry run completed successfully, but actually completing this quantization will require an imatrix!\n", + __func__ + ); } - close_ofstream(); - - LLAMA_LOG_INFO("%s: model size = %8.2f MiB\n", __func__, total_size_org/1024.0/1024.0); - LLAMA_LOG_INFO("%s: quant size = %8.2f MiB\n", __func__, total_size_new/1024.0/1024.0); if (qs.n_fallback > 0) { LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n", @@ -1048,6 +1094,7 @@ llama_model_quantize_params llama_model_quantize_default_params() { /*.only_copy =*/ false, /*.pure =*/ false, /*.keep_split =*/ false, + /*.dry_run =*/ false, /*.imatrix =*/ nullptr, /*.kv_overrides =*/ nullptr, /*.tensor_type =*/ nullptr, diff --git a/llama/llama.cpp/src/llama-sampling.cpp b/llama/llama.cpp/src/llama-sampler.cpp similarity index 61% rename from llama/llama.cpp/src/llama-sampling.cpp rename to llama/llama.cpp/src/llama-sampler.cpp index 38a30ea05e2..5cf66b63f1f 100644 --- a/llama/llama.cpp/src/llama-sampling.cpp +++ b/llama/llama.cpp/src/llama-sampler.cpp @@ -1,9 +1,11 @@ -#include "llama-sampling.h" +#include "llama-sampler.h" #include "llama-impl.h" #include "llama-vocab.h" #include "llama-grammar.h" +#include "ggml-cpp.h" + #include #include #include @@ -346,7 +348,9 @@ static uint32_t get_rng_seed(uint32_t seed) { // llama_sampler API -struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) { +struct llama_sampler * llama_sampler_init( + struct llama_sampler_i * iface, + llama_sampler_context_t ctx) { return new llama_sampler { /* .iface = */ iface, /* .ctx = */ ctx, @@ -362,23 +366,39 @@ const char * llama_sampler_name(const struct llama_sampler * smpl) { } void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { + if (!smpl) { + return; + } + if (smpl->iface->accept) { smpl->iface->accept(smpl, token); } } void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) { + if (!smpl) { + return; + } + GGML_ASSERT(smpl->iface->apply); smpl->iface->apply(smpl, cur_p); } void llama_sampler_reset(struct llama_sampler * smpl) { + if (!smpl) { + return; + } + if (smpl->iface->reset) { smpl->iface->reset(smpl); } } struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) { + if (!smpl) { + return nullptr; + } + if (smpl->iface->clone) { return smpl->iface->clone(smpl); } @@ -405,37 +425,200 @@ void llama_sampler_free(struct llama_sampler * smpl) { delete smpl; } -llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { - const auto * logits = llama_get_logits_ith(ctx, idx); +// empty sampler - const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_model_get_vocab(model); +struct llama_sampler_empty { + const char * name; +}; - const int n_vocab = llama_vocab_n_tokens(vocab); +static struct llama_sampler * llama_sampler_init_empty(const char * name); + +static const char * llama_sampler_empty_name(const struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_empty *) smpl->ctx; + return ctx->name; +} + +static void llama_sampler_empty_accept(struct llama_sampler * smpl, llama_token token) { + GGML_UNUSED(smpl); + GGML_UNUSED(token); +} + +static void llama_sampler_empty_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + GGML_UNUSED(smpl); + GGML_UNUSED(cur_p); +} + +static void llama_sampler_empty_reset(struct llama_sampler * smpl) { + GGML_UNUSED(smpl); +} + +static struct llama_sampler * llama_sampler_empty_clone(const struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_empty *) smpl->ctx; + return llama_sampler_init_empty(ctx->name); +} + +static void llama_sampler_empty_free(struct llama_sampler * smpl) { + delete (llama_sampler_empty *) smpl->ctx; +} + +static bool llama_sampler_empty_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + GGML_UNUSED(smpl); + GGML_UNUSED(buft); + + return true; +} + +static void llama_sampler_empty_backend_accept( + struct llama_sampler * smpl, + ggml_context * ctx, + ggml_cgraph * gf, + struct ggml_tensor * selected_token) { + GGML_UNUSED(smpl); + GGML_UNUSED(ctx); + GGML_UNUSED(gf); + GGML_UNUSED(selected_token); +} + +static void llama_sampler_empty_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + GGML_UNUSED(smpl); + GGML_UNUSED(ctx); + GGML_UNUSED(gf); + GGML_UNUSED(data); +} - // TODO: do not allocate each time - std::vector cur; - cur.reserve(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); +static void llama_sampler_empty_backend_set_input(struct llama_sampler * smpl) { + GGML_UNUSED(smpl); +} + +static struct llama_sampler_i llama_sampler_empty_i = { + /* .name = */ llama_sampler_empty_name, + /* .accept = */ llama_sampler_empty_accept, + /* .apply = */ llama_sampler_empty_apply, + /* .reset = */ llama_sampler_empty_reset, + /* .clone = */ llama_sampler_empty_clone, + /* .free = */ llama_sampler_empty_free, + /* .backend_init = */ llama_sampler_empty_backend_init, + /* .backend_accept = */ llama_sampler_empty_backend_accept, + /* .backend_apply = */ llama_sampler_empty_backend_apply, + /* .backend_set_input = */ llama_sampler_empty_backend_set_input, +}; + +struct llama_sampler * llama_sampler_init_empty(const char * name) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_empty_i, + /* .ctx = */ new llama_sampler_empty { + /* .name = */ name, + } + ); +} + +// common backend sampler functionality +// +// +name : means that the sampler is support and will run on the backend +// -name : means that a ggml operator is not supported by the backend +// +struct llama_sampler_backend { + llama_sampler_backend(const char * name) : name(name), name_ext(name), is_init(false), support(false) {} + + const char * get_name() { + if (!is_init) { + return name.c_str(); + } + + if (support) { + name_ext = "+" + name; + } else { + name_ext = "-" + name; + } + + return name_ext.c_str(); } - llama_token_data_array cur_p = { - /* .data = */ cur.data(), - /* .size = */ cur.size(), - /* .selected = */ -1, - /* .sorted = */ false, + void init(bool support) { + GGML_ASSERT(this->is_init == false); + + this->is_init = true; + this->support = support; + } + +private: + std::string name; + std::string name_ext; + + bool is_init; + bool support; +}; + +// check if all ggml ops used by the sampler are supported by the backend +static bool llama_sampler_backend_support( + llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * device = ggml_backend_buft_get_device(buft); + if (!device) { + // CPU backend always supported + return true; + } + + ggml_init_params params = { + /*.mem_size =*/ 128*ggml_tensor_overhead() + ggml_graph_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, }; - llama_sampler_apply(smpl, &cur_p); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + if (!ctx_ptr) { + throw std::runtime_error(format("failed to create ggml context")); + } - GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size); + ggml_context * ctx = ctx_ptr.get(); - auto token = cur_p.data[cur_p.selected].id; + const int64_t n = 1024*1024; - llama_sampler_accept(smpl, token); + llama_sampler_data data = { + /*.logits = */ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n), + /*.probs = */ nullptr, + /*.sampled = */ nullptr, + /*.candidates = */ ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n), + }; - return token; + ggml_cgraph * gf = ggml_new_graph(ctx); + + smpl->iface->backend_apply(smpl, ctx, gf, &data); + + if (data.logits) { + ggml_build_forward_expand(gf, data.logits); + } + + if (data.probs) { + ggml_build_forward_expand(gf, data.probs); + } + + if (data.sampled) { + ggml_build_forward_expand(gf, data.sampled); + } + + if (data.candidates) { + ggml_build_forward_expand(gf, data.candidates); + } + + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + struct ggml_tensor * op = ggml_graph_node(gf, i); + + if (!ggml_backend_dev_supports_op(device, op)) { + LLAMA_LOG_WARN("%s: device '%s' does not have support for op %s needed for sampler '%s'\n", + __func__, ggml_backend_dev_name(device), ggml_op_name(op->op), smpl->iface->name(smpl)); + + return false; + } + } + + return true; } // sampler chain @@ -449,8 +632,8 @@ static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token time_meas tm(chain->t_sample_us, chain->params.no_perf); - for (auto * smpl : chain->samplers) { - llama_sampler_accept(smpl, token); + for (auto & smpl : chain->samplers) { + llama_sampler_accept(smpl.ptr, token); } chain->n_sample++; @@ -461,16 +644,28 @@ static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_d time_meas tm(chain->t_sample_us, chain->params.no_perf); - for (auto * smpl : chain->samplers) { - llama_sampler_apply(smpl, cur_p); + bool is_backend = chain->is_init; + + for (auto & smpl : chain->samplers) { + if (is_backend && smpl.is_backend) { + continue; + } + + is_backend = false; + + if (smpl.ptr->iface->apply == nullptr) { + continue; + } + + llama_sampler_apply(smpl.ptr, cur_p); } } static void llama_sampler_chain_reset(struct llama_sampler * smpl) { auto * chain = (llama_sampler_chain *) smpl->ctx; - for (auto * smpl : chain->samplers) { - llama_sampler_reset(smpl); + for (auto & smpl : chain->samplers) { + llama_sampler_reset(smpl.ptr); } } @@ -479,8 +674,8 @@ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampl auto * result = llama_sampler_chain_init(chain_src->params); - for (auto * smpl : chain_src->samplers) { - llama_sampler_chain_add(result, llama_sampler_clone(smpl)); + for (const auto & smpl : chain_src->samplers) { + llama_sampler_chain_add(result, llama_sampler_clone(smpl.ptr)); } return result; @@ -489,20 +684,109 @@ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampl static void llama_sampler_chain_free(struct llama_sampler * smpl) { auto * chain = (llama_sampler_chain *) smpl->ctx; - for (auto * smpl : chain->samplers) { - llama_sampler_free(smpl); + for (auto & smpl : chain->samplers) { + llama_sampler_free(smpl.ptr); } delete chain; } +static bool llama_sampler_chain_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + GGML_ASSERT(chain->is_init == false && "llama_sampler_chain_backend_init() called twice"); + + chain->is_init = true; + + bool res = true; + + for (auto & smpl : chain->samplers) { + bool res_cur = true; + + // to be able to run a sampler on the backend, it has to: + // - have the .backend_init() API implemented + // - return true during .backend_init() + if (smpl.ptr->iface->backend_init) { + if (!smpl.ptr->iface->backend_init(smpl.ptr, buft)) { + res_cur = false; + } + } else { + res_cur = false; + } + + smpl.is_backend = res_cur; + + res = res && res_cur; + } + + return res; +} + +static void llama_sampler_chain_backend_accept( + struct llama_sampler * smpl, + ggml_context * ctx, + ggml_cgraph * gf, + struct ggml_tensor * selected_token) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto & smpl : chain->samplers) { + if (!smpl.is_backend) { + break; + } + + if (smpl.ptr->iface->backend_accept) { + smpl.ptr->iface->backend_accept(smpl.ptr, ctx, gf, selected_token); + } + } +} + +static void llama_sampler_chain_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + GGML_ASSERT(chain->is_init && "llama_sampler_chain_backend_init() not called"); + + for (auto & smpl : chain->samplers) { + if (!smpl.is_backend) { + break; + } + + if (smpl.ptr->iface->backend_apply) { + smpl.ptr->iface->backend_apply(smpl.ptr, ctx, gf, data); + } + } +} + +static void llama_sampler_chain_backend_set_input(struct llama_sampler * smpl) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto & smpl : chain->samplers) { + if (!smpl.is_backend) { + break; + } + + if (smpl.ptr->iface->backend_set_input) { + smpl.ptr->iface->backend_set_input(smpl.ptr); + } + } +} + static struct llama_sampler_i llama_sampler_chain_i = { - /* .name = */ llama_sampler_chain_name, - /* .accept = */ llama_sampler_chain_accept, - /* .apply = */ llama_sampler_chain_apply, - /* .reset = */ llama_sampler_chain_reset, - /* .clone = */ llama_sampler_chain_clone, - /* .free = */ llama_sampler_chain_free, + /* .name = */ llama_sampler_chain_name, + /* .accept = */ llama_sampler_chain_accept, + /* .apply = */ llama_sampler_chain_apply, + /* .reset = */ llama_sampler_chain_reset, + /* .clone = */ llama_sampler_chain_clone, + /* .free = */ llama_sampler_chain_free, + /* .backend_init = */ llama_sampler_chain_backend_init, + /* .backend_accept = */ llama_sampler_chain_backend_accept, + /* .backend_apply = */ llama_sampler_chain_backend_apply, + /* .backend_set_input = */ llama_sampler_chain_backend_set_input, }; struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { @@ -510,26 +794,113 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param /* .iface = */ &llama_sampler_chain_i, /* .ctx = */ new llama_sampler_chain { /* .params = */ params, + /* .is_init = */ false, /* .samplers = */ {}, + /* .cur = */ {}, /* .t_sample_us = */ 0, /* .n_sample = */ 0, } ); } +llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { + const llama_token sampled_token = llama_get_sampled_token_ith (ctx, idx); + const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx); + const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx); + const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx); + + // If a backend sampler has already sampled a token, return it. + if (sampled_token != LLAMA_TOKEN_NULL) { + LLAMA_LOG_DEBUG("%s: Backend sampler selected token for idx %d. Skipping CPU samplers\n", __func__, idx); + return sampled_token; + } + + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_vocab = llama_vocab_n_tokens(vocab); + + // use pre-allocated buffer from chain if available, otherwise allocate locally + std::vector * cur_ptr; + std::vector cur_local; + + if (smpl->iface == &llama_sampler_chain_i) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + cur_ptr = &chain->cur; + } else { + cur_ptr = &cur_local; + } + + auto & cur = *cur_ptr; + + if (sampled_probs) { + const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx); + cur.resize(sampled_probs_count); + for (uint32_t i = 0; i < sampled_probs_count; ++i) { + cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]}; + } + } else if (sampled_logits) { + const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx); + cur.resize(sampled_logits_count); + for (llama_token i = 0; i < (int)sampled_logits_count; i++) { + cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}; + } + } else { + const auto * logits = llama_get_logits_ith(ctx, idx); + GGML_ASSERT(logits != nullptr); + cur.resize(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } + } + + llama_token_data_array cur_p = { + /* .data = */ cur.data(), + /* .size = */ cur.size(), + /* .selected = */ -1, + /* .sorted = */ false, + }; + + llama_sampler_apply(smpl, &cur_p); + + GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size); + + auto token = cur_p.data[cur_p.selected].id; + + llama_sampler_accept(smpl, token); + + return token; +} + + void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) { auto * p = (llama_sampler_chain *) chain->ctx; - p->samplers.push_back(smpl); + p->samplers.push_back({ + /* .is_backend = */ false, + /* .ptr = */ smpl, + }); } -struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) { +struct llama_sampler * llama_sampler_chain_get(struct llama_sampler * chain, int32_t i) { + if (chain == nullptr) { + return nullptr; + } + + if (chain->iface != &llama_sampler_chain_i) { + return nullptr; + } + + if (i == -1) { + return chain; + } + const auto * p = (const llama_sampler_chain *) chain->ctx; if (i < 0 || (size_t) i >= p->samplers.size()) { return nullptr; } - return p->samplers[i]; + return p->samplers[i].ptr; } struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) { @@ -539,7 +910,7 @@ struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, return nullptr; } - auto * result = p->samplers[i]; + auto * result = p->samplers[i].ptr; p->samplers.erase(p->samplers.begin() + i); return result; @@ -557,8 +928,36 @@ int llama_sampler_chain_n(const struct llama_sampler * chain) { // greedy -static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) { - return "greedy"; +struct llama_sampler_greedy : public llama_sampler_backend { +}; + +static const char * llama_sampler_greedy_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_greedy *) smpl->ctx; + return sctx->get_name(); +} + +static void llama_sampler_greedy_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_greedy *) smpl->ctx; + GGML_UNUSED(ctx); +} + +static struct llama_sampler * llama_sampler_greedy_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_greedy *) smpl->ctx; + auto * result = llama_sampler_init_greedy(); + + // copy the state + { + auto * result_ctx = (llama_sampler_greedy *) result->ctx; + + GGML_UNUSED(ctx); + GGML_UNUSED(result_ctx); + } + + return result; +} + +static void llama_sampler_greedy_free(struct llama_sampler * smpl) { + delete (llama_sampler_greedy *) smpl->ctx; } static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) { @@ -570,33 +969,68 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to } } +static bool llama_sampler_greedy_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_greedy *) smpl->ctx; + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + return res; +} + +static void llama_sampler_greedy_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + GGML_UNUSED(gf); + GGML_UNUSED(smpl); + + struct ggml_tensor * curl = ggml_argmax(ctx, data->logits); + ggml_set_name(curl, "greedy_argmax"); + + data->sampled = curl; +} + static struct llama_sampler_i llama_sampler_greedy_i = { - /* .name = */ llama_sampler_greedy_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_greedy_apply, - /* .reset = */ nullptr, - /* .clone = */ nullptr, - /* .free = */ nullptr, + /* .name = */ llama_sampler_greedy_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_greedy_apply, + /* .reset = */ llama_sampler_greedy_reset, + /* .clone = */ llama_sampler_greedy_clone, + /* .free = */ llama_sampler_greedy_free, + /* .backend_init = */ llama_sampler_greedy_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_greedy_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_greedy() { return llama_sampler_init( /* .iface = */ &llama_sampler_greedy_i, - /* .ctx = */ nullptr + /* .ctx = */ new llama_sampler_greedy { + ("greedy"), + } ); } // dist -struct llama_sampler_dist { +struct llama_sampler_dist : public llama_sampler_backend { const uint32_t seed; uint32_t seed_cur; std::mt19937 rng; + + ggml_tensor * inp_uniform; }; -static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) { - return "dist"; +static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_dist *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -671,6 +1105,12 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da #endif } +static void llama_sampler_dist_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_dist *) smpl->ctx; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); +} + static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_dist *) smpl->ctx; auto * result = llama_sampler_init_dist(ctx->seed); @@ -685,23 +1125,106 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample return result; } -static void llama_sampler_dist_reset(struct llama_sampler * smpl) { - auto * ctx = (llama_sampler_dist *) smpl->ctx; - ctx->seed_cur = get_rng_seed(ctx->seed); - ctx->rng.seed(ctx->seed_cur); -} - static void llama_sampler_dist_free(struct llama_sampler * smpl) { delete (llama_sampler_dist *) smpl->ctx; } +static bool llama_sampler_dist_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_dist *) smpl->ctx; + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + return res; +} + +static void llama_sampler_dist_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + GGML_UNUSED(gf); + + auto * sctx = (llama_sampler_dist *) smpl->ctx; + + sctx->inp_uniform = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); + ggml_set_name (sctx->inp_uniform, "uniform"); + ggml_set_input(sctx->inp_uniform); + + struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits); + ggml_set_name(probs, "dist_probs"); + + struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs); + ggml_set_name(cumsum, "dist_cumsum"); + + // The uniform tensor has a random value and we subtract this tensor with + // the cumsum tensor (the uniform tensor will be broadcasted by ggml_sub). + // Recall that each entry in cumsum is the cumulative probability up to that + // index so values stay negative while the cumulative total is below the + // random value, and become zero/positive once the threshold is crossed. + struct ggml_tensor * diff = ggml_sub(ctx, cumsum, sctx->inp_uniform); + ggml_set_name(diff, "dist_cumsum"); + + // The ggml_step function produces a tensor where entries are 1 if the + // corresponding entry in diff is > 0, and 0 otherwise. So all values up to + // the index where the cumulative probability exceeds the random value are 0, + // and all entries after that are 1. + struct ggml_tensor * mask = ggml_step(ctx, diff); + ggml_set_name(mask, "dist_mask"); + + // Taking the sum of the mask gives us the sum of elements after the threshold + // we are interested in. + struct ggml_tensor * idxf = ggml_sum(ctx, mask); + ggml_set_name(idxf, "dist_index_f32"); + + // Use ggml_scale_bias to scale the index value by -1 and then add the size + // of the mask to that value so we get the correct index ((-1 * idxf) + n). + struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32); + ggml_set_name(idx, "dist_index_i32"); + + // Map back to original vocab ids if a candidates tensor is available. + struct ggml_tensor * sampled_token = idx; + if (data->candidates != nullptr) { + struct ggml_tensor * candidates = ggml_reshape_2d(ctx, data->candidates, 1, ggml_nelements(data->candidates)); + + sampled_token = ggml_get_rows(ctx, candidates, idx); + ggml_set_name(sampled_token, "dist_sampled_token"); + } + + data->sampled = sampled_token; + data->probs = probs; +} + +static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_dist *) smpl->ctx; + + GGML_ASSERT(sctx->inp_uniform != nullptr); + + // We sample in double precision and cast to float to match rnd numbers of + // llama_dampler_dist which uses double precision (sampling from + // std::uniform_real_distribution and + // std::uniform_real_distribution with same rng will produce + // different sequences). + std::uniform_real_distribution dist(0.0f, 1.0f); + const float rnd = dist(sctx->rng); + + ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float)); +} + static struct llama_sampler_i llama_sampler_dist_i = { - /* .name = */ llama_sampler_dist_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_dist_apply, - /* .reset = */ llama_sampler_dist_reset, - /* .clone = */ llama_sampler_dist_clone, - /* .free = */ llama_sampler_dist_free, + /* .name = */ llama_sampler_dist_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_dist_apply, + /* .reset = */ llama_sampler_dist_reset, + /* .clone = */ llama_sampler_dist_clone, + /* .free = */ llama_sampler_dist_free, + /* .backend_init = */ llama_sampler_dist_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_dist_backend_apply, + /* .backend_set_input = */ llama_sampler_dist_backend_set_input, }; struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { @@ -709,21 +1232,24 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { return llama_sampler_init( /* .iface = */ &llama_sampler_dist_i, /* .ctx = */ new llama_sampler_dist { - /* .seed = */ seed, - /* .seed_cur = */ seed_cur, - /* .rng = */ std::mt19937(seed_cur), + ("dist"), + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .rng = */ std::mt19937(seed_cur), + /* .inp_uniform = */ nullptr, } ); } // top-k -struct llama_sampler_top_k { +struct llama_sampler_top_k : public llama_sampler_backend { const int32_t k; }; -static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) { - return "top-k"; +static const char * llama_sampler_top_k_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_top_k *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -740,19 +1266,69 @@ static void llama_sampler_top_k_free(struct llama_sampler * smpl) { delete (llama_sampler_top_k *) smpl->ctx; } +static bool llama_sampler_top_k_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_top_k *) smpl->ctx; + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + return res; +} + +static void llama_sampler_top_k_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + auto * sctx = (llama_sampler_top_k *) smpl->ctx; + + struct ggml_tensor * top_k = ggml_top_k(ctx, data->logits, sctx->k); + ggml_set_name(top_k, "top_k"); + + if (data->candidates) { + struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]); + data->candidates = ggml_get_rows(ctx, candidates_rows, top_k); + data->candidates = ggml_reshape_1d(ctx, data->candidates, sctx->k); + ggml_set_name(data->candidates, "top_k_candidates"); + } else { + data->candidates = top_k; + } + + struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); + struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k); + data->logits = ggml_reshape_1d(ctx, top_k_rows, sctx->k); + ggml_set_name(top_k_rows, "top_k_rows"); + + GGML_UNUSED(gf); +} + static struct llama_sampler_i llama_sampler_top_k_i = { - /* .name = */ llama_sampler_top_k_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_top_k_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_top_k_clone, - /* .free = */ llama_sampler_top_k_free, + /* .name = */ llama_sampler_top_k_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_k_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_k_clone, + /* .free = */ llama_sampler_top_k_free, + /* .backend_init = */ llama_sampler_top_k_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_top_k_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_top_k(int32_t k) { + const bool is_empty = (k <= 0); + + if (is_empty) { + return llama_sampler_init_empty("?top-k"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_top_k_i, /* .ctx = */ new llama_sampler_top_k { + ("top-k"), /* .k = */ k, } ); @@ -760,15 +1336,16 @@ struct llama_sampler * llama_sampler_init_top_k(int32_t k) { // top-p -struct llama_sampler_top_p { +struct llama_sampler_top_p : public llama_sampler_backend { const float p; const size_t min_keep; std::vector buf_sort; }; -static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) { - return "top-p"; +static const char * llama_sampler_top_p_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_top_p *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -835,19 +1412,115 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) { delete (llama_sampler_top_p *) smpl->ctx; } +static bool llama_sampler_top_p_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_top_p *) smpl->ctx; + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + return res; +} + +static void llama_sampler_top_p_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + auto * sctx = (llama_sampler_top_p *) smpl->ctx; + + auto ggml_sort = [ctx](struct ggml_tensor * a, struct ggml_tensor * b) { + GGML_ASSERT(ggml_nrows(a) == 1); + struct ggml_tensor * a_reshaped = ggml_reshape_2d(ctx, a, 1, a->ne[0]); + struct ggml_tensor * a_sorted = ggml_get_rows(ctx, a_reshaped, b); + return ggml_reshape_1d(ctx, a_sorted, a->ne[0]); + }; + + // Get the sorted logits in descending order. + struct ggml_tensor * sorted_idx = ggml_argsort(ctx, data->logits, GGML_SORT_ORDER_DESC); + ggml_set_name(sorted_idx, "top_p_sorted_idx"); + + // Do the sorting via reshape + get_rows + struct ggml_tensor * sorted_logits = ggml_sort(data->logits, sorted_idx); + ggml_set_name(sorted_logits, "top_p_sorted_logits"); + + struct ggml_tensor * softmax = ggml_soft_max(ctx, sorted_logits); + ggml_set_name(softmax, "top_p_softmax"); + + // If candidates are provided, sort them as well. Otherwise, set sorted indices as candidates. + if (data->candidates) { + data->candidates = ggml_sort(data->candidates, sorted_idx); + } else { + data->candidates = sorted_idx; + } + ggml_set_name(data->candidates, "top_p_candidates"); + + // Compute Cumulative Distribution Function (CDF) by means of GGML_OP_CUMSUM. + struct ggml_tensor * cdf = ggml_cumsum(ctx, softmax); + ggml_set_name(cdf, "top_p_cdf"); + + // Invert CDF and add top-p value so that ggml_step yields 1 for values we want to keep + struct ggml_tensor * cdf_scaled = ggml_scale_bias(ctx, cdf, -1.0f, sctx->p); + ggml_set_name(cdf_scaled, "top_p_cdf_scaled"); + + struct ggml_tensor * mask = ggml_step(ctx, cdf_scaled); + ggml_set_name(mask, "top_p_mask"); + + // Taking the sum of the mask gives us the sum of elements after the threshold + // we are interested in. + struct ggml_tensor * idxf = ggml_sum(ctx, mask); + ggml_set_name(idxf, "top_p_index_f32"); + + // prevent out-of-bounds access + idxf = ggml_clamp(ctx, idxf, 0.0f, mask->ne[0] - 1); + + // construct ones tensor to set the value in the mask + struct ggml_tensor * ones = ggml_scale_bias(ctx, idxf, 0.0f, 1.0f); + ggml_set_name(ones, "top_p_ones"); + + // Make top-p inclusive (i.e. return all values such that cum_sum/cdf >= p) + struct ggml_tensor * mask_reshaped = ggml_reshape_2d(ctx, mask, 1, mask->ne[0]); + + mask_reshaped = ggml_set_rows(ctx, mask_reshaped, ones, ggml_cast(ctx, idxf, GGML_TYPE_I32)); + mask = ggml_reshape_1d(ctx, mask_reshaped, mask->ne[0]); + + // Apply -INFINITY bias for masked-out tokens + // log(1) = 0 (keep), log(0) = -INF (discard) + struct ggml_tensor * top_p_bias = ggml_log(ctx, mask); + ggml_set_name(top_p_bias, "top_p_bias"); + + data->logits = ggml_add(ctx, sorted_logits, top_p_bias); + ggml_set_name(data->logits, "top_p_logits"); + + GGML_UNUSED(gf); +} + static struct llama_sampler_i llama_sampler_top_p_i = { - /* .name = */ llama_sampler_top_p_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_top_p_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_top_p_clone, - /* .free = */ llama_sampler_top_p_free, + /* .name = */ llama_sampler_top_p_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_p_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_p_clone, + /* .free = */ llama_sampler_top_p_free, + /* .backend_init = */ llama_sampler_top_p_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_top_p_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { + const bool is_empty = p >= 1.0f; + + if (is_empty) { + return llama_sampler_init_empty("?top-p"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_top_p_i, /* .ctx = */ new llama_sampler_top_p { + ("top-p"), /* .p = */ p, /* .min_keep = */ min_keep, /* .buf_sort = */ {}, @@ -857,13 +1530,14 @@ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { // min-p -struct llama_sampler_min_p { +struct llama_sampler_min_p : public llama_sampler_backend { const float p; const size_t min_keep; }; -static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) { - return "min-p"; +static const char * llama_sampler_min_p_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_min_p *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -929,19 +1603,81 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) { delete (llama_sampler_min_p *) smpl->ctx; } +static bool llama_sampler_min_p_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_min_p *) smpl->ctx; + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + return res; +} + +static void llama_sampler_min_p_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + auto * sctx = (llama_sampler_min_p *) smpl->ctx; + + struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits); + ggml_set_name(max_idx, "max_idx"); + + struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); + ggml_set_name(logits_rows, "logits_rows"); + + struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx); + ggml_set_name(max_logit, "max_logit"); + + // Calculate the threshold value. + struct ggml_tensor * threshold = ggml_scale_bias(ctx, max_logit, 1.0f, logf(sctx->p)); + ggml_set_name(threshold, "min_p_threshold"); + + // Subtract the threshold from logits. + struct ggml_tensor * sub = ggml_sub(ctx, data->logits, threshold); + + // Create a mask where logits below the threshold are 0 (discard), + // and others are 1 (keep). + struct ggml_tensor * mask = ggml_step(ctx, sub); + ggml_set_name(mask, "min_p_mask"); + + // Apply -INFINITY bias for masked-out tokens + // log(1) = 0 (keep), log(0) = -INF (discard) + struct ggml_tensor * min_p_bias = ggml_log(ctx, mask); + ggml_set_name(min_p_bias, "min_p_bias"); + + data->logits = ggml_add(ctx, data->logits, min_p_bias); + ggml_set_name(data->logits, "min_p_logits"); + + GGML_UNUSED(gf); +} + static struct llama_sampler_i llama_sampler_min_p_i = { - /* .name = */ llama_sampler_min_p_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_min_p_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_min_p_clone, - /* .free = */ llama_sampler_min_p_free, + /* .name = */ llama_sampler_min_p_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_min_p_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_min_p_clone, + /* .free = */ llama_sampler_min_p_free, + /* .backend_init = */ llama_sampler_min_p_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_min_p_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { + const bool is_empty = (p <= 0.0f); + + if (is_empty) { + return llama_sampler_init_empty("?min-p"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_min_p_i, /* .ctx = */ new llama_sampler_min_p { + ("min-p"), /* .p = */ p, /* .min_keep = */ min_keep, } @@ -1029,15 +1765,25 @@ static void llama_sampler_typical_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_typical_i = { - /* .name = */ llama_sampler_typical_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_typical_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_typical_clone, - /* .free = */ llama_sampler_typical_free, + /* .name = */ llama_sampler_typical_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_typical_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_typical_clone, + /* .free = */ llama_sampler_typical_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { + const bool is_empty = (p >= 1.0f); + + if (is_empty) { + return llama_sampler_init_empty("?typical"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_typical_i, /* .ctx = */ new llama_sampler_typical { @@ -1049,12 +1795,13 @@ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { // temp -struct llama_sampler_temp { +struct llama_sampler_temp : public llama_sampler_backend { const float temp; }; -static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) { - return "temp"; +static const char * llama_sampler_temp_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_temp *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -1072,19 +1819,79 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) { delete (llama_sampler_temp *) smpl->ctx; } +static void llama_sampler_backend_temp_sampling( + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data, + float temp) { + if (temp <= 0.0f) { + // Find the most probable token index. + struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits); + ggml_set_name(max_idx, "temp_max_idx"); + + if (data->candidates) { + struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]); + data->candidates = ggml_get_rows(ctx, candidates_rows, max_idx); + } else { + data->candidates = max_idx; + } + + struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); + data->logits = ggml_get_rows(ctx, logits_rows, max_idx); + + return; + } + + data->logits = ggml_scale(ctx, data->logits, 1.0f / temp); + + GGML_UNUSED(gf); +} + +static bool llama_sampler_temp_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_temp *) smpl->ctx; + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + return res; +} + +static void llama_sampler_temp_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + auto * sctx = (llama_sampler_temp *) smpl->ctx; + llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp); +} + static struct llama_sampler_i llama_sampler_temp_i = { - /* .name = */ llama_sampler_temp_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_temp_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_temp_clone, - /* .free = */ llama_sampler_temp_free, + /* .name = */ llama_sampler_temp_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_temp_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_temp_clone, + /* .free = */ llama_sampler_temp_free, + /* .backend_init = */ llama_sampler_temp_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_temp_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_temp(float temp) { + const bool is_empty = temp == 1.0f; + + if (is_empty) { + return llama_sampler_init_empty("?temp"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_temp_i, /* .ctx = */ new llama_sampler_temp { + ("temp"), /*.temp = */ temp, } ); @@ -1092,14 +1899,15 @@ struct llama_sampler * llama_sampler_init_temp(float temp) { // temp-ext -struct llama_sampler_temp_ext { +struct llama_sampler_temp_ext : public llama_sampler_backend { const float temp; const float delta; const float exponent; }; -static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) { - return "temp-ext"; +static const char * llama_sampler_temp_ext_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_temp_ext *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -1182,24 +1990,112 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) { delete (llama_sampler_temp_ext *) smpl->ctx; } +static bool llama_sampler_temp_ext_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_temp_ext *) smpl->ctx; + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + return res; +} + +static void llama_sampler_temp_ext_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + auto * sctx = (llama_sampler_temp_ext *) smpl->ctx; + + // Revert to standard temperature scaling if delta or temp are non-positive. + if (sctx->delta <= 0.0f || sctx->temp <= 0.0f) { + llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp); + return; + } + + // Calculate min_temp, max_temp, and max_entropy. + const float min_temp = std::max(0.0f, sctx->temp - sctx->delta); + const float max_temp = sctx->temp + sctx->delta; + const float max_entropy = logf(data->logits->ne[0]); + + // Calculate the probabilities. + struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits); + ggml_set_name(probs, "temp_ext_softmax_probs"); + + // Clamp probabilities to avoid log(0) which would give -inf + struct ggml_tensor * probs_clamped = ggml_clamp(ctx, probs, 1e-10f, 1.0f); + ggml_set_name(probs_clamped, "temp_ext_probs_clamped"); + + // Calculate the entropy, entropy = -Σ(p * log(p)). + struct ggml_tensor * log_probs = ggml_log(ctx, probs_clamped); + struct ggml_tensor * p_log_p = ggml_mul(ctx, probs_clamped, log_probs); + struct ggml_tensor * sum_p_log_p = ggml_sum(ctx, p_log_p); + struct ggml_tensor * entropy = ggml_scale(ctx, sum_p_log_p, -1.0f); + ggml_set_name(log_probs, "temp_ext_log_probs"); + ggml_set_name(p_log_p, "temp_ext_p_log_p"); + ggml_set_name(sum_p_log_p, "temp_ext_sum_p_log_p"); + ggml_set_name(entropy, "temp_ext_entropy"); + + // Normalize the entropy, norm_entropy = entropy / max_entropy + struct ggml_tensor * norm_entropy = ggml_scale(ctx, entropy, 1.0f / max_entropy); + ggml_set_name(norm_entropy, "temp_ext_norm_entropy"); + + // Calculate the dynamic temperature: + // dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent); + // + // Calculate powf(normalized_entropy, exponent) as + // norm_entropy^exponent = exp(exponent * log(norm_entropy)) + struct ggml_tensor * log_norm_entropy = ggml_log(ctx, norm_entropy); + struct ggml_tensor * scaled_log = ggml_scale(ctx, log_norm_entropy, sctx->exponent); + struct ggml_tensor * pow_entropy = ggml_exp(ctx, scaled_log); + // With pow_entropy computed we can now compute dyn_temp, scaling by + // (max_temp - min_temp) and then adding min_temp. + struct ggml_tensor * dyn_temp = ggml_scale_bias(ctx, pow_entropy, max_temp - min_temp, min_temp); + ggml_set_name(log_norm_entropy, "temp_ext_log_norm_entropy"); + ggml_set_name(scaled_log, "temp_ext_scaled_log"); + ggml_set_name(pow_entropy, "temp_ext_pow_entropy"); + ggml_set_name(dyn_temp, "temp_ext_dyn_temp"); + + // Scale the logits by the dynamic temperature + struct ggml_tensor * scaled_logits = ggml_div(ctx, data->logits, dyn_temp); + ggml_set_name(scaled_logits, "temp_ext_scaled_logits"); + + data->logits = scaled_logits; +} + static struct llama_sampler_i llama_sampler_temp_ext_i = { - /* .name = */ llama_sampler_temp_ext_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_temp_ext_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_temp_ext_clone, - /* .free = */ llama_sampler_temp_ext_free, + /* .name = */ llama_sampler_temp_ext_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_temp_ext_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_temp_ext_clone, + /* .free = */ llama_sampler_temp_ext_free, + /* .backend_init = */ llama_sampler_temp_ext_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_temp_ext_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { - return llama_sampler_init( + const bool is_empty = temp == 1.0f && delta <= 0.0f; + + if (is_empty) { + return llama_sampler_init_empty("?temp-ext"); + } + + auto * res = llama_sampler_init( /* .iface = */ &llama_sampler_temp_ext_i, /* .ctx = */ new llama_sampler_temp_ext { + ("temp-ext"), /* .temp = */ temp, /* .delta = */ delta, /* .exponent = */ exponent, } ); + + return res; } // xtc @@ -1212,7 +2108,7 @@ struct llama_sampler_xtc { const uint32_t seed; uint32_t seed_cur; - std::mt19937 rng; + std::mt19937 rng; }; static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) { @@ -1277,16 +2173,27 @@ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_xtc_i = { - /* .name = */ llama_sampler_xtc_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sample_xtc_apply, - /* .reset = */ llama_sampler_xtc_reset, - /* .clone = */ llama_sampler_xtc_clone, - /* .free = */ llama_sampler_xtc_free, + /* .name = */ llama_sampler_xtc_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sample_xtc_apply, + /* .reset = */ llama_sampler_xtc_reset, + /* .clone = */ llama_sampler_xtc_clone, + /* .free = */ llama_sampler_xtc_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) { - auto seed_cur = get_rng_seed(seed); + const bool is_empty = (p <= 0.0f || t > 0.5f); + + if (is_empty) { + return llama_sampler_init_empty("?xtc"); + } + + const auto seed_cur = get_rng_seed(seed); + return llama_sampler_init( /* .iface = */ &llama_sampler_xtc_i, /* .ctx = */ new llama_sampler_xtc { @@ -1385,16 +2292,21 @@ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_mirostat_i = { - /* .name = */ llama_sampler_mirostat_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_mirostat_apply, - /* .reset = */ llama_sampler_mirostat_reset, - /* .clone = */ llama_sampler_mirostat_clone, - /* .free = */ llama_sampler_mirostat_free, + /* .name = */ llama_sampler_mirostat_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_mirostat_apply, + /* .reset = */ llama_sampler_mirostat_reset, + /* .clone = */ llama_sampler_mirostat_clone, + /* .free = */ llama_sampler_mirostat_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) { - auto seed_cur = get_rng_seed(seed); + const auto seed_cur = get_rng_seed(seed); + return llama_sampler_init( /* .iface = */ &llama_sampler_mirostat_i, /* .ctx = */ new llama_sampler_mirostat { @@ -1484,12 +2396,16 @@ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_mirostat_v2_i = { - /* .name = */ llama_sampler_mirostat_v2_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_mirostat_v2_apply, - /* .reset = */ llama_sampler_mirostat_v2_reset, - /* .clone = */ llama_sampler_mirostat_v2_clone, - /* .free = */ llama_sampler_mirostat_v2_free, + /* .name = */ llama_sampler_mirostat_v2_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_mirostat_v2_apply, + /* .reset = */ llama_sampler_mirostat_v2_reset, + /* .clone = */ llama_sampler_mirostat_v2_clone, + /* .free = */ llama_sampler_mirostat_v2_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { @@ -1601,12 +2517,16 @@ static void llama_sampler_grammar_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_grammar_i = { - /* .name = */ llama_sampler_grammar_name, - /* .accept = */ llama_sampler_grammar_accept_impl, - /* .apply = */ llama_sampler_grammar_apply, - /* .reset = */ llama_sampler_grammar_reset, - /* .clone = */ llama_sampler_grammar_clone, - /* .free = */ llama_sampler_grammar_free, + /* .name = */ llama_sampler_grammar_name, + /* .accept = */ llama_sampler_grammar_accept_impl, + /* .apply = */ llama_sampler_grammar_apply, + /* .reset = */ llama_sampler_grammar_reset, + /* .clone = */ llama_sampler_grammar_clone, + /* .free = */ llama_sampler_grammar_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; static struct llama_sampler * llama_sampler_init_grammar_impl( @@ -1808,12 +2728,16 @@ static void llama_sampler_penalties_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_penalties_i = { - /* .name = */ llama_sampler_penalties_name, - /* .accept = */ llama_sampler_penalties_accept, - /* .apply = */ llama_sampler_penalties_apply, - /* .reset = */ llama_sampler_penalties_reset, - /* .clone = */ llama_sampler_penalties_clone, - /* .free = */ llama_sampler_penalties_free, + /* .name = */ llama_sampler_penalties_name, + /* .accept = */ llama_sampler_penalties_accept, + /* .apply = */ llama_sampler_penalties_apply, + /* .reset = */ llama_sampler_penalties_reset, + /* .clone = */ llama_sampler_penalties_clone, + /* .free = */ llama_sampler_penalties_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_penalties( @@ -1823,6 +2747,12 @@ struct llama_sampler * llama_sampler_init_penalties( float penalty_present) { penalty_last_n = std::max(penalty_last_n, 0); + const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)); + + if (is_empty) { + return llama_sampler_init_empty("?penalties"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_penalties_i, /* .ctx = */ new llama_sampler_penalties { @@ -1860,9 +2790,7 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t for (size_t i = 0; i < cur_p->size; ++i) { // Only count non-negative infinity values if (cur_p->data[i].logit != -INFINITY) { - if (cur_p->data[i].logit > max) { - max = cur_p->data[i].logit; - } + max = std::max(max, cur_p->data[i].logit); logits_sum += cur_p->data[i].logit; valid_count++; } @@ -1899,15 +2827,25 @@ static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_top_n_sigma_i = { - /* .name = */ llama_sampler_top_n_sigma_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_top_n_sigma_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_top_n_sigma_clone, - /* .free = */ llama_sampler_top_n_sigma_free, + /* .name = */ llama_sampler_top_n_sigma_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_n_sigma_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_n_sigma_clone, + /* .free = */ llama_sampler_top_n_sigma_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_top_n_sigma(float n) { + const bool is_empty = (n <= 0.0f); + + if (is_empty) { + return llama_sampler_init_empty("?top-n-sigma"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_top_n_sigma_i, /* .ctx = */ new llama_sampler_top_n_sigma { @@ -2229,12 +3167,16 @@ static void llama_sampler_dry_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_dry_i = { - /* .name = */ llama_sampler_dry_name, - /* .accept = */ llama_sampler_dry_accept, - /* .apply = */ llama_sampler_dry_apply, - /* .reset = */ llama_sampler_dry_reset, - /* .clone = */ llama_sampler_dry_clone, - /* .free = */ llama_sampler_dry_free, + /* .name = */ llama_sampler_dry_name, + /* .accept = */ llama_sampler_dry_accept, + /* .apply = */ llama_sampler_dry_apply, + /* .reset = */ llama_sampler_dry_reset, + /* .clone = */ llama_sampler_dry_clone, + /* .free = */ llama_sampler_dry_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { @@ -2245,6 +3187,10 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0); + if (!dry_enabled) { + return llama_sampler_init_empty("?dry"); + } + if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) { // Process sequence breakers for (size_t i = 0; i < num_breakers; ++i) { @@ -2313,18 +3259,186 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa return result; } +// adaptive-p sampler state +// +// maintains an exponential moving average of the *ORIGINAL* probabilities +// of selected tokens, used to compute an adapted target at each sampling step. +// +// see llama.h for a full description of the sampler +// +// ref: https://github.com/ggml-org/llama.cpp/pull/17927 +// +struct llama_sampler_adaptive_p { + const float target; // target probability (0.0 - 1.0; negative = disabled) + const float decay; // EMA decay; history ~= 1/(1-decay) tokens (0.0 - 0.99) + const uint32_t seed; // original RNG seed + uint32_t seed_cur; // actual RNG seed + std::mt19937 rng; // RNG state + float weighted_sum; // sum(p_i * decay^i) + float total_weight; // sum(decay^i), converges to 1/(1-decay) + std::vector original_probs; // pre-transform probs, cached for EMA update + llama_token pending_token_id; // token ID of selected token + int32_t pending_token_idx; // index of orig. prob. of selected token in original_probs +}; + +// adaptive probability transformation constants +static constexpr float DISTRIBUTION_WIDTH = 0.3f; +static constexpr float PEAK_LOGIT_VALUE = 5.0f; +static constexpr float SHARPNESS = 10.0f; +static constexpr float INV_WIDTH = 1.0f / DISTRIBUTION_WIDTH; + +static const char * llama_sampler_adaptive_p_name(const struct llama_sampler * /*smpl*/) { + return "adaptive-p"; +} + +static void llama_sampler_adaptive_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx; + + llama_sampler_softmax_impl(cur_p, false); + + if (ctx->target < 0.0f) { + // at negative target values, adaptive-p is no-op + // we simply sample from the existing distribution + cur_p->selected = llama_sample_dist(cur_p, ctx->rng); + return; + } + + // store the original probabilities + ctx->original_probs.resize(cur_p->size); + for (size_t i = 0; i < cur_p->size; ++i) { + ctx->original_probs[i] = cur_p->data[i].p; + } + + // using the EMA, compute the adapted target probability for the current sampling step + auto target = std::clamp(ctx->target, 0.0f, 1.0f); + float adapted_target = std::clamp( + ctx->total_weight == 0.0f ? target : 2.0f * target - (ctx->weighted_sum / ctx->total_weight), + 0.0f, 1.0f + ); + + // adaptive probability transform + // + // quadratic near target for fine differentiation, transitioning to linear decay in the + // tails. unbounded negative logits ensure proper suppression of far-from-target tokens + // after the softmax. + // + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].logit == -INFINITY) { + // don't transform logits that are -INFINITY + // (as masked out by e.g. min-p and top-p when using backend sampling) + continue; + } + float dist = std::abs((cur_p->data[i].p - adapted_target) * INV_WIDTH); + cur_p->data[i].logit = PEAK_LOGIT_VALUE - SHARPNESS * dist * dist / (1.0f + dist); + } + + // softmax and sample from the transformed distribution + llama_sampler_softmax_impl(cur_p, false); + const int idx = llama_sample_dist(cur_p, ctx->rng); + cur_p->selected = idx; + + // store the selected token ID for acceptance later + ctx->pending_token_id = cur_p->data[idx].id; + ctx->pending_token_idx = idx; +} + +static void llama_sampler_adaptive_p_accept(struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx; + if (ctx->pending_token_id == token) { + GGML_ASSERT(ctx->pending_token_id != LLAMA_TOKEN_NULL); + GGML_ASSERT(ctx->pending_token_idx != -1); + // update EMA with the original probability of the selected token + ctx->weighted_sum = ctx->original_probs[ctx->pending_token_idx] + ctx->decay * ctx->weighted_sum; + ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight; + } + ctx->pending_token_id = LLAMA_TOKEN_NULL; + ctx->pending_token_idx = -1; +} + +static void llama_sampler_adaptive_p_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx; + // ctx->target and ctx->decay never change after init, so it's safe to keep them as is. + // original_probs is completely overwritten on every call to _apply. + // so we only need to reset the EMA state and pending token. + ctx->weighted_sum = ctx->target / (1.0f - ctx->decay); + ctx->total_weight = 1.0f / (1.0f - ctx->decay); + ctx->pending_token_id = LLAMA_TOKEN_NULL; + ctx->pending_token_idx = -1; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); +} + +static struct llama_sampler * llama_sampler_adaptive_p_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_adaptive_p *) smpl->ctx; + auto * result = llama_sampler_init_adaptive_p(ctx->target, ctx->decay, ctx->seed); + auto * result_ctx = (llama_sampler_adaptive_p *) result->ctx; + + // copy everything (target, decay, seed, and RNG are already set) + result_ctx->weighted_sum = ctx->weighted_sum; + result_ctx->total_weight = ctx->total_weight; + result_ctx->pending_token_id = ctx->pending_token_id; + result_ctx->pending_token_idx = ctx->pending_token_idx; + + return result; +} + +static void llama_sampler_adaptive_p_free(struct llama_sampler * smpl) { + delete (llama_sampler_adaptive_p *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_adaptive_p_i = { + /* .name = */ llama_sampler_adaptive_p_name, + /* .accept = */ llama_sampler_adaptive_p_accept, + /* .apply = */ llama_sampler_adaptive_p_apply, + /* .reset = */ llama_sampler_adaptive_p_reset, + /* .clone = */ llama_sampler_adaptive_p_clone, + /* .free = */ llama_sampler_adaptive_p_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, +}; + +struct llama_sampler * llama_sampler_init_adaptive_p( + float target, + float decay, + uint32_t seed +) { + auto seed_cur = get_rng_seed(seed); + float clamped_decay = std::clamp(decay, 0.0f, 0.99f); + return llama_sampler_init( + /* .iface = */ &llama_sampler_adaptive_p_i, + /* .ctx = */ new llama_sampler_adaptive_p { + /* .target = */ target, + /* .decay = */ clamped_decay, + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .rng = */ std::mt19937(seed_cur), + /* .weighted_sum = */ target / (1.0f - clamped_decay), + /* .total_weight = */ 1.0f / (1.0f - clamped_decay), + /* .original_probs = */ {}, + /* .pending_token_id = */ LLAMA_TOKEN_NULL, + /* .pending_token_idx = */ -1 + } + ); +} + // logit-bias -struct llama_sampler_logit_bias { +struct llama_sampler_logit_bias : public llama_sampler_backend { const int32_t n_vocab; const std::vector logit_bias; std::vector to_search; + + struct ggml_tensor * inp_logit_bias; + struct ggml_tensor * inp_logit_idxs; }; -static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) { - return "logit-bias"; +static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_logit_bias *) smpl->ctx; + return ctx->get_name(); } static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -2369,25 +3483,110 @@ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) { delete (llama_sampler_logit_bias *) smpl->ctx; } +static void llama_sampler_logit_bias_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + GGML_UNUSED(gf); + GGML_UNUSED(ctx); + + auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; + if (sctx->logit_bias.empty()) { + return; + } + + const size_t n = sctx->logit_bias.size(); + + sctx->inp_logit_bias = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n); + ggml_set_name(sctx->inp_logit_bias, "logit_bias"); + ggml_set_input(sctx->inp_logit_bias); + + sctx->inp_logit_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n); + ggml_set_name(sctx->inp_logit_idxs, "logit_idxs"); + ggml_set_input(sctx->inp_logit_idxs); + + ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f); + + cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur)); + cur = ggml_set_rows(ctx, cur, sctx->inp_logit_bias, sctx->inp_logit_idxs); + cur = ggml_reshape_1d(ctx, cur, ggml_nelements(cur)); + + data->logits = ggml_add(ctx, data->logits, cur); +} + +static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; + if (sctx->logit_bias.empty()) { + return; + } + + GGML_ASSERT(sctx->inp_logit_bias != nullptr); + GGML_ASSERT(sctx->inp_logit_idxs != nullptr); + + const size_t n = sctx->logit_bias.size(); + + std::vector data_logit_bias(n, 0.0f); + std::vector data_logit_idxs(n, 0); + for (size_t i = 0; i < n; ++i) { + const auto & lb = sctx->logit_bias[i]; + GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab); + data_logit_bias[i] = lb.bias; + data_logit_idxs[i] = lb.token; + } + + ggml_backend_tensor_set(sctx->inp_logit_bias, data_logit_bias.data(), 0, ggml_nbytes(sctx->inp_logit_bias)); + ggml_backend_tensor_set(sctx->inp_logit_idxs, data_logit_idxs.data(), 0, ggml_nbytes(sctx->inp_logit_idxs)); +} + +static bool llama_sampler_logit_bias_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + GGML_UNUSED(buft); + + auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; + + sctx->init(true); + + if (sctx->logit_bias.empty()) { + return true; + } + + return true; +} + static struct llama_sampler_i llama_sampler_logit_bias_i = { - /* .name = */ llama_sampler_logit_bias_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_logit_bias_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_logit_bias_clone, - /* .free = */ llama_sampler_logit_bias_free, + /* .name = */ llama_sampler_logit_bias_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_logit_bias_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_logit_bias_clone, + /* .free = */ llama_sampler_logit_bias_free, + /* .backend_init = */ llama_sampler_logit_bias_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_logit_bias_backend_apply, + /* .backend_set_input = */ llama_sampler_logit_bias_backend_set_input, }; struct llama_sampler * llama_sampler_init_logit_bias( int32_t n_vocab, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { + const bool is_empty = n_logit_bias <= 0; + + if (is_empty) { + return llama_sampler_init_empty("?logit-bias"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_logit_bias_i, /* .ctx = */ new llama_sampler_logit_bias { - /* .n_vocab = */ n_vocab, - /* .logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias), - /* .to_search = */ {}, + ("logit-bias"), + /* .n_vocab = */ n_vocab, + /* .logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias), + /* .to_search = */ {}, + /* .inp_logit_bias = */ nullptr, + /* .inp_logit_idxs = */ nullptr, } ); } @@ -2600,12 +3799,16 @@ static void llama_sampler_infill_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_infill_i = { - /* .name = */ llama_sampler_infill_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_infill_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_infill_clone, - /* .free = */ llama_sampler_infill_free, + /* .name = */ llama_sampler_infill_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_infill_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_infill_clone, + /* .free = */ llama_sampler_infill_free, + /* .backend_apply = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_set_input = */ nullptr, + /* .backend_init = */ nullptr, }; struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) { @@ -2637,7 +3840,7 @@ uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) { if (smpl->iface == &llama_sampler_chain_i) { const auto * ctx = (const llama_sampler_chain *) smpl->ctx; for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) { - const uint32_t seed = llama_sampler_get_seed(*it); + const uint32_t seed = llama_sampler_get_seed(it->ptr); if (seed != LLAMA_DEFAULT_SEED) { return seed; } diff --git a/llama/llama.cpp/src/llama-sampler.h b/llama/llama.cpp/src/llama-sampler.h new file mode 100644 index 00000000000..b9bfc20d251 --- /dev/null +++ b/llama/llama.cpp/src/llama-sampler.h @@ -0,0 +1,42 @@ +#pragma once + +#include "llama.h" + +#include + +struct llama_vocab; +struct llama_grammar; + +// sampler chain + +struct llama_sampler_chain { + llama_sampler_chain_params params; + + // has .backend_init() been called? + bool is_init = false; + + struct info { + bool is_backend; + + llama_sampler * ptr; + }; + + std::vector samplers; + + // pre-allocated buffer for llama_sampler_sample to avoid repeated allocations + std::vector cur; + + // timing + + mutable int64_t t_sample_us; + + mutable int32_t n_sample; +}; + +struct llama_sampler * llama_sampler_init_dry_testing( + int32_t context_size, + float dry_multiplier, + float dry_base, + int32_t dry_allowed_length, + int32_t dry_penalty_last_n, + const std::vector> & seq_breakers); diff --git a/llama/llama.cpp/src/llama-sampling.h b/llama/llama.cpp/src/llama-sampling.h deleted file mode 100644 index 759dd7dcb70..00000000000 --- a/llama/llama.cpp/src/llama-sampling.h +++ /dev/null @@ -1,32 +0,0 @@ -#pragma once - -// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ? - -#include "llama.h" - -#include - -struct llama_vocab; -struct llama_grammar; - -// sampler chain - -struct llama_sampler_chain { - llama_sampler_chain_params params; - - std::vector samplers; - - // timing - - mutable int64_t t_sample_us; - - mutable int32_t n_sample; -}; - -struct llama_sampler * llama_sampler_init_dry_testing( - int32_t context_size, - float dry_multiplier, - float dry_base, - int32_t dry_allowed_length, - int32_t dry_penalty_last_n, - const std::vector>& seq_breakers); diff --git a/llama/llama.cpp/src/llama-vocab.cpp b/llama/llama.cpp/src/llama-vocab.cpp index d63ce9c8493..b244318bdda 100644 --- a/llama/llama.cpp/src/llama-vocab.cpp +++ b/llama/llama.cpp/src/llama-vocab.cpp @@ -90,7 +90,7 @@ static_assert(std::is_trivially_copyable::value, "llm_symbol is not // // SPM tokenizer // original implementation: -// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4 +// https://github.com/ggml-org/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4 // struct llm_bigram_spm { @@ -285,10 +285,19 @@ struct llm_tokenizer_bpe : llm_tokenizer { // original regex from tokenizer.json //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", - // adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989 + // adapted: https://github.com/ggml-org/llama.cpp/pull/6920#issuecomment-2080233989 "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_JAIS2: + regex_exprs = { + // original regex from tokenizer.json + //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s{512}(?!\\S)|\\s{256}(?!\\S)|\\s{128}(?!\\S)|\\s{64}(?!\\S)|\\s{32}(?!\\S)|\\s{16}(?!\\S)|\\s{8}(?!\\S)|\\s{4}(?!\\S)|\\s{1,2}(?!\\S)|\\s{1}", + + // adapted: same as llama3 but with cascading whitespace pattern + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s{512}(?!\\S)|\\s{256}(?!\\S)|\\s{128}(?!\\S)|\\s{64}(?!\\S)|\\s{32}(?!\\S)|\\s{16}(?!\\S)|\\s{8}(?!\\S)|\\s{4}(?!\\S)|\\s{1,2}(?!\\S)|\\s{1}", + }; + break; case LLAMA_VOCAB_PRE_TYPE_DBRX: case LLAMA_VOCAB_PRE_TYPE_SMAUG: regex_exprs = { @@ -299,7 +308,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM: regex_exprs = { "[\r\n]", - "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z\U00010400-\U0001044f𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+", + "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+", "\\s?[!-/:-~!-/:-~‘-‟ -。]+", "\\s+$", "[一-龥ࠀ-一가-퟿]+", @@ -308,12 +317,19 @@ struct llm_tokenizer_bpe : llm_tokenizer { break; case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM: case LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE: + case LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM: regex_exprs = { "\\p{N}{1,3}", "[一-龥぀-ゟ゠-ヿ]+", "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_YOUTU: + regex_exprs = { + "[가-힣ㄱ-ㆎ]+|[!…“”‘’—:;,、-〿︰-﹏]+|[ㄅ-ㄯ]+|[一-龥぀-ゟ゠-ヿ]+", + "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER: regex_exprs = { "[\r\n]", @@ -355,12 +371,20 @@ struct llm_tokenizer_bpe : llm_tokenizer { case LLAMA_VOCAB_PRE_TYPE_STABLELM2: case LLAMA_VOCAB_PRE_TYPE_QWEN2: case LLAMA_VOCAB_PRE_TYPE_HUNYUAN: + case LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN: regex_exprs = { // original regex from tokenizer.json // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_QWEN35: + regex_exprs = { + // original regex from tokenizer.json + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_PORO: case LLAMA_VOCAB_PRE_TYPE_BLOOM: case LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH: @@ -408,6 +432,14 @@ struct llm_tokenizer_bpe : llm_tokenizer { "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_TINY_AYA: + regex_exprs = { + // original regex from tokenizer.json: "\\d{1,3}(?=(?:\\d{3})*\\b)" + "\\d{1,3}(?=(?:\\d{3})*\\b)", + // original regex from tokenizer.json: "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_KIMI_K2: regex_exprs = { // K2 trigger pattern - this will activate the custom K2 handler in unicode.cpp @@ -454,6 +486,13 @@ struct llm_tokenizer_bpe : llm_tokenizer { "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\\r\\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE: + regex_exprs = { + // original regex from tokenizer.json + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+", + }; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -1738,26 +1777,33 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { // read bpe merges and populate bpe ranks const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); + // Kimi-K2 uses custom tokenization without traditional BPE merges + const bool is_kimi_k2 = (tokenizer_pre == "kimi-k2"); + if (merges_keyidx == -1) { - throw std::runtime_error("cannot find tokenizer merges in model file\n"); - } + if (!is_kimi_k2) { + throw std::runtime_error("cannot find tokenizer merges in model file\n"); + } + // Kimi-K2 doesn't need merges, skip + LLAMA_LOG_INFO("%s: Kimi-K2 tokenizer detected, skipping BPE merges\n", __func__); + } else { + const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); + for (int i = 0; i < n_merges; i++) { + const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); + //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); - const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); - for (int i = 0; i < n_merges; i++) { - const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); - //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); + std::string first; + std::string second; - std::string first; - std::string second; + const size_t pos = word.find(' ', 1); - const size_t pos = word.find(' ', 1); + if (pos != std::string::npos) { + first = word.substr(0, pos); + second = word.substr(pos + 1); + } - if (pos != std::string::npos) { - first = word.substr(0, pos); - second = word.substr(pos + 1); + bpe_ranks.emplace(std::make_pair(first, second), i); } - - bpe_ranks.emplace(std::make_pair(first, second), i); } // default special tokens @@ -1833,7 +1879,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "falcon-h1" || tokenizer_pre == "pixtral" || tokenizer_pre == "midm-2.0" || - tokenizer_pre == "lfm2") { + tokenizer_pre == "lfm2" || + tokenizer_pre == "jina-v5-nano") { pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3; ignore_merges = true; add_bos = true; @@ -1849,6 +1896,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "deepseek-v3") { pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM; clean_spaces = false; + } else if ( + tokenizer_pre == "youtu") { + pre_type = LLAMA_VOCAB_PRE_TYPE_YOUTU; + clean_spaces = false; + ignore_merges = true; } else if ( tokenizer_pre == "falcon") { pre_type = LLAMA_VOCAB_PRE_TYPE_FALCON; @@ -1867,8 +1919,12 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "jina-v2-es" || tokenizer_pre == "jina-v2-de" || tokenizer_pre == "a.x-4.0" || - tokenizer_pre == "mellum") { + tokenizer_pre == "mellum" || + tokenizer_pre == "modern-bert") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; + } else if ( + tokenizer_pre == "jais-2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS2; } else if ( tokenizer_pre == "jina-v1-en" || tokenizer_pre == "jina-v2-code" || @@ -1888,6 +1944,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "kormo") { pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; clean_spaces = false; + } else if ( + tokenizer_pre == "qwen35") { + pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN35; + clean_spaces = false; } else if ( tokenizer_pre == "stablelm2") { pre_type = LLAMA_VOCAB_PRE_TYPE_STABLELM2; @@ -1941,6 +2001,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } else if ( tokenizer_pre == "exaone4") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; + } else if ( + tokenizer_pre == "exaone-moe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE; } else if ( tokenizer_pre == "chameleon") { pre_type = LLAMA_VOCAB_PRE_TYPE_CHAMELEON; @@ -1953,10 +2016,15 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "megrez") { pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; } else if ( - tokenizer_pre == "gpt-4o" || - tokenizer_pre == "llama4") { + tokenizer_pre == "gpt-4o" || + tokenizer_pre == "llama4" || + tokenizer_pre == "kanana2") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O; clean_spaces = false; + } else if ( + tokenizer_pre == "tiny_aya") { + pre_type = LLAMA_VOCAB_PRE_TYPE_TINY_AYA; + clean_spaces = false; } else if ( tokenizer_pre == "superbpe") { pre_type = LLAMA_VOCAB_PRE_TYPE_SUPERBPE; @@ -1987,6 +2055,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "hunyuan-dense") { pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE; clean_spaces = false; + } else if ( + tokenizer_pre == "joyai-llm") { + pre_type = LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM; + clean_spaces = false; } else if ( tokenizer_pre == "kimi-k2") { pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2; @@ -2003,6 +2075,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "minimax-m2") { pre_type = LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2; clean_spaces = false; + } else if ( + tokenizer_pre == "solar-open") { + pre_type = LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN; + clean_spaces = false; } else { LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__); pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; @@ -2049,6 +2125,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { scores = (const float * ) gguf_get_arr_data(ctx, score_idx); } + const uint32_t n_scores = score_idx != -1 ? gguf_get_arr_n(ctx, score_idx) : 0; const int * toktypes = nullptr; const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str()); if (toktype_idx != -1) { @@ -2070,7 +2147,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { auto & token_data = id_to_token[i]; token_data.text = std::move(word); - token_data.score = scores ? scores[i] : 0.0f; + token_data.score = (scores && i < n_scores) ? scores[i] : 0.0f; token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; if (toktypes) { //TODO: remove, required until per token attributes are available from GGUF file @@ -2176,6 +2253,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { // for now, we apply this workaround to find the tokens based on their text for (const auto & t : token_to_id) { + auto & attr = id_to_token[t.second].attr; + // find EOT token: "<|eot_id|>", "<|im_end|>", "", etc. if (special_eot_id == LLAMA_TOKEN_NULL) { if (false @@ -2187,14 +2266,15 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|end_of_text|>" // granite || t.first == "" || t.first == "_" + || t.first == "[EOT]" // Kimi-K2 || t.first == "<|end▁of▁sentence|>" // DeepSeek || t.first == "" // smoldocling ) { special_eot_id = t.second; - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } } @@ -2205,10 +2285,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|eom_id|>" ) { special_eom_id = t.second; - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } } @@ -2223,12 +2303,13 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "
"
                         || t.first == "▁
"          // CodeLlama
                         || t.first == "<|code_prefix|>" // GLM-4.5
+                        || t.first == "<|prefix|>"      // Falcon-H1-Tiny-Coder
                         ) {
                     special_fim_pre_id = t.second;
-                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
                         LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
                                 __func__, t.second, t.first.c_str());
-                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                        attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
                     }
                 }
             }
@@ -2243,12 +2324,13 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
                         || t.first == "<|code_suffix|>" // GLM-4.5
+                        || t.first == "<|suffix|>"      // Falcon-H1-Tiny-Coder
                         ) {
                     special_fim_suf_id = t.second;
-                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
                         LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
                                 __func__, t.second, t.first.c_str());
-                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                        attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
                     }
                 }
             }
@@ -2263,12 +2345,13 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
                         || t.first == "<|code_middle|>" // GLM-4.5
+                        || t.first == "<|middle|>"      // Falcon-H1-Tiny-Coder
                         ) {
                     special_fim_mid_id = t.second;
-                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
                         LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
                                 __func__, t.second, t.first.c_str());
-                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                        attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
                     }
                 }
             }
@@ -2280,12 +2363,13 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == ""
                         || t.first == ""   // Granite
                         || t.first == ""
+                        || t.first == "[PAD]" // Kimi-K2
                         ) {
                     special_fim_pad_id = t.second;
-                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
                         LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
                                 __func__, t.second, t.first.c_str());
-                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                        attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
                     }
                 }
             }
@@ -2300,10 +2384,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == ""    // Granite
                         ) {
                     special_fim_rep_id = t.second;
-                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
                         LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
                                 __func__, t.second, t.first.c_str());
-                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                        attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
                     }
                 }
             }
@@ -2314,18 +2398,44 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<|file_sep|>" // Qwen
                         ) {
                     special_fim_sep_id = t.second;
-                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
                         LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
                                 __func__, t.second, t.first.c_str());
-                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                        attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
+                    }
+                }
+            }
+        }
+
+        // auto-detect unused tokens: e.g. control tokens with the word "unused"
+        // ideally, these tokens should be marked as unused during conversion
+        {
+            uint32_t n_unused = 0;
+
+            for (const auto & t : token_to_id) {
+                auto & attr = id_to_token[t.second].attr;
+
+                if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    continue;
+                }
+
+                if ((attr & LLAMA_TOKEN_ATTR_UNUSED) == 0) {
+                    if (strstr(t.first.c_str(), "unused") != NULL) {
+                        attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_UNUSED);
                     }
                 }
+
+                if (attr & LLAMA_TOKEN_ATTR_UNUSED) {
+                    n_unused++;
+                }
             }
+
+            LLAMA_LOG_INFO("%s: %u unused tokens\n", __func__, n_unused);
         }
 
         // maintain a list of tokens that cause end-of-generation
         // this is currently determined based on the token text, which is obviously not ideal
-        // ref: https://github.com/ggerganov/llama.cpp/issues/9606
+        // ref: https://github.com/ggml-org/llama.cpp/issues/9606
         special_eog_ids.clear();
 
         if (special_fim_pad_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_pad_id) == 0) {
@@ -2341,39 +2451,53 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
         }
 
         for (const auto & t : token_to_id) {
+            auto & attr = id_to_token[t.second].attr;
+
             if (false
                     || t.first == "<|eot_id|>"
                     || t.first == "<|im_end|>"
                     || t.first == "<|end|>"
                     || t.first == "<|return|>" // o200k_harmony
                     || t.first == "<|call|>"   // o200k_harmony
+                    || t.first == "<|flush|>"  // solar-open
+                    || t.first == "<|calls|>"  // solar-open
                     || t.first == ""
                     || t.first == "<|endoftext|>"
+                    || t.first == ""      // paddleocr
                     || t.first == "<|eom_id|>"
                     || t.first == ""
                     || t.first == "_"
+                    || t.first == "[EOT]" // Kimi-K2
+                    || t.first == "[EOS]" // Kimi-K2
                     || t.first == "<|end_of_text|>"
                     || t.first == "" // smoldocling
                ) {
                 special_eog_ids.insert(t.second);
-                if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
                     LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
                             __func__, t.second, t.first.c_str());
-                    id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
                 }
             } else {
-                // token is control, but not marked as EOG -> print a debug log
-                if (id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL && special_eog_ids.count(t.second) == 0) {
-                    LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n",
-                            __func__, t.second, t.first.c_str());
+                if (attr & LLAMA_TOKEN_ATTR_CONTROL && !(attr & LLAMA_TOKEN_ATTR_UNUSED)) {
+                    // token is control, but not marked as EOG -> print a debug log
+                    if (special_eog_ids.count(t.second) == 0) {
+                        LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n",
+                                __func__, t.second, t.first.c_str());
+                    }
                 }
             }
         }
 
         // @ngxson : quick hack for gpt-oss, always render these tokens
         for (const auto & t : token_to_id) {
+            auto & attr = id_to_token[t.second].attr;
+
             if (t.first == "<|channel|>" || t.first == "<|message|>" || t.first == "<|start|>" || t.first == "<|constrain|>") {
-                id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_USER_DEFINED;
+                LLAMA_LOG_WARN("%s: setting token '%s' (%d) attribute to USER_DEFINED (%u), old attributes: %u\n",
+                        __func__, t.first.c_str(), t.second, LLAMA_TOKEN_ATTR_USER_DEFINED, attr);
+
+                attr = LLAMA_TOKEN_ATTR_USER_DEFINED;
             }
         }
 
@@ -2393,34 +2517,42 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
             LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
         }
 
-        // TODO: workaround for o200k_harmony tokenizer: the "<|end|>" token should not be EOG
-        //       we don't have a good way to detect this, so for now, if we have "<|return|>" and "<|call|>" tokens,
+        // TODO: workaround for o200k_harmony and solar-open tokenizer: the "<|end|>" token should not be EOG
+        //       we don't have a good way to detect this, so for now, if we have "<|return|>" and "<|call|>" tokens ("<|calls|>" and "<|flush|>" for solar-open),
         //       we remove the "<|end|>" token from the EOG list
         {
             bool has_return = false;
             bool has_call   = false;
             bool has_end    = false;
+            bool has_flush  = false;
 
             llama_token end_id = LLAMA_TOKEN_NULL;
 
             LLAMA_LOG_INFO("%s: printing all EOG tokens:\n", __func__);
             for (auto tid : special_eog_ids) {
-                LLAMA_LOG_INFO("%s:   - %d ('%s')\n", __func__, tid, id_to_token[tid].text.c_str());
+                auto & text = id_to_token[tid].text;
+
+                LLAMA_LOG_INFO("%s:   - %d ('%s')\n", __func__, tid, text.c_str());
 
-                if (id_to_token[tid].text == "<|return|>") {
+                if (text == "<|return|>") {
                     has_return = true;
-                } else if (id_to_token[tid].text == "<|call|>") {
+                } else if (text == "<|call|>" || text == "<|calls|>") {
                     has_call = true;
-                } else if (id_to_token[tid].text == "<|end|>") {
+                } else if (text == "<|flush|>") {
+                    has_flush = true;
+                } else if (text == "<|end|>") {
                     has_end = true;
                     end_id = tid;
                 }
             }
 
-            if (has_return && has_call && has_end) {
+            if ((has_return && has_call && has_end) || (has_call && has_flush && has_end)) {
                 special_eog_ids.erase(end_id);
-                id_to_token[end_id].attr = LLAMA_TOKEN_ATTR_USER_DEFINED;
-                LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>' tokens, removing '<|end|>' token from EOG list\n", __func__);
+
+                auto & attr = id_to_token[end_id].attr;
+                attr = LLAMA_TOKEN_ATTR_USER_DEFINED;
+
+                LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>', or '<|calls|>' and '<|flush|>' tokens, removing '<|end|>' token from EOG list\n", __func__);
             }
         }
     }
@@ -2518,6 +2650,13 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
             for (const auto * token : {"", "", "<|endoftext|>"}) {
                 _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
             }
+        } else if (_contains_any(model_name, {"modern-bert"})) {
+            if (token_to_id.count("[MASK]") == 0 ) {
+                LLAMA_LOG_WARN("%s: Mask token missing in vocab!\n", __func__);
+            }
+            else {
+                _set_token_attr("[MASK]", LLAMA_TOKEN_ATTR_LSTRIP, true);
+            }
         }
     }
 }
@@ -2988,7 +3127,7 @@ std::vector llama_vocab::impl::tokenize(
 }
 
 int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) const {
-    // ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843
+    // ref: https://github.com/ggml-org/llama.cpp/pull/7587#discussion_r1620983843
     static const int attr_special = LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_CONTROL;
     const llama_token_attr attr = token_get_attr(token);
     if (!special && (attr & attr_special)) {
@@ -3211,34 +3350,34 @@ int32_t llama_vocab::impl::detokenize(
 }
 
 void llama_vocab::impl::print_info() const {
-    LLAMA_LOG_INFO("%s: vocab type       = %s\n",     __func__, type_name().c_str());
-    LLAMA_LOG_INFO("%s: n_vocab          = %u\n",     __func__, vocab.n_tokens());
-    LLAMA_LOG_INFO("%s: n_merges         = %u\n",     __func__, (uint32_t) bpe_ranks.size());
+    LLAMA_LOG_INFO("%s: vocab type            = %s\n",     __func__, type_name().c_str());
+    LLAMA_LOG_INFO("%s: n_vocab               = %u\n",     __func__, vocab.n_tokens());
+    LLAMA_LOG_INFO("%s: n_merges              = %u\n",     __func__, (uint32_t) bpe_ranks.size());
 
     // special tokens
-    if (special_bos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, special_bos_id,     id_to_token.at(special_bos_id).text.c_str() );  }
-    if (special_eos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, special_eos_id,     id_to_token.at(special_eos_id).text.c_str() );  }
-    if (special_eot_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, special_eot_id,     id_to_token.at(special_eot_id).text.c_str() );  }
-    if (special_eom_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOM token        = %d '%s'\n", __func__, special_eom_id,     id_to_token.at(special_eom_id).text.c_str() );  }
-    if (special_unk_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, special_unk_id,     id_to_token.at(special_unk_id).text.c_str() );  }
-    if (special_sep_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, special_sep_id,     id_to_token.at(special_sep_id).text.c_str() );  }
-    if (special_pad_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, special_pad_id,     id_to_token.at(special_pad_id).text.c_str() );  }
-    if (special_mask_id != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, special_mask_id,    id_to_token.at(special_mask_id).text.c_str() ); }
-
-    if (linefeed_id != LLAMA_TOKEN_NULL)        { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, linefeed_id,        id_to_token.at(linefeed_id).text.c_str() ); }
-
-    if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token    = %d '%s'\n", __func__, special_fim_pre_id, id_to_token.at(special_fim_pre_id).text.c_str() ); }
-    if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token    = %d '%s'\n", __func__, special_fim_suf_id, id_to_token.at(special_fim_suf_id).text.c_str() ); }
-    if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token    = %d '%s'\n", __func__, special_fim_mid_id, id_to_token.at(special_fim_mid_id).text.c_str() ); }
-    if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token    = %d '%s'\n", __func__, special_fim_pad_id, id_to_token.at(special_fim_pad_id).text.c_str() ); }
-    if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token    = %d '%s'\n", __func__, special_fim_rep_id, id_to_token.at(special_fim_rep_id).text.c_str() ); }
-    if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token    = %d '%s'\n", __func__, special_fim_sep_id, id_to_token.at(special_fim_sep_id).text.c_str() ); }
+    if (special_bos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: BOS token             = %d '%s'\n", __func__, special_bos_id,     id_to_token.at(special_bos_id).text.c_str() );  }
+    if (special_eos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOS token             = %d '%s'\n", __func__, special_eos_id,     id_to_token.at(special_eos_id).text.c_str() );  }
+    if (special_eot_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOT token             = %d '%s'\n", __func__, special_eot_id,     id_to_token.at(special_eot_id).text.c_str() );  }
+    if (special_eom_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOM token             = %d '%s'\n", __func__, special_eom_id,     id_to_token.at(special_eom_id).text.c_str() );  }
+    if (special_unk_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: UNK token             = %d '%s'\n", __func__, special_unk_id,     id_to_token.at(special_unk_id).text.c_str() );  }
+    if (special_sep_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: SEP token             = %d '%s'\n", __func__, special_sep_id,     id_to_token.at(special_sep_id).text.c_str() );  }
+    if (special_pad_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: PAD token             = %d '%s'\n", __func__, special_pad_id,     id_to_token.at(special_pad_id).text.c_str() );  }
+    if (special_mask_id != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: MASK token            = %d '%s'\n", __func__, special_mask_id,    id_to_token.at(special_mask_id).text.c_str() ); }
+
+    if (linefeed_id != LLAMA_TOKEN_NULL)        { LLAMA_LOG_INFO( "%s: LF token              = %d '%s'\n", __func__, linefeed_id,        id_to_token.at(linefeed_id).text.c_str() ); }
+
+    if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token         = %d '%s'\n", __func__, special_fim_pre_id, id_to_token.at(special_fim_pre_id).text.c_str() ); }
+    if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token         = %d '%s'\n", __func__, special_fim_suf_id, id_to_token.at(special_fim_suf_id).text.c_str() ); }
+    if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token         = %d '%s'\n", __func__, special_fim_mid_id, id_to_token.at(special_fim_mid_id).text.c_str() ); }
+    if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token         = %d '%s'\n", __func__, special_fim_pad_id, id_to_token.at(special_fim_pad_id).text.c_str() ); }
+    if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token         = %d '%s'\n", __func__, special_fim_rep_id, id_to_token.at(special_fim_rep_id).text.c_str() ); }
+    if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token         = %d '%s'\n", __func__, special_fim_sep_id, id_to_token.at(special_fim_sep_id).text.c_str() ); }
 
     for (const auto & id : special_eog_ids) {
-        LLAMA_LOG_INFO( "%s: EOG token        = %d '%s'\n", __func__, id, id_to_token.at(id).text.c_str() );
+        LLAMA_LOG_INFO( "%s: EOG token             = %d '%s'\n", __func__, id, id_to_token.at(id).text.c_str() );
     }
 
-    LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, max_token_len);
+    LLAMA_LOG_INFO("%s: max token length      = %d\n", __func__, max_token_len);
 }
 
 llama_vocab::llama_vocab() : pimpl(new impl(*this)) {
diff --git a/llama/llama.cpp/src/llama-vocab.h b/llama/llama.cpp/src/llama-vocab.h
index 55f8f3923c9..be5b08012df 100644
--- a/llama/llama.cpp/src/llama-vocab.h
+++ b/llama/llama.cpp/src/llama-vocab.h
@@ -51,6 +51,13 @@ enum llama_vocab_pre_type {
     LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40,
     LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2      = 41,
     LLAMA_VOCAB_PRE_TYPE_AFMOE           = 42,
+    LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN      = 43,
+    LLAMA_VOCAB_PRE_TYPE_YOUTU           = 44,
+    LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE      = 45,
+    LLAMA_VOCAB_PRE_TYPE_QWEN35          = 46,
+    LLAMA_VOCAB_PRE_TYPE_TINY_AYA        = 47,
+    LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM       = 48,
+    LLAMA_VOCAB_PRE_TYPE_JAIS2           = 49,
 };
 
 struct LLM_KV;
diff --git a/llama/llama.cpp/src/llama.cpp b/llama/llama.cpp/src/llama.cpp
index 759152b767d..c5aec0816e7 100644
--- a/llama/llama.cpp/src/llama.cpp
+++ b/llama/llama.cpp/src/llama.cpp
@@ -71,8 +71,9 @@ static std::vector llama_get_device_memory_data(
     }, &ud);
 
     llama_model_params mparams_copy = *mparams;
-    mparams_copy.no_alloc = true;
-    mparams_copy.use_mmap = false;
+    mparams_copy.no_alloc  = true;
+    mparams_copy.use_mmap  = false;
+    mparams_copy.use_mlock = false;
 
     llama_model * model = llama_model_load_from_file(path_model, mparams_copy);
     if (model == nullptr) {
@@ -110,8 +111,20 @@ static std::vector llama_get_device_memory_data(
         }
     }
     for (size_t i = 0; i < ret.size(); i++) {
-        size_t free, total;
+        size_t free;
+        size_t total;
         ggml_backend_dev_memory(model->devices[i], &free, &total);
+
+        // devices can return 0 bytes for free and total memory if they do not
+        // have any to report. in this case, we will use the host memory as a fallback
+        // fixes: https://github.com/ggml-org/llama.cpp/issues/18577
+        if (free == 0 && total == 0) {
+            ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
+            if (cpu_dev == nullptr) {
+                throw std::runtime_error(format("%s: no CPU backend found", __func__));
+            }
+            ggml_backend_dev_memory(cpu_dev, &free, &total);
+        }
         ret[i].free  = free;
         ret[i].total = total;
     }
@@ -139,12 +152,15 @@ enum layer_fraction_t {
 };
 // this enum is only used in llama_params_fit_impl but needs to be defined outside of it to fix a Windows compilation issue
 
+class llama_params_fit_exception : public std::runtime_error {
+    using std::runtime_error::runtime_error;
+};
+
 static void llama_params_fit_impl(
         const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams,
         float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides,
-        size_t margin_s, uint32_t n_ctx_min, enum ggml_log_level log_level) {
+        size_t * margins_s, uint32_t n_ctx_min, enum ggml_log_level log_level) {
     constexpr int64_t MiB = 1024*1024;
-    const int64_t margin = margin_s; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits
     typedef std::vector dmds_t;
     const llama_model_params default_mparams = llama_model_default_params();
 
@@ -163,6 +179,12 @@ static void llama_params_fit_impl(
         return;
     }
 
+    std::vector margins; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits
+    margins.reserve(nd);
+    for (size_t id = 0; id < nd; id++) {
+        margins.push_back(margins_s[id]);
+    }
+
     std::vector dev_names;
     {
         dev_names.reserve(nd);
@@ -180,11 +202,12 @@ static void llama_params_fit_impl(
         }
     }
 
-    int64_t sum_total          = 0;
-    int64_t sum_projected_free = 0;
-    int64_t min_projected_free = INT64_MAX;
-    int64_t sum_projected_used = 0;
-    int64_t sum_projected_ctx  = 0;
+    int64_t sum_free            = 0;
+    int64_t sum_projected_free  = 0;
+    int64_t sum_projected_used  = 0;
+    int64_t sum_projected_model = 0;
+    std::vector projected_free_per_device;
+    projected_free_per_device.reserve(nd);
 
     if (nd > 1) {
         LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__);
@@ -194,63 +217,106 @@ static void llama_params_fit_impl(
 
         const int64_t projected_used = dmd.mb.total();
         const int64_t projected_free = dmd.free - projected_used;
+        projected_free_per_device.push_back(projected_free);
 
-        sum_total          += dmd.total;
-        sum_projected_used += projected_used;
-        sum_projected_free += projected_free;
-        min_projected_free  = std::min(min_projected_free, projected_free);
-        sum_projected_ctx  += dmd.mb.context;
+        sum_free            += dmd.free;
+        sum_projected_used  += projected_used;
+        sum_projected_free  += projected_free;
+        sum_projected_model += dmd.mb.model;
 
         if (nd > 1) {
-            LLAMA_LOG_INFO("%s:   - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " %s\n",
-                __func__, dev_names[id].c_str(), dmd.total/MiB, projected_used/MiB, std::abs(projected_free)/MiB,
-                projected_free >= 0 ? "surplus" : "deficit");
+            LLAMA_LOG_INFO("%s:   - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " free vs. target of %6" PRId64 "\n",
+                __func__, dev_names[id].c_str(), dmd.total/MiB, projected_used/MiB, projected_free/MiB, margins[id]/MiB);
         }
     }
-    assert(sum_total >= 0 && sum_projected_used >= 0 && sum_projected_ctx >= 0);
-    assert(sum_projected_used >= sum_projected_ctx);
+    assert(sum_free >= 0 && sum_projected_used >= 0);
     LLAMA_LOG_INFO("%s: projected to use %" PRId64 " MiB of device memory vs. %" PRId64 " MiB of free device memory\n",
-        __func__, sum_projected_used/MiB, sum_total/MiB);
-    if (min_projected_free >= margin) {
-        if (nd == 1) {
+        __func__, sum_projected_used/MiB, sum_free/MiB);
+    if (nd == 1) {
+        if (projected_free_per_device[0] >= margins[0]) {
             LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n",
-                __func__, min_projected_free/MiB, margin/MiB);
+                __func__, projected_free_per_device[0]/MiB, margins[0]/MiB);
+            return;
+        }
+    } else {
+        bool changes_needed = false;
+        for (size_t id = 0; id < nd; id++) {
+            if (projected_free_per_device[id] < margins[id]) {
+                changes_needed = true;
+                break;
+            }
+        }
+        if (!changes_needed) {
+            LLAMA_LOG_INFO("%s: targets for free memory can be met on all devices, no changes needed\n", __func__);
             return;
         }
-        LLAMA_LOG_INFO("%s: will leave at least %" PRId64 " >= %" PRId64 " MiB of free memory on all devices, no changes needed\n",
-            __func__, min_projected_free/MiB, margin/MiB);
-        return;
     }
 
     // step 2: try reducing memory use by reducing the context size
 
     {
-        int64_t global_surplus = sum_projected_free - int64_t(nd)*margin;
+        int64_t global_surplus = sum_projected_free;
+        for (size_t id = 0; id < nd; id++) {
+            global_surplus -= margins[id];
+        }
         if (global_surplus < 0) {
-            LLAMA_LOG_INFO(nd == 1 ?
-                "%s: cannot fulfill margin of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n" :
-                "%s: cannot fulfill margin of %" PRId64 " MiB on all devices, need to use %" PRId64 " MiB less in total\n",
-                __func__, margin/MiB, -global_surplus/MiB);
+            if (nd == 1) {
+                LLAMA_LOG_INFO("%s: cannot meet free memory target of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n",
+                    __func__, margins[0]/MiB, -global_surplus/MiB);
+            } else {
+                LLAMA_LOG_INFO(
+                    "%s: cannot meet free memory targets on all devices, need to use %" PRId64 " MiB less in total\n",
+                    __func__, -global_surplus/MiB);
+            }
             if (cparams->n_ctx == 0) {
                 if (hp_nct > n_ctx_min) {
-                    const int64_t bytes_per_ctx = sum_projected_ctx / hp_nct;
-                    const uint32_t ctx_reduction = std::min(
-                        uint32_t((-global_surplus + bytes_per_ctx - 1) / bytes_per_ctx), hp_nct - n_ctx_min);
-                    cparams->n_ctx = hp_nct - ctx_reduction;
-                    const int64_t memory_reduction = ctx_reduction * bytes_per_ctx;
-                    global_surplus += memory_reduction;
-                    LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n",
-                        __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
-                    if (global_surplus >= 0) {
+                    int64_t sum_used_target = sum_free;
+                    for (size_t id = 0; id < nd; id++) {
+                        sum_used_target -= margins[id];
+                    }
+                    if (nd > 1) {
+                        // for multiple devices we need to be more conservative in terms of how much context we think can fit:
+                        //   - for dense models only whole layers can be assigned to devices
+                        //   - for MoE models only whole tensors can be assigned to devices, which we estimate to be <= 1/3 of a layer
+                        //   - on average we expect a waste of 0.5 layers/tensors per device
+                        //   - use slightly more than the expected average for nd devices to be safe
+                        const int64_t model_per_layer = sum_projected_model / std::min(uint32_t(mparams->n_gpu_layers), hp_ngl);
+                        sum_used_target -= (nd + 1) * model_per_layer / (hp_nex == 0 ? 2 : 6);
+                    }
+
+                    int64_t sum_projected_used_min_ctx = 0;
+                    cparams->n_ctx = n_ctx_min;
+                    const dmds_t dmds_min_ctx = llama_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level);
+                    for (const auto & dmd : dmds_min_ctx) {
+                        sum_projected_used_min_ctx += dmd.mb.total();
+                    }
+                    if (sum_used_target > sum_projected_used_min_ctx) {
+                        // linear interpolation between minimum and maximum context size:
+                        cparams->n_ctx += (hp_nct - n_ctx_min) * (sum_used_target - sum_projected_used_min_ctx)
+                            / (sum_projected_used - sum_projected_used_min_ctx);
+                        cparams->n_ctx = std::max(cparams->n_ctx - cparams->n_ctx % 256, n_ctx_min); // round down context for CUDA backend
+
+                        const int64_t bytes_per_ctx = (sum_projected_used - sum_projected_used_min_ctx) / (hp_nct - n_ctx_min);
+                        const int64_t memory_reduction = (hp_nct - cparams->n_ctx) * bytes_per_ctx;
+                        LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n",
+                            __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
                         if (nd == 1) {
                             LLAMA_LOG_INFO("%s: entire model can be fit by reducing context\n", __func__);
                             return;
                         }
                         LLAMA_LOG_INFO("%s: entire model should be fit across devices by reducing context\n", __func__);
+                    } else {
+                        const int64_t memory_reduction = sum_projected_used - sum_projected_used_min_ctx;
+                        LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n",
+                            __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
                     }
                 } else {
-                    LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n",
-                        __func__, hp_nct, n_ctx_min);
+                    if (n_ctx_min == UINT32_MAX) {
+                        LLAMA_LOG_INFO("%s: user has requested full context size of %" PRIu32 " -> no change\n", __func__, hp_nct);
+                    } else {
+                        LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n",
+                            __func__, hp_nct, n_ctx_min);
+                    }
                 }
             } else {
                 LLAMA_LOG_INFO("%s: context size set by user to %" PRIu32 " -> no change\n", __func__, cparams->n_ctx);
@@ -259,32 +325,28 @@ static void llama_params_fit_impl(
     }
 
     if (mparams->n_gpu_layers != default_mparams.n_gpu_layers) {
-        throw std::runtime_error("n_gpu_layers already set by user to " + std::to_string(mparams->n_gpu_layers) + ", abort");
+        throw llama_params_fit_exception("n_gpu_layers already set by user to " + std::to_string(mparams->n_gpu_layers) + ", abort");
     }
     if (nd > 1) {
         if (!tensor_split) {
-            throw std::runtime_error("did not provide a buffer to write the tensor_split to, abort");
+            throw llama_params_fit_exception("did not provide a buffer to write the tensor_split to, abort");
         }
         if (mparams->tensor_split) {
             for (size_t id = 0; id < nd; id++) {
                 if (mparams->tensor_split[id] != 0.0f) {
-                    throw std::runtime_error("model_params::tensor_split already set by user, abort");
+                    throw llama_params_fit_exception("model_params::tensor_split already set by user, abort");
                 }
             }
         }
         if (mparams->split_mode == LLAMA_SPLIT_MODE_ROW) {
-            throw std::runtime_error("changing weight allocation for LLAMA_SPLIT_MODE_ROW not implemented, abort");
-        }
-        if (hp_ngl < 2*nd) {
-            throw std::runtime_error("model has only " + std::to_string(hp_ngl) + " layers but need at least "
-                + std::to_string(2*nd) + " to fit memory for " + std::to_string(nd) + " devices, abort");
+            throw llama_params_fit_exception("changing weight allocation for LLAMA_SPLIT_MODE_ROW not implemented, abort");
         }
     }
     if (!tensor_buft_overrides) {
-        throw std::runtime_error("did not provide buffer to set tensor_buft_overrides, abort");
+        throw llama_params_fit_exception("did not provide buffer to set tensor_buft_overrides, abort");
     }
     if (mparams->tensor_buft_overrides && (mparams->tensor_buft_overrides->pattern || mparams->tensor_buft_overrides->buft)) {
-        throw std::runtime_error("model_params::tensor_buft_overrides already set by user, abort");
+        throw llama_params_fit_exception("model_params::tensor_buft_overrides already set by user, abort");
     }
 
     // step 3: iteratively fill the back to front with "dense" layers
@@ -337,6 +399,11 @@ static void llama_params_fit_impl(
 
         // for the first partial layer varying parts can overflow, all further layers use LAYER_FRACTION_MOE:
         layer_fraction_t overflow_type = LAYER_FRACTION_MOE;
+
+        uint32_t n_full() const {
+            assert(n_layer >= n_part);
+            return n_layer - n_part;
+        }
     };
 
     const size_t ntbo = llama_max_tensor_buft_overrides();
@@ -345,8 +412,7 @@ static void llama_params_fit_impl(
     auto set_ngl_tensor_split_tbo = [&](
             const std::vector & ngl_per_device,
             const std::vector & overflow_bufts,
-            llama_model_params & mparams,
-            const bool add_nonrepeating) {
+            llama_model_params & mparams) {
         mparams.n_gpu_layers = 0;
         for (size_t id = 0; id < nd; id++) {
             mparams.n_gpu_layers += ngl_per_device[id].n_layer;
@@ -354,29 +420,25 @@ static void llama_params_fit_impl(
                 tensor_split[id] = ngl_per_device[id].n_layer;
             }
         }
-        assert(uint32_t(mparams.n_gpu_layers) <= hp_ngl);
-        uint32_t il0 = hp_ngl - mparams.n_gpu_layers; // start index for tensor buft overrides
+        assert(uint32_t(mparams.n_gpu_layers) <= hp_ngl + 1);
+        uint32_t il0 = hp_ngl + 1 - mparams.n_gpu_layers; // start index for tensor buft overrides
 
-        if (add_nonrepeating) {
-            mparams.n_gpu_layers += 1;
-            tensor_split[nd - 1] += 1;
-        }
         mparams.tensor_split = tensor_split;
 
         size_t itbo = 0;
         for (size_t id = 0; id < nd; id++) {
-            il0 += ngl_per_device[id].n_layer - ngl_per_device[id].n_part;
+            il0 += ngl_per_device[id].n_full();
             for (uint32_t il = il0; il < il0 + ngl_per_device[id].n_part; il++) {
                 if (itbo + 1 >= ntbo) {
                     tensor_buft_overrides[itbo].pattern = nullptr;
                     tensor_buft_overrides[itbo].buft    = nullptr;
                     itbo++;
                     mparams.tensor_buft_overrides = tensor_buft_overrides;
-                    throw std::runtime_error("llama_params_fit_n_tensor_buft_overrides() == "
-                        + std::to_string(ntbo) + " is insufficient for model\n");
+                    throw llama_params_fit_exception("llama_max_tensor_buft_overrides() == "
+                        + std::to_string(ntbo) + " is insufficient for model");
                 }
                 tensor_buft_overrides[itbo].pattern = get_overflow_pattern(il, il == il0 ? ngl_per_device[id].overflow_type : LAYER_FRACTION_MOE);
-                tensor_buft_overrides[itbo].buft = overflow_bufts[id];
+                tensor_buft_overrides[itbo].buft = il == il0 ? overflow_bufts[id] : ggml_backend_cpu_buffer_type();
                 itbo++;
             }
             il0 += ngl_per_device[id].n_part;
@@ -391,10 +453,9 @@ static void llama_params_fit_impl(
     auto get_memory_for_layers = [&](
             const char * func_name,
             const std::vector & ngl_per_device,
-            const std::vector & overflow_bufts,
-            const bool add_nonrepeating) -> std::vector {
+            const std::vector & overflow_bufts) -> std::vector {
         llama_model_params mparams_copy = *mparams;
-        set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, mparams_copy, add_nonrepeating);
+        set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, mparams_copy);
 
         const dmds_t dmd_nl = llama_get_device_memory_data(
             path_model, &mparams_copy, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level);
@@ -427,9 +488,9 @@ static void llama_params_fit_impl(
         const dmds_t dmds_cpu_moe = llama_get_device_memory_data(
             path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level);
 
-        for (const llama_device_memory_data & dmd : dmds_cpu_moe) {
-            global_surplus_cpu_moe += dmd.free;
-            global_surplus_cpu_moe -= int64_t(dmd.mb.total()) + margin;
+        for (size_t id = 0; id < nd; id++) {
+            global_surplus_cpu_moe += dmds_cpu_moe[id].free;
+            global_surplus_cpu_moe -= int64_t(dmds_cpu_moe[id].mb.total()) + margins[id];
         }
 
         if (global_surplus_cpu_moe > 0) {
@@ -448,27 +509,18 @@ static void llama_params_fit_impl(
     std::vector targets; // maximum acceptable memory use per device
     targets.reserve(nd);
     for (size_t id = 0; id < nd; id++) {
-        targets.push_back(dmds_full[id].free - margin);
+        targets.push_back(dmds_full[id].free - margins[id]);
         LLAMA_LOG_DEBUG("%s: id=%zu, target=%" PRId64 " MiB\n", __func__, id, targets[id]/MiB);
     }
 
-    // whether for the optimal memory use we expect to load at least some MoE tensors:
-    const bool partial_moe = hp_nex > 0 && global_surplus_cpu_moe > 0;
-
-    std::vector overflow_bufts; // which bufts the partial layers of a device overflow to:
+    std::vector overflow_bufts; // which bufts the first partial layer of a device overflows to:
     overflow_bufts.reserve(nd);
-    for (size_t id = 0; id < nd - 1; ++id) {
-        overflow_bufts.push_back(ggml_backend_dev_buffer_type(devs[id + 1]));
+    for (size_t id = 0; id < nd; id++) {
+        overflow_bufts.push_back(ggml_backend_cpu_buffer_type());
     }
-    overflow_bufts.push_back(ggml_backend_cpu_buffer_type());
 
     std::vector ngl_per_device(nd);
-    std::vector mem = get_memory_for_layers(__func__, ngl_per_device, overflow_bufts, partial_moe);
-    if (hp_nex > 0) {
-        for (size_t id = 0; id < nd; id++) {
-            ngl_per_device[id].overflow_type = LAYER_FRACTION_MOE;
-        }
-    }
+    std::vector mem = get_memory_for_layers(__func__, ngl_per_device, overflow_bufts);
 
     // optimize the number of layers per device using the method of false position:
     //   - ngl_per_device has 0 layers for each device, lower bound
@@ -476,22 +528,30 @@ static void llama_params_fit_impl(
     //   - interpolate the memory use / layer between low and high linearly to get a guess where it meets our target
     //   - check memory use of our guess, replace either the low or high bound
     //   - once we only have a difference of a single layer, stop and return the lower bound that just barely still fits
+    //   - the last device has the output layer, which cannot be a partial layer
     if (hp_nex == 0) {
         LLAMA_LOG_INFO("%s: filling dense layers back-to-front:\n", __func__);
     } else {
         LLAMA_LOG_INFO("%s: filling dense-only layers back-to-front:\n", __func__);
     }
-    uint32_t n_unassigned = hp_ngl;
     for (int id = nd - 1; id >= 0; id--) {
+        uint32_t n_unassigned = hp_ngl + 1;
+        for (size_t jd = id + 1; jd < nd; ++jd) {
+            assert(n_unassigned >= ngl_per_device[jd].n_layer);
+            n_unassigned -= ngl_per_device[jd].n_layer;
+        }
+
         std::vector ngl_per_device_high = ngl_per_device;
         ngl_per_device_high[id].n_layer = n_unassigned;
         if (hp_nex > 0) {
-            ngl_per_device_high[id].n_part = ngl_per_device_high[id].n_layer;
+            ngl_per_device_high[id].n_part = size_t(id) < nd - 1 ? ngl_per_device_high[id].n_layer : ngl_per_device_high[id].n_layer - 1;
         }
         if (ngl_per_device_high[id].n_layer > 0) {
-            std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts, partial_moe);
+            std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts);
             if (mem_high[id] > targets[id]) {
+                assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer);
                 uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer;
+                LLAMA_LOG_DEBUG("%s: start filling device %" PRIu32 ", delta=%" PRIu32 "\n", __func__, id, delta);
                 while (delta > 1) {
                     uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]);
                     step_size = std::max(step_size, uint32_t(1));
@@ -500,25 +560,26 @@ static void llama_params_fit_impl(
                     std::vector ngl_per_device_test = ngl_per_device;
                     ngl_per_device_test[id].n_layer += step_size;
                     if (hp_nex) {
-                        ngl_per_device_test[id].n_part += step_size;
+                        ngl_per_device_test[id].n_part += size_t(id) == nd - 1 && ngl_per_device_test[id].n_part == 0 ?
+                            step_size - 1 : step_size; // the first layer is the output layer which must always be full
                     }
-                    const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe);
+                    const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
 
                     if (mem_test[id] <= targets[id]) {
-                        ngl_per_device  = ngl_per_device_test;
-                        mem             = mem_test;
-                        n_unassigned   -= ngl_per_device[id].n_layer;
+                        ngl_per_device = ngl_per_device_test;
+                        mem            = mem_test;
                         LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer);
                     } else {
                         ngl_per_device_high = ngl_per_device_test;
                         mem_high            = mem_test;
-                        LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer);
+                        LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device_high[id].n_layer);
                     }
                     delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer;
                 }
             } else {
-                ngl_per_device  = ngl_per_device_high;
-                n_unassigned   -= ngl_per_device[id].n_layer;
+                assert(ngl_per_device_high[id].n_layer == n_unassigned);
+                ngl_per_device = ngl_per_device_high;
+                mem            = mem_high;
                 LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer);
             }
         }
@@ -529,7 +590,7 @@ static void llama_params_fit_impl(
             __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, mem[id]/MiB, projected_margin/MiB);
     }
     if (hp_nex == 0 || global_surplus_cpu_moe <= 0) {
-        set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams, partial_moe);
+        set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams);
         return;
     }
 
@@ -549,24 +610,20 @@ static void llama_params_fit_impl(
     assert(id_dense_start < nd);
 
     LLAMA_LOG_INFO("%s: converting dense-only layers to full layers and filling them front-to-back with overflow to next device/system memory:\n", __func__);
-    for (size_t id = 0; id <= id_dense_start; id++) {
+    for (size_t id = 0; id <= id_dense_start && id_dense_start < nd; id++) {
         std::vector ngl_per_device_high = ngl_per_device;
         for (size_t jd = id_dense_start; jd < nd; jd++) {
-            const uint32_t n_layer_move = ngl_per_device_high[jd].n_layer;
+            const uint32_t n_layer_move = jd < nd - 1 ? ngl_per_device_high[jd].n_layer : ngl_per_device_high[jd].n_layer - 1;
             ngl_per_device_high[id].n_layer += n_layer_move;
             ngl_per_device_high[jd].n_layer -= n_layer_move;
             ngl_per_device_high[jd].n_part = 0;
         }
         size_t id_dense_start_high = nd - 1;
-        std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts, partial_moe);
+        std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts);
 
         if (mem_high[id] > targets[id]) {
-            assert(ngl_per_device_high[id].n_layer >= ngl_per_device_high[id].n_part);
-            assert(ngl_per_device[id].n_layer >= ngl_per_device[id].n_part);
-            assert((ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part)
-                   >= ngl_per_device[id].n_layer - ngl_per_device[id].n_part);
-            uint32_t delta = (ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part)
-                - (ngl_per_device[id].n_layer - ngl_per_device[id].n_part);
+            assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full());
+            uint32_t delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full();
             while (delta > 1) {
                 uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]);
                 step_size = std::max(step_size, uint32_t(1));
@@ -582,11 +639,11 @@ static void llama_params_fit_impl(
                     ngl_per_device_test[id].n_layer += n_convert_jd;
                     n_converted_test += n_convert_jd;
 
-                    if (ngl_per_device_test[id_dense_start_test].n_layer > 0) {
+                    if (ngl_per_device_test[id_dense_start_test].n_part > 0) {
                         break;
                     }
                 }
-                const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe);
+                const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
 
                 if (mem_test[id] <= targets[id]) {
                     ngl_per_device = ngl_per_device_test;
@@ -601,32 +658,38 @@ static void llama_params_fit_impl(
                     LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start_high=%zu\n",
                         __func__, id, ngl_per_device_high[id].n_layer, ngl_per_device_high[id].n_part, id_dense_start_high);
                 }
-                delta = (ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part)
-                    - (ngl_per_device[id].n_layer - ngl_per_device[id].n_part);
+                assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full());
+                delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full();
             }
         } else {
             ngl_per_device = ngl_per_device_high;
+            mem            = mem_high;
             id_dense_start = id_dense_start_high;
             LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start=%zu\n",
                 __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start);
         }
 
         // try to fit at least part of one more layer
-        if (ngl_per_device[id_dense_start].n_layer > 0) {
+        if (ngl_per_device[id_dense_start].n_layer > (id < nd - 1 ? 0 : 1)) {
             std::vector ngl_per_device_test = ngl_per_device;
             size_t id_dense_start_test = id_dense_start;
             ngl_per_device_test[id_dense_start_test].n_layer--;
             ngl_per_device_test[id_dense_start_test].n_part--;
             ngl_per_device_test[id].n_layer++;
             ngl_per_device_test[id].n_part++;
-            if (ngl_per_device_test[id_dense_start_test].n_layer == 0) {
+            if (ngl_per_device_test[id_dense_start_test].n_part == 0) {
                 id_dense_start_test++;
             }
             ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP;
+            std::vector overflow_bufts_test = overflow_bufts;
+            if (id < nd - 1) {
+                overflow_bufts_test[id] = ggml_backend_dev_buffer_type(devs[id + 1]);
+            }
             LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__);
-            std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe);
-            if (mem_test[id] < targets[id]) {
+            std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test);
+            if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
                 ngl_per_device = ngl_per_device_test;
+                overflow_bufts = overflow_bufts_test;
                 mem            = mem_test;
                 id_dense_start = id_dense_start_test;
                 LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", UP), id_dense_start=%zu\n",
@@ -634,9 +697,10 @@ static void llama_params_fit_impl(
 
                 ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE;
                 LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__);
-                mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe);
-                if (mem_test[id] < targets[id]) {
+                mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test);
+                if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
                     ngl_per_device = ngl_per_device_test;
+                    overflow_bufts = overflow_bufts_test;
                     mem            = mem_test;
                     id_dense_start = id_dense_start_test;
                     LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", GATE), id_dense_start=%zu\n",
@@ -645,9 +709,10 @@ static void llama_params_fit_impl(
             } else {
                 ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN;
                 LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__);
-                mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe);
-                if (mem_test[id] < targets[id]) {
+                mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test);
+                if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
                     ngl_per_device = ngl_per_device_test;
+                    overflow_bufts = overflow_bufts_test;
                     mem            = mem_test;
                     id_dense_start = id_dense_start_test;
                     LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", ATTN), id_dense_start=%zu\n",
@@ -662,30 +727,41 @@ static void llama_params_fit_impl(
             __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB);
     }
 
-    set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams, partial_moe);
+    // print info for devices that were not changed during the conversion from dense only to full layers:
+    for (size_t id = id_dense_start + 1; id < nd; id++) {
+        const int64_t projected_margin = dmds_full[id].free - mem[id];
+        LLAMA_LOG_INFO(
+            "%s:   - %s: %2" PRIu32 " layers (%2" PRIu32 " overflowing), %6" PRId64 " MiB used, %6" PRId64 " MiB free\n",
+            __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB);
+    }
+
+    set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams);
 }
 
-bool llama_params_fit(
+enum llama_params_fit_status llama_params_fit(
         const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams,
         float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides,
-        size_t margin_s, uint32_t n_ctx_min, enum ggml_log_level log_level) {
+        size_t * margins, uint32_t n_ctx_min, enum ggml_log_level log_level) {
     const int64_t t0_us = llama_time_us();
-    bool ok = true;
+    llama_params_fit_status status = LLAMA_PARAMS_FIT_STATUS_SUCCESS;
     try {
-        llama_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margin_s, n_ctx_min, log_level);
+        llama_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margins, n_ctx_min, log_level);
         LLAMA_LOG_INFO("%s: successfully fit params to free device memory\n", __func__);
-    } catch (const std::runtime_error & e) {
+    } catch (const llama_params_fit_exception & e) {
         LLAMA_LOG_WARN("%s: failed to fit params to free device memory: %s\n", __func__, e.what());
-        ok = false;
+        status = LLAMA_PARAMS_FIT_STATUS_FAILURE;
+    } catch (const std::runtime_error & e) {
+        LLAMA_LOG_ERROR("%s: encountered an error while trying to fit params to free device memory: %s\n", __func__, e.what());
+        status = LLAMA_PARAMS_FIT_STATUS_ERROR;
     }
     const int64_t t1_us = llama_time_us();
     LLAMA_LOG_INFO("%s: fitting params to free memory took %.2f seconds\n", __func__, (t1_us - t0_us) * 1e-6);
-    return ok;
+    return status;
 }
 
 struct llama_sampler_chain_params llama_sampler_chain_default_params() {
     struct llama_sampler_chain_params result = {
-        /*.no_perf                     =*/ true,
+        /*.no_perf =*/ true,
     };
 
     return result;
@@ -758,7 +834,7 @@ static int llama_model_load(const std::string & fname, std::vector
     model.t_start_us = tm.t_start_us;
 
     try {
-        llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides);
+        llama_model_loader ml(fname, splits, params.use_mmap, params.use_direct_io, params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides);
 
         ml.print_info();
 
@@ -1021,25 +1097,55 @@ int32_t llama_chat_apply_template(
 // model split
 //
 
-int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
+int32_t llama_split_path(
+    char * split_path,
+    size_t maxlen,
+    const char * path_prefix,
+    int32_t split_no,
+    int32_t split_count) {
+
     static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf";
-    if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) {
-        return strlen(split_path);
+
+    const int written = snprintf(
+        split_path,
+        maxlen,
+        SPLIT_PATH_FORMAT,
+        path_prefix,
+        split_no + 1,
+        split_count
+    );
+
+    if (written < 0 || (size_t) written >= maxlen) {
+        return 0;
     }
-    return 0;
+
+    return (int32_t) written;
 }
 
-int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count) {
-    std::string str_split_path(split_path);
+int32_t llama_split_prefix(
+    char * split_prefix,
+    size_t maxlen,
+    const char * split_path,
+    int32_t split_no,
+    int32_t split_count) {
+
+    const std::string str_split_path(split_path);
+
     char postfix[32];
-    snprintf(postfix, 32, "-%05d-of-%05d.gguf", split_no + 1, split_count);
-    std::string str_postfix(postfix);
-
-    // check if split_prefix ends with postfix
-    int size_prefix = str_split_path.size() - str_postfix.size();
-    if (size_prefix > 0 && str_split_path.find(str_postfix, size_prefix) != std::string::npos) {
-        snprintf(split_prefix, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path);
-        return size_prefix;
+    snprintf(postfix, sizeof(postfix), "-%05d-of-%05d.gguf", split_no + 1, split_count);
+
+    const std::string str_postfix(postfix);
+    if (str_split_path.size() <= str_postfix.size()) {
+        return 0;
+    }
+
+    const size_t size_prefix = str_split_path.size() - str_postfix.size();
+
+    if (str_split_path.compare(size_prefix, std::string::npos, str_postfix) == 0) {
+        const size_t copy_len = std::min(size_prefix + 1, maxlen);
+        snprintf(split_prefix, copy_len, "%s", split_path);
+
+        return (int32_t) size_prefix;
     }
 
     return 0;
diff --git a/llama/llama.cpp/src/models/afmoe.cpp b/llama/llama.cpp/src/models/afmoe.cpp
index 0192e344ca0..6a752a403f6 100644
--- a/llama/llama.cpp/src/models/afmoe.cpp
+++ b/llama/llama.cpp/src/models/afmoe.cpp
@@ -22,8 +22,15 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para
     const float kq_scale = 1.0f/sqrtf(float(n_embd_head));
 
     for (int il = 0; il < n_layer; ++il) {
+        const float freq_base_l  = model.get_rope_freq_base (cparams, il);
+        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+
         ggml_tensor * inpSA = inpL;
 
+        // This overlaps with SWA layers in current models, so get_rope_freq_base/scale may be superfluous
+        const bool use_rope = hparams.n_no_rope_layer_step > 0 &&
+                              (il + 1) % hparams.n_no_rope_layer_step != 0;
+
         // dual attention normalization (pre)
         cur = build_norm(inpL,
                 model.layers[il].attn_norm, NULL,
@@ -56,19 +63,16 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para
             cb(Qcur, "Qcur_normed", il);
             cb(Kcur, "Kcur_normed", il);
 
-            // RoPE only for sliding_attention layers
-            const bool use_rope = hparams.n_no_rope_layer_step > 0 &&
-                                ((il + 1) % hparams.n_no_rope_layer_step) != 0;
             if (use_rope) {
                 Qcur = ggml_rope_ext(
                         ctx0, Qcur, inp_pos, nullptr,
-                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Qcur, "Qcur_rope", il);
 
                 Kcur = ggml_rope_ext(
                         ctx0, Kcur, inp_pos, nullptr,
-                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Kcur, "Kcur_rope", il);
             }
diff --git a/llama/llama.cpp/src/models/bert.cpp b/llama/llama.cpp/src/models/bert.cpp
index 3274fa3b99d..bca0e254fc5 100644
--- a/llama/llama.cpp/src/models/bert.cpp
+++ b/llama/llama.cpp/src/models/bert.cpp
@@ -142,11 +142,13 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params
                     LLM_FFN_GELU, LLM_FFN_SEQ, il);
             cb(cur, "ffn_out", il);
         } else if (model.arch == LLM_ARCH_JINA_BERT_V2) {
+            const bool up_contains_gate = !model.layers[il].ffn_gate && model.layers[il].ffn_up->ne[1] != hparams.n_ff();
+            auto type_op = up_contains_gate ? LLM_FFN_GEGLU : LLM_FFN_GELU;
             cur = build_ffn(cur,
-                    model.layers[il].ffn_up, NULL, NULL,
+                    model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
                     model.layers[il].ffn_gate, NULL, NULL,
                     model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL,
-                    model.layers[il].ffn_gate ? LLM_FFN_GELU : LLM_FFN_GEGLU, LLM_FFN_PAR, il);
+                    type_op, LLM_FFN_PAR, il);
             cb(cur, "ffn_out", il);
         } else {
             cur = build_ffn(cur,
diff --git a/llama/llama.cpp/src/models/cogvlm.cpp b/llama/llama.cpp/src/models/cogvlm.cpp
index edf0d1424ce..0ceae3aaeb5 100644
--- a/llama/llama.cpp/src/models/cogvlm.cpp
+++ b/llama/llama.cpp/src/models/cogvlm.cpp
@@ -3,12 +3,14 @@
 llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params) {
     const int64_t n_embd_head = hparams.n_embd_head_v;
-    float         kq_scale    = 1.0f / sqrtf(float(n_embd_head));
+    const float   kq_scale    = 1.0f / sqrtf(float(n_embd_head));
 
     GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
     GGML_ASSERT(n_embd_head == hparams.n_rot);
 
-    ggml_tensor *inpL, *cur;
+    ggml_tensor * inpL;
+    ggml_tensor * cur;
+
     inpL = build_inp_embd(model.tok_embd);
 
     ggml_tensor * inp_pos = build_inp_pos();
@@ -44,7 +46,7 @@ llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_pa
         }
 
         ggml_tensor * inpSA = inpL;
-        cur                 = build_norm(inpSA, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+        cur = build_norm(inpSA, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
 
         // build self attention
         {
diff --git a/llama/llama.cpp/src/models/cohere2-iswa.cpp b/llama/llama.cpp/src/models/cohere2-iswa.cpp
index b18aa8c4e6c..9334b5e4263 100644
--- a/llama/llama.cpp/src/models/cohere2-iswa.cpp
+++ b/llama/llama.cpp/src/models/cohere2-iswa.cpp
@@ -21,6 +21,9 @@ llm_build_cohere2_iswa::llm_build_cohere2_iswa(const llama_model & model, const
 
     for (int il = 0; il < n_layer; ++il) {
         const bool is_swa = hparams.is_swa(il);
+        // UNUSED:
+        // const float freq_base_l  = model.get_rope_freq_base (cparams, il);
+        // const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
 
         // norm
         cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM, il);
diff --git a/llama/llama.cpp/src/models/deepseek2.cpp b/llama/llama.cpp/src/models/deepseek2.cpp
index 49382874baa..b608396e50e 100644
--- a/llama/llama.cpp/src/models/deepseek2.cpp
+++ b/llama/llama.cpp/src/models/deepseek2.cpp
@@ -2,14 +2,11 @@
 
 llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params) {
-    // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B
-    bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
-
-    const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);
+    const bool is_mla = hparams.is_mla();
 
     // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA
-    const int64_t n_embd_head_k = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k;
-    const int64_t n_embd_head_v = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v;
+    const int64_t n_embd_head_k = hparams.n_embd_head_k_mla();
+    const int64_t n_embd_head_v = hparams.n_embd_head_v_mla();
 
     const int64_t n_embd_head_qk_rope = hparams.n_rot;
     const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope;
@@ -17,7 +14,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
     const uint32_t kv_lora_rank = hparams.n_lora_kv;
 
     // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly.
-    // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
+    // See https://github.com/ggml-org/llama.cpp/discussions/7416 for detailed explanation.
     // And also: https://github.com/ggml-org/llama.cpp/pull/17945 [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
 
     // first cancel the adjustment from llama_hparams::yarn_attn_factor_adjust to get the original attn_factor
@@ -43,11 +40,13 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
     // inp_pos - contains the positions
     ggml_tensor * inp_pos = build_inp_pos();
 
-    auto * inp_attn = build_attn_inp_kv();
+    auto * inp_attn_kv = !is_mla ? build_attn_inp_kv() : nullptr;
+    auto * inp_attn_k  =  is_mla ? build_attn_inp_k()  : nullptr;
 
     ggml_tensor * inp_out_ids = build_inp_out_ids();
 
-    for (int il = 0; il < n_layer; ++il) {
+    int effective_n_layers = hparams.n_layer - hparams.nextn_predict_layers;
+    for (int il = 0; il < effective_n_layers; ++il) {
         ggml_tensor * inpSA = inpL;
 
         // norm
@@ -57,6 +56,9 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
         // self_attention
         {
             ggml_tensor * q = NULL;
+
+            const bool is_lite = model.layers[il].wq;
+
             if (!is_lite) {
                 q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
                 cb(q, "q", il);
@@ -124,14 +126,14 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
 
                 // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
                 // note: rope must go first for in-place context shifting in build_rope_shift()
-                ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0);
+                ggml_tensor * Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0);
                 cb(Qcur, "Qcur", il);
 
                 kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
                 cb(kv_cmpr, "kv_cmpr_reshape", il);
 
                 // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
-                ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0);
+                ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0);
                 cb(Kcur, "Kcur", il);
 
                 // {kv_lora_rank, 1, n_tokens}
@@ -145,7 +147,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
                 }
 
                 // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group)
-                cur = build_attn(inp_attn,
+                cur = build_attn(inp_attn_k,
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, kq_scale, il);
             } else {
@@ -169,11 +171,10 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
                 Vcur = ggml_cont(ctx0, Vcur);
                 cb(Vcur, "Vcur_cont", il);
 
-                // note: rope must go first for in-place context shifting in build_rope_shift()
-                ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope, 0);
+                ggml_tensor * Qcur = ggml_concat(ctx0, q_nope, q_pe, 0);
                 cb(Qcur, "Qcur", il);
 
-                ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0);
+                ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
                 cb(Kcur, "Kcur", il);
 
                 if (inp_attn_scale) {
@@ -183,12 +184,12 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
                 }
 
                 // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups)
-                cur = build_attn(inp_attn,
+                cur = build_attn(inp_attn_kv,
                             model.layers[il].wo, NULL,
                             Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
             }
         }
-        if (il == n_layer - 1 && inp_out_ids) {
+        if (il == effective_n_layers - 1 && inp_out_ids) {
             cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
             inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
         }
@@ -215,9 +216,11 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
                 model.layers[il].ffn_exp_probs_b,
                 n_expert, n_expert_used,
                 LLM_FFN_SILU, hparams.expert_weights_norm,
-                true, hparams.expert_weights_scale,
+                hparams.expert_weights_scale, hparams.expert_weights_scale,
                 (llama_expert_gating_func_type) hparams.expert_gating_func,
-                il);
+                il,
+                nullptr,
+                model.layers[il].ffn_gate_up_exps);
             cb(moe_out, "ffn_moe_out", il);
 
             // FFN shared expert
diff --git a/llama/llama.cpp/src/models/delta-net-base.cpp b/llama/llama.cpp/src/models/delta-net-base.cpp
new file mode 100644
index 00000000000..99f1fdd9538
--- /dev/null
+++ b/llama/llama.cpp/src/models/delta-net-base.cpp
@@ -0,0 +1,376 @@
+#include "models.h"
+
+#define CHUNK_SIZE 64
+
+// utility to get one slice from the third dimension
+// input dim:  [x, y, c, b]
+// output dim: [x, y, 1, b]
+static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) {
+    return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
+        t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
+}
+
+llm_build_delta_net_base::llm_build_delta_net_base(const llm_graph_params & params) : llm_graph_context(params) {}
+
+std::pair llm_build_delta_net_base::build_delta_net_chunking(
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * g,
+        ggml_tensor * b,
+        ggml_tensor * s,
+        int           il) {
+    const int64_t S_k      = q->ne[0];
+    const int64_t H_k      = q->ne[1];
+    const int64_t n_tokens = q->ne[2];
+    const int64_t n_seqs   = q->ne[3];
+
+    const int64_t S_v = v->ne[0];
+    const int64_t H_v = v->ne[1];
+    const bool kda = (g->ne[0] == S_k && g->ne[1] == H_k);
+
+    GGML_ASSERT(S_k == S_v);
+    GGML_ASSERT(H_v % H_k == 0);
+
+    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+    GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
+
+    GGML_ASSERT(g->ne[0] == 1   || g->ne[0] == S_v);
+    GGML_ASSERT(                   g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
+    GGML_ASSERT(b->ne[0] == 1   && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
+    GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v      && s->ne[3] == n_seqs);
+
+    const float scale = 1.0f / sqrtf(S_k);
+
+    q = ggml_scale(ctx0, q, scale);
+
+    cb(q, "q_in", il);
+    cb(k, "k_in", il);
+    cb(v, "v_in", il);
+    cb(b, "b_in", il);
+    cb(g, "g_in", il);
+
+    q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
+    k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
+    v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs]
+    g = ggml_permute(ctx0, g, 0, 2, 1, 3); // [g_0, n_tokens, H_v, n_seqs]
+    b = ggml_permute(ctx0, b, 0, 2, 1, 3); // [  1, n_tokens, H_v, n_seqs]
+
+    const int CS = CHUNK_SIZE;
+
+    const int pad = (CS - n_tokens % CS) % CS;
+    const int n_chunks = (n_tokens + pad) / CS;
+
+    q = ggml_pad(ctx0, q, 0, pad, 0, 0);
+    k = ggml_pad(ctx0, k, 0, pad, 0, 0);
+    v = ggml_pad(ctx0, v, 0, pad, 0, 0);
+    g = ggml_pad(ctx0, g, 0, pad, 0, 0);
+    b = ggml_pad(ctx0, b, 0, pad, 0, 0);
+
+    ggml_tensor * v_b = ggml_mul(ctx0, v, b);
+    ggml_tensor * k_b = ggml_mul(ctx0, k, b);
+
+    cb(v_b, "v_b", il);
+    cb(k_b, "k_b", il);
+
+    q   = ggml_reshape_4d(ctx0, q,   S_k, CS, n_chunks, H_k * n_seqs);
+    k   = ggml_reshape_4d(ctx0, k,   S_k, CS, n_chunks, H_k * n_seqs);
+    k_b = ggml_reshape_4d(ctx0, k_b, S_k, CS, n_chunks, H_v * n_seqs);
+    v   = ggml_reshape_4d(ctx0, v,   S_v, CS, n_chunks, H_v * n_seqs);
+    v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs);
+
+    g = ggml_reshape_4d(ctx0, g, g->ne[0], CS, n_chunks, H_v * n_seqs);
+    b = ggml_reshape_4d(ctx0, b, 1,        CS, n_chunks, H_v * n_seqs);
+
+    // [CS, g_0, n_chunks, H_v * n_seqs]
+    // TODO: extend ggml_cumsum with axis parameter to avoid transpose
+    ggml_tensor * g_cs = ggml_cumsum(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, g)));
+    cb(g_cs, "g_cs", il);
+
+    ggml_tensor * kb = nullptr;
+    ggml_tensor * kq = nullptr;
+    if (kda) {
+        const int64_t CHB = n_chunks * H_k * n_seqs;
+
+        ggml_tensor * g_cs_i = ggml_reshape_4d(ctx0, g_cs, CS, 1, S_k, CHB);  // [chunk_size, 1, S_k, CHB]
+        ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, S_k, CHB);  // [1, chunk_size, S_k, CHB]
+
+        g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, S_k, CHB);  // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB]
+
+        // decay_mask [chunk_size,chunk_size,S_k,CHB]
+        ggml_tensor * decay_mask;
+        decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i);
+        decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
+        decay_mask = ggml_exp(ctx0, decay_mask);
+        cb(decay_mask, "decay_mask", il);
+
+        // decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched
+        decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, CS, CS, CHB);
+
+        ggml_tensor * k_b_i = ggml_reshape_4d(ctx0, k_b, S_k, CS,  1, CHB);
+        ggml_tensor * k_j   = ggml_reshape_4d(ctx0, k,   S_k,  1, CS, CHB);
+        ggml_tensor * q_i   = ggml_reshape_4d(ctx0, q,   S_k, CS,  1, CHB);
+
+        ggml_tensor * decay_k_b_i = ggml_mul(ctx0, decay_mask, k_b_i);
+        ggml_tensor * decay_q_i   = ggml_mul(ctx0, decay_mask, q_i);
+
+        // decay_k_b_i [S,BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB]
+        kb = ggml_mul_mat(ctx0, decay_k_b_i, k_j);
+        kq = ggml_mul_mat(ctx0, decay_q_i,   k_j);
+
+        kb = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, kb, CS, CS, n_chunks, H_v * n_seqs)));
+        kq = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, kq, CS, CS, n_chunks, H_v * n_seqs)));
+    } else {
+        ggml_tensor * g_cs_i = g_cs;
+        ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs);
+
+        g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs);
+
+        // [CS, CS, n_chunks, H_v * n_seqs]
+        ggml_tensor * decay_mask;
+        decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i);
+        decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
+        decay_mask = ggml_exp(ctx0, decay_mask);
+        cb(decay_mask, "decay_mask", il);
+
+        // [CS, CS, n_chunks, H_k * n_seqs]
+        kb = ggml_mul_mat(ctx0, k,  k_b);
+        kb = ggml_mul    (ctx0, kb, decay_mask);
+
+        // [CS, CS, n_chunks, H_k * n_seqs]
+        kq = ggml_mul_mat(ctx0, k, q);
+        kq = ggml_mul(ctx0, kq, decay_mask);
+    }
+
+    kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG);
+    cb(kq, "kq", il);
+
+    // [CS, CS, n_chunks, H_k * n_seqs]
+    ggml_tensor * attn;
+    attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER);
+    cb(attn, "attn", il);
+
+    ggml_tensor * identity;
+    identity = ggml_view_1d(ctx0, attn, CS, 0);
+    identity = ggml_fill   (ctx0, identity, 1.0f);
+    identity = ggml_diag   (ctx0, identity);
+
+    ggml_tensor * lhs = ggml_add(ctx0, attn, identity);
+    cb(lhs, "dnet_add_ch_lhs", il);
+
+    attn = ggml_neg(ctx0, attn);
+    cb(attn, "attn_pre_solve", il);
+
+    ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
+    attn = ggml_add(ctx0, lin_solve, identity);
+    cb(attn, "dnet_add_ch_attn_solved", il); // [CS, CS, n_chunks, H_k * n_seqs]
+
+    // [S_v, CS, n_chunks, H_v * n_seqs]
+    v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn);
+
+    // [CS, 1, n_chunks, H_v * n_seqs] KDA: [CS, S_k, n_chunks, H_v * n_seqs]
+    ggml_tensor * g_exp = ggml_exp(ctx0, g_cs);
+
+    k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b));
+
+    // [CS, S_k, n_chunks, H_k * n_seqs]
+    ggml_tensor * kbg = ggml_mul(ctx0, k_b, g_exp);
+    cb(kbg, "k_beta_g_exp", il);
+
+    // [S_k, CS, n_chunks, H_k * n_seqs]
+    ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn);
+    cb(k_cd, "k_cumdecay", il);
+
+    // [1, CS, n_chunks, H_k * n_seqs] KDA: [S_k, CS, n_chunks, H_k * n_seqs]
+    ggml_tensor * g_exp_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_exp));
+    ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t);
+
+    // vectorized calculation of key_gdiff
+    // improved from the chunked version:
+    //   g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
+    //   g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
+    //   key_gdiff = key * g_diff.unsqueeze(-1)
+    //   kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
+    //   last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
+
+    // get last element in g_cumsum along CS dimension (ne0)
+    // example: [[x, y, z, ..., last], ...] -> [[last], ...]
+    // [1, 1, n_chunks, H_v * n_seqs] KDA: [1, S_k, n_chunks, H_v * n_seqs]
+    ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, g_cs->ne[1], g_cs->ne[2], g_cs->ne[3],
+            g_cs->nb[1],
+            g_cs->nb[2],
+            g_cs->nb[3],
+            ggml_row_size(g_cs->type, g_cs->ne[0] - 1));
+    cb(g_last, "g_last", il);
+
+    // TODO: remove this cont when CUDA supports non-cont unary ops
+    g_last = ggml_cont(ctx0, g_last);
+
+    // [1, 1, n_chunks, H_v * n_seqs] KDA: [S_k, 1, n_chunks, H_v * n_seqs]
+    ggml_tensor * g_last_exp_t = ggml_transpose(ctx0, ggml_exp(ctx0, g_last));
+    cb(g_last_exp_t, "g_last_exp_t", il);
+
+    // [CS, 1, n_chunks, H_v * n_seqs] KDA: [CS, S_k, n_chunks, H_v * n_seqs]
+    ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last));
+    cb(g_diff, "g_diff", il);
+
+    ggml_tensor * g_diff_exp_t = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_exp(ctx0, g_diff)));
+
+    // [S_k, CS, n_chunks, H_v * n_seqs]
+    ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t);
+    cb(kg, "key_gdiff", il);
+
+    // [CS, S_k, n_chunks, H_v * n_seqs]
+    ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg));
+    cb(kg_t, "key_gdiff_t", il);
+
+    ggml_tensor * s_t = ggml_transpose(ctx0, s);
+    s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs);
+    cb(s_t, "dnet_add_ch_state", il);
+
+    // [CS, S_v, n_chunks, H_v * n_seqs]
+    ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v));
+
+    for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
+        ggml_tensor * ch_k_cd    = get_slice_2d(ctx0, k_cd,    chunk); // [S_k,  CS, 1, H_k * n_seqs]
+        ggml_tensor * ch_v_t     = get_slice_2d(ctx0, v_t,     chunk); // [ CS, S_v, 1, H_v * n_seqs]
+        ggml_tensor * ch_kq      = get_slice_2d(ctx0, kq,      chunk); // [ CS,  CS, 1, H_k * n_seqs]
+        ggml_tensor * ch_q_g_exp = get_slice_2d(ctx0, q_g_exp, chunk); // [S_k,  CS, 1, H_k * n_seqs]
+        ggml_tensor * ch_kg_t    = get_slice_2d(ctx0, kg_t,    chunk); // [ CS, S_k, 1, H_v * n_seqs]
+
+        // [CS, S_v, 1, H_v * n_seqs]
+        ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t);
+        cb(v_t_p, "v_prime", il);
+
+        // [CS, S_v, 1, H_v * n_seqs]
+        ggml_tensor * v_t_new = ggml_sub(ctx0, ch_v_t, v_t_p);
+        cb(v_t_new, "v_t_new", il);
+
+        // [S_v, CS, 1, H_v * n_seqs]
+        ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_t_new, ch_kq);
+        cb(v_attn, "v_attn", il);
+
+        // [S_v, CS, 1, H_v * n_seqs]
+        ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp);
+        cb(attn_inter, "attn_inter", il);
+
+        // [S_v, CS, 1, H_v * n_seqs]
+        ggml_tensor * o_ch = ggml_add(ctx0, attn_inter, v_attn);
+        cb(o_ch, "dnet_add_ch_attn_out", il);
+
+        v = ggml_set_inplace(ctx0, v, o_ch, v->nb[1], v->nb[2], v->nb[3], chunk * v->nb[2]);
+
+        // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
+        // TODO: head broadcast might not work here - probably will need a transpose
+        ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs]
+
+        // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
+        ggml_tensor * ch_g_last_exp_t = get_slice_2d(ctx0, g_last_exp_t, chunk);
+
+        s_t = ggml_mul(ctx0, s_t, ch_g_last_exp_t);
+        s_t = ggml_add(ctx0, s_t, kgv);
+        cb(s_t, "dnet_add_ch_state", il);
+    }
+
+    s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs);
+
+    // truncate padded tokens
+    ggml_tensor * o = ggml_view_4d(ctx0, v,
+            S_v, n_tokens, H_v, n_seqs,
+            ggml_row_size(v->type, S_v),
+            ggml_row_size(v->type, S_v * CS * n_chunks),
+            ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0);
+    o = ggml_permute  (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
+    s = ggml_transpose(ctx0, s_t);
+    cb(s, "output_state", il);
+
+    return {o, s};
+}
+
+std::pair llm_build_delta_net_base::build_delta_net_autoregressive(
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * g,
+        ggml_tensor * b, // beta
+        ggml_tensor * s, // state
+        int           il) {
+    const int64_t S_k      = q->ne[0];
+    const int64_t H_k      = q->ne[1];
+    const int64_t n_tokens = q->ne[2];
+    const int64_t n_seqs   = q->ne[3];
+
+    const int64_t S_v = v->ne[0];
+    const int64_t H_v = v->ne[1];
+
+    GGML_ASSERT(n_tokens == 1);
+
+    GGML_ASSERT(S_k == S_v);
+    GGML_ASSERT(H_v % H_k == 0);
+
+    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+    GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
+
+    GGML_ASSERT(g->ne[0] == 1   || g->ne[0] == S_v);
+    GGML_ASSERT(                   g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
+    GGML_ASSERT(b->ne[0] == 1   && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
+    GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v      && s->ne[3] == n_seqs);
+
+    const float scale = 1.0f / sqrtf(S_k);
+
+    q = ggml_scale(ctx0, q, scale);
+
+    q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
+    k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
+    v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs]
+
+    cb(q, "q_in", il);
+    cb(k, "k_in", il);
+    cb(v, "v_in", il);
+    cb(b, "b_in", il);
+    cb(g, "g_in", il);
+
+    // GDA: [1,  1,  H_v, n_seqs]
+    // KDA: [1, S_k, H_v, n_seqs]
+    g = ggml_reshape_4d(ctx0, g, 1, g->ne[0], H_v, n_seqs);
+    b = ggml_reshape_4d(ctx0, b, 1,        1, H_v, n_seqs);
+
+    // [S_v, S_v, H_v, n_seqs]
+    g = ggml_exp(ctx0, g);
+    s = ggml_mul(ctx0, s, g);
+
+    ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s));
+
+    // [1, S_v, H_v, n_seqs]
+    ggml_tensor * sk;
+    sk = ggml_mul     (ctx0, s_t, k);
+    sk = ggml_sum_rows(ctx0, sk);
+
+    // [S_v, 1, H_v, n_seqs]
+    ggml_tensor * d;
+    d = ggml_sub(ctx0, v, ggml_transpose(ctx0, sk));
+    d = ggml_mul(ctx0, d, b);
+
+    // [1, S_v, H_v, n_seqs]
+    ggml_tensor * d_t;
+    d_t = ggml_transpose(ctx0, d);
+
+    // [S_v, S_v, H_v, n_seqs]
+    ggml_tensor * kd;
+    k  = ggml_repeat(ctx0, k, s);
+    kd = ggml_mul   (ctx0, k, d_t);
+
+    s_t = ggml_add(ctx0, s_t, kd);
+
+    cb(s_t, "dnet_add_ar_state", il);
+
+    ggml_tensor * s_q = ggml_mul     (ctx0, s_t, q);
+    ggml_tensor * o   = ggml_sum_rows(ctx0, s_q);
+
+    o = ggml_permute  (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
+    s = ggml_transpose(ctx0, s_t);           // [S_v, S_v, H_v, n_seqs]
+
+    return {o, s};
+}
diff --git a/llama/llama.cpp/src/models/eurobert.cpp b/llama/llama.cpp/src/models/eurobert.cpp
new file mode 100644
index 00000000000..86e3176edc0
--- /dev/null
+++ b/llama/llama.cpp/src/models/eurobert.cpp
@@ -0,0 +1,97 @@
+#include "models.h"
+
+llm_build_eurobert::llm_build_eurobert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+    const int64_t n_embd_head = hparams.n_embd_head_v;
+
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+    ggml_tensor * inp_pos = build_inp_pos();
+
+    inpL = build_inp_embd(model.tok_embd);
+    cb(inpL, "inp_embd", -1);
+
+    auto * inp_attn = build_attn_inp_no_cache();
+
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_tensor * cur = inpL;
+
+        cur = build_norm(inpL,
+                model.layers[il].attn_norm, NULL,
+                LLM_NORM_RMS, il);
+
+        {
+            ggml_tensor * Qcur;
+            ggml_tensor * Kcur;
+            ggml_tensor * Vcur;
+
+            Qcur = build_lora_mm(model.layers[il].wq, cur);
+            Kcur = build_lora_mm(model.layers[il].wk, cur);
+            Vcur = build_lora_mm(model.layers[il].wv, cur);
+
+            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+            Qcur = ggml_rope_ext(
+                    ctx0, Qcur, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            Kcur = ggml_rope_ext(
+                    ctx0, Kcur, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            cb(Qcur, "Qcur", il);
+            cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
+
+            cur = build_attn(inp_attn,
+                    model.layers[il].wo, nullptr,
+                    Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            cb(cur, "kqv_out", il);
+        }
+
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+            inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+        }
+
+        cur = ggml_add(ctx0, cur, inpL);
+
+        ggml_tensor * ffn_inp = cur;
+        cb(ffn_inp, "ffn_inp", il);
+
+        cur = build_norm(ffn_inp,
+                model.layers[il].ffn_norm, NULL,
+                LLM_NORM_RMS, il);
+        cb(cur, "ffn_norm", il);
+
+        cur = build_ffn(cur,
+                model.layers[il].ffn_up, NULL, NULL,
+                model.layers[il].ffn_gate, NULL, NULL,
+                model.layers[il].ffn_down, NULL, NULL,
+                NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
+        cb(cur, "ffn_out", il);
+
+        cur = ggml_add(ctx0, cur, ffn_inp);
+
+        inpL = cur;
+    }
+    cur = inpL;
+
+    cur = build_norm(cur,
+            model.output_norm, NULL,
+            LLM_NORM_RMS, -1);
+
+    cb(cur, "result_embd", -1);
+    res->t_embd = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
diff --git a/llama/llama.cpp/src/models/exaone-moe.cpp b/llama/llama.cpp/src/models/exaone-moe.cpp
new file mode 100644
index 00000000000..bef5b2ad351
--- /dev/null
+++ b/llama/llama.cpp/src/models/exaone-moe.cpp
@@ -0,0 +1,146 @@
+#include "models.h"
+
+
+llm_build_exaone_moe::llm_build_exaone_moe(const llama_model & model, const llm_graph_params & params) :
+    llm_graph_context(params) {
+    const int64_t n_embd_head = hparams.n_embd_head_k;
+
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_v);
+    GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+
+    // inp_pos - contains the positions
+    ggml_tensor * inp_pos = build_inp_pos();
+
+    auto * inp_attn_iswa = build_attn_inp_kv_iswa();
+
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
+    for (int il = 0; il < n_transformer_layers; ++il) {
+        ggml_tensor * inpSA = inpL;
+
+        // use RoPE for SWA layers
+        const bool is_local_layer = hparams.is_swa(il);
+
+        // norm
+        cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+        cb(cur, "attn_norm", il);
+
+        // self-attention
+        {
+            ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
+
+            // compute Q and K and RoPE them
+            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+            cb(Qcur, "Qcur", il);
+
+            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+            cb(Kcur, "Kcur", il);
+
+            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+            cb(Vcur, "Vcur", il);
+
+            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+            Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+            Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+            cb(Qcur, "Qcur_normed", il);
+            cb(Kcur, "Kcur_normed", il);
+
+            if (is_local_layer) {
+                Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base,
+                                     freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+
+                Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base,
+                                     freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+            }
+            cb(Qcur, "Qcur", il);
+            cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
+
+            cur = build_attn(inp_attn_iswa,
+                model.layers[il].wo, NULL,
+                Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
+            cb(cur, "attn_out", il);
+        }
+        if (il == n_transformer_layers - 1 && inp_out_ids) {
+            cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+        cb(ffn_inp, "ffn_inp", il);
+
+        // norm
+        cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
+        cb(cur, "ffn_norm", il);
+
+        // feed-forward network
+        if (model.layers[il].ffn_gate_inp == nullptr) {
+            // dense branch
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up, NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL, NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+        } else {
+            // MoE branch
+            ggml_tensor * moe_out = build_moe_ffn(cur,
+                model.layers[il].ffn_gate_inp,
+                model.layers[il].ffn_up_exps,
+                model.layers[il].ffn_gate_exps,
+                model.layers[il].ffn_down_exps,
+                model.layers[il].ffn_exp_probs_b,
+                n_expert, n_expert_used,
+                LLM_FFN_SILU, hparams.expert_weights_norm,
+                true, hparams.expert_weights_scale,
+                (llama_expert_gating_func_type) hparams.expert_gating_func,
+                il);
+            cb(moe_out, "ffn_moe_out", il);
+
+            // FFN shared expert
+            {
+                ggml_tensor * ffn_shexp =
+                    build_ffn(cur,
+                        model.layers[il].ffn_up_shexp, NULL, NULL,
+                        model.layers[il].ffn_gate_shexp, NULL, NULL,
+                        model.layers[il].ffn_down_shexp, NULL, NULL,
+                        NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(ffn_shexp, "ffn_shexp", il);
+
+                cur = ggml_add(ctx0, moe_out, ffn_shexp);
+                cb(cur, "ffn_out", il);
+            }
+        }
+
+        cur = ggml_add(ctx0, cur, ffn_inp);
+
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        // input for next layer
+        inpL = cur;
+    }
+    cur = inpL;
+
+    // final norm
+    cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
+
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    // lm_head
+    cur = build_lora_mm(model.output, cur);
+
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
diff --git a/llama/llama.cpp/src/models/falcon-h1.cpp b/llama/llama.cpp/src/models/falcon-h1.cpp
index b641a094079..785a7e5e662 100644
--- a/llama/llama.cpp/src/models/falcon-h1.cpp
+++ b/llama/llama.cpp/src/models/falcon-h1.cpp
@@ -1,9 +1,7 @@
 #include "models.h"
 
-
-
 llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params) :
-    llm_graph_context_mamba(params) {
+    llm_build_mamba_base(params) {
     const int64_t n_embd_head = hparams.n_embd_head_v;
 
     ggml_tensor * cur;
diff --git a/llama/llama.cpp/src/models/gemma-embedding.cpp b/llama/llama.cpp/src/models/gemma-embedding.cpp
index 90a98f7abf0..944c198bf95 100644
--- a/llama/llama.cpp/src/models/gemma-embedding.cpp
+++ b/llama/llama.cpp/src/models/gemma-embedding.cpp
@@ -1,7 +1,5 @@
 #include "models.h"
 
-
-
 llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params) {
     const int64_t n_embd_head = hparams.n_embd_head_k;
@@ -12,10 +10,8 @@ llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model,
     inpL = build_inp_embd(model.tok_embd);
 
     // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
-    if (ubatch.token) {
-        inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
-        cb(inpL, "inp_scaled", -1);
-    }
+    inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f);
+    cb(inpL, "inp_scaled", -1);
 
     // inp_pos - contains the positions
     ggml_tensor * inp_pos = build_inp_pos();
diff --git a/llama/llama.cpp/src/models/gemma2-iswa.cpp b/llama/llama.cpp/src/models/gemma2-iswa.cpp
index 9cc59a53ee5..7a9198193ac 100644
--- a/llama/llama.cpp/src/models/gemma2-iswa.cpp
+++ b/llama/llama.cpp/src/models/gemma2-iswa.cpp
@@ -19,6 +19,9 @@ llm_build_gemma2_iswa::llm_build_gemma2_iswa(const llama_model & model, const ll
     ggml_tensor * inp_out_ids = build_inp_out_ids();
 
     for (int il = 0; il < n_layer; ++il) {
+        const float freq_base_l  = model.get_rope_freq_base (cparams, il);
+        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+
         // norm
         cur = build_norm(inpL,
                 model.layers[il].attn_norm, NULL,
@@ -43,12 +46,12 @@ llm_build_gemma2_iswa::llm_build_gemma2_iswa(const llama_model & model, const ll
 
             Qcur = ggml_rope_ext(
                     ctx0, Qcur, inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
                     ext_factor, attn_factor, beta_fast, beta_slow);
 
             Kcur = ggml_rope_ext(
                     ctx0, Kcur, inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
                     ext_factor, attn_factor, beta_fast, beta_slow);
 
             cb(Qcur, "Qcur", il);
diff --git a/llama/llama.cpp/src/models/gemma3.cpp b/llama/llama.cpp/src/models/gemma3.cpp
index ae60ef4790c..dec3fc4b8bc 100644
--- a/llama/llama.cpp/src/models/gemma3.cpp
+++ b/llama/llama.cpp/src/models/gemma3.cpp
@@ -10,10 +10,9 @@ llm_build_gemma3::llm_build_gemma3(const llama_model & model, const llm_gr
     inpL = build_inp_embd(model.tok_embd);
 
     // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
-    if (ubatch.token) {
-        inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
-        cb(inpL, "inp_scaled", -1);
-    }
+    inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f);
+    cb(inpL, "inp_scaled", -1);
+
     // inp_pos - contains the positions
     ggml_tensor * inp_pos = build_inp_pos();
 
diff --git a/llama/llama.cpp/src/models/gemma3n-iswa.cpp b/llama/llama.cpp/src/models/gemma3n-iswa.cpp
index a0bdd6a15a1..7db6d3bf4ec 100644
--- a/llama/llama.cpp/src/models/gemma3n-iswa.cpp
+++ b/llama/llama.cpp/src/models/gemma3n-iswa.cpp
@@ -1,7 +1,5 @@
 #include "models.h"
 
-
-
 llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params),
     model(model),
@@ -15,10 +13,9 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
     inpL = build_inp_embd(model.tok_embd);
 
     // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
-    if (ubatch.token) {
-        inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
-        cb(inpL, "inp_scaled", -1);
-    }
+    inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f);
+    cb(inpL, "inp_scaled", -1);
+
     // inp_pos - contains the positions
     ggml_tensor * inp_pos = build_inp_pos();
 
@@ -248,20 +245,30 @@ ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) {
 // equivalent to get_per_layer_inputs() in python code
 // output shape: [n_embd_altup, n_layer, n_tokens]
 ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
-    auto          inp = std::make_unique();
+    auto inp = std::make_unique(n_embd);
     ggml_tensor * inp_per_layer;
     if (ubatch.token) {
         inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
         ggml_set_input(inp->tokens);
-        res->t_tokens = inp->tokens;
+        res->t_inp_tokens = inp->tokens;
         inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
         inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
         inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup));
         cb(inp_per_layer, "inp_per_layer_selected", -1);
+        res->add_input(std::move(inp));
     } else {
-        GGML_ABORT("TODO: support embd input");
+        // Vision embedding path: use padding token (ID=0) embedding
+        // TODO: verify if this is the correct behavior in transformers implementation
+        const int64_t embd_size = model.tok_embd_per_layer->ne[0];  // n_embd_altup * n_layer
+
+        // Extract and dequantize padding token embedding (row 0)
+        ggml_tensor * padding = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0);
+        inp_per_layer = ggml_cast(ctx0, padding, GGML_TYPE_F32);
+
+        // Reshape to [n_embd_altup, n_layer, 1]
+        inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, 1);
+        cb(inp_per_layer, "inp_per_layer_vision", -1);
     }
-    res->add_input(std::move(inp));
     return inp_per_layer;
 }
 
@@ -279,7 +286,7 @@ ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inp
                                               -1);  // [n_embd_altup, n_layer, n_tokens]
     cb(per_layer_proj, "per_layer_proj", -1);
 
-    inp_per_layer = ggml_add(ctx0, inp_per_layer, per_layer_proj);
+    inp_per_layer = ggml_add(ctx0, per_layer_proj, inp_per_layer);
     inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
     cb(inp_per_layer, "inp_per_layer", -1);
 
diff --git a/llama/llama.cpp/src/models/glm4.cpp b/llama/llama.cpp/src/models/glm4.cpp
index 204aa3932af..bcd837b30d6 100644
--- a/llama/llama.cpp/src/models/glm4.cpp
+++ b/llama/llama.cpp/src/models/glm4.cpp
@@ -29,7 +29,10 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params
 
     ggml_tensor * inp_out_ids = build_inp_out_ids();
 
-    for (int il = 0; il < n_layer; ++il) {
+    // Only process up to last layer (skip final NextN layer)
+    // Final layer tensors are loaded but not processed in forward pass
+    const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
+    for (int il = 0; il < n_transformer_layers; ++il) {
         ggml_tensor * inpSA = inpL;
 
         // Pre-attention norm
@@ -100,7 +103,7 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params
                     model.layers[il].wo, NULL,
                     Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
         }
-        if (il == n_layer - 1 && inp_out_ids) {
+        if (il == n_transformer_layers - 1 && inp_out_ids) {
             cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
             inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
         }
@@ -130,9 +133,13 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params
             cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il);
             cb(cur, "post_mlp_norm", il);
         }
-        // Add residual connection after post-MLP norm
-        inpL = ggml_add(ctx0, cur, ffn_inp);
-        cb(inpL, "l_out", il);
+        cur = ggml_add(ctx0, cur, ffn_inp);
+
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        // input for next layer
+        inpL = cur;
     }
     // Final norm
     cur = build_norm(inpL, model.output_norm, NULL, LLM_NORM_RMS, -1);
diff --git a/llama/llama.cpp/src/models/granite-hybrid.cpp b/llama/llama.cpp/src/models/granite-hybrid.cpp
index f6ca4c17a21..726ecdcca77 100644
--- a/llama/llama.cpp/src/models/granite-hybrid.cpp
+++ b/llama/llama.cpp/src/models/granite-hybrid.cpp
@@ -2,7 +2,7 @@
 
 
 llm_build_granite_hybrid::llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params) :
-    llm_graph_context_mamba(params) {
+    llm_build_mamba_base(params) {
     const int64_t n_embd_head = hparams.n_embd_head_v;
     GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
 
diff --git a/llama/llama.cpp/src/models/jais2.cpp b/llama/llama.cpp/src/models/jais2.cpp
new file mode 100644
index 00000000000..a69fcaa3bb3
--- /dev/null
+++ b/llama/llama.cpp/src/models/jais2.cpp
@@ -0,0 +1,123 @@
+#include "models.h"
+
+// JAIS-2 model graph builder
+// Uses: LayerNorm (not RMSNorm), relu2 activation, separate Q/K/V, RoPE embeddings
+llm_build_jais2::llm_build_jais2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+    const int64_t n_embd_head = hparams.n_embd_head_v;
+
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+
+    // inp_pos - contains the positions
+    ggml_tensor * inp_pos = build_inp_pos();
+
+    // KV input for attention
+    auto * inp_attn = build_attn_inp_kv();
+
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    for (int il = 0; il < n_layer; ++il) {
+        // Pre-attention LayerNorm
+        cur = build_norm(inpL,
+                model.layers[il].attn_norm,
+                model.layers[il].attn_norm_b,
+                LLM_NORM, il);
+        cb(cur, "attn_norm", il);
+
+        // Self-attention with separate Q, K, V projections
+        {
+            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+            cb(Qcur, "Qcur", il);
+            Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+            cb(Qcur, "Qcur_bias", il);
+
+            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+            cb(Kcur, "Kcur", il);
+            Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+            cb(Kcur, "Kcur_bias", il);
+
+            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+            cb(Vcur, "Vcur", il);
+            Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+            cb(Vcur, "Vcur_bias", il);
+
+            // Reshape for attention
+            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+            // Apply RoPE
+            Qcur = ggml_rope_ext(
+                ctx0, Qcur, inp_pos, nullptr,
+                n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                ext_factor, attn_factor, beta_fast, beta_slow
+            );
+
+            Kcur = ggml_rope_ext(
+                ctx0, Kcur, inp_pos, nullptr,
+                n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                ext_factor, attn_factor, beta_fast, beta_slow
+            );
+
+            cb(Qcur, "Qcur_rope", il);
+            cb(Kcur, "Kcur_rope", il);
+
+            cur = build_attn(inp_attn,
+                    model.layers[il].wo, model.layers[il].bo,
+                    Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+        }
+
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+            inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+        }
+
+        // Residual connection
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+        cb(ffn_inp, "ffn_inp", il);
+
+        // Pre-FFN LayerNorm
+        cur = build_norm(ffn_inp,
+                model.layers[il].ffn_norm,
+                model.layers[il].ffn_norm_b,
+                LLM_NORM, il);
+        cb(cur, "ffn_norm", il);
+
+        // FFN with relu2 activation (ReLU squared) - no gate projection
+        // up -> relu2 -> down
+        cur = build_ffn(cur,
+                model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                NULL, NULL, NULL,  // no gate
+                model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                NULL,
+                LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il);
+        cb(cur, "ffn_out", il);
+
+        // Residual connection
+        inpL = ggml_add(ctx0, cur, ffn_inp);
+        inpL = build_cvec(inpL, il);
+        cb(inpL, "l_out", il);
+    }
+
+    // Final LayerNorm
+    cur = build_norm(inpL,
+            model.output_norm,
+            model.output_norm_b,
+            LLM_NORM, -1);
+    cb(cur, "result_norm", -1);
+
+    res->t_embd = cur;
+
+    // Output projection
+    cur = build_lora_mm(model.output, cur);
+    cb(cur, "result_output", -1);
+
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
diff --git a/llama/llama.cpp/src/models/jamba.cpp b/llama/llama.cpp/src/models/jamba.cpp
index a0187772ccb..ceab5817407 100644
--- a/llama/llama.cpp/src/models/jamba.cpp
+++ b/llama/llama.cpp/src/models/jamba.cpp
@@ -1,6 +1,6 @@
 #include "models.h"
 
-llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) {
+llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) {
     const int64_t n_embd_head = hparams.n_embd_head_v;
 
     ggml_tensor * cur;
diff --git a/llama/llama.cpp/src/models/kimi-linear.cpp b/llama/llama.cpp/src/models/kimi-linear.cpp
new file mode 100644
index 00000000000..83d11241f8d
--- /dev/null
+++ b/llama/llama.cpp/src/models/kimi-linear.cpp
@@ -0,0 +1,392 @@
+#include "models.h"
+#include "ggml.h"
+
+#include "llama-memory-recurrent.h"
+
+// Causal Conv1d function for Q,K,V
+// When qkv is 0, it is Q, 1 is K, 2 is V
+static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_tensor * conv_states_all, ggml_tensor * conv_state_all, int64_t qkv, ggml_tensor * x, ggml_tensor * proj_w, ggml_tensor * conv_w, int64_t d_conv, int64_t head_dim, int64_t n_head, int64_t n_seq_tokens, int64_t n_seqs, int64_t n_tokens, int64_t kv_head) {
+    const int64_t d_inner = head_dim * n_head;
+    const int64_t conv_state_size = (d_conv - 1) * d_inner;
+    const int64_t n_embd_r_total = 3 * conv_state_size;  // Q + K + V
+
+    // conv_state_all is [n_embd_r_total, n_seqs], split into Q, K, V
+    // Each conv state is [(d_conv-1) * d_inner] per sequence, need to reshape to [d_conv-1, d_inner, n_seqs]
+    // Memory layout: for each seq, Q state is first conv_state_size elements, then K, then V
+    // conv_state_all has stride: nb[0] = element_size, nb[1] = n_embd_r_total * element_size
+    // View Q conv state: offset 0, size conv_state_size per seq
+    // conv_state_all is [n_embd_r_total, n_seqs] with memory layout:
+    //   state[i + seq * n_embd_r_total] where i = conv_step + channel * (d_conv-1) + {0, conv_state_size, 2*conv_state_size} for Q/K/V
+    // We want [d_conv-1, d_inner, n_seqs] view:
+    //   nb1 = (d_conv-1) * element_size (stride between channels)
+    //   nb2 = n_embd_r_total * element_size (stride between seqs)
+    ggml_tensor * conv_state_x = ggml_view_3d(ctx0, conv_state_all, d_conv - 1, d_inner, n_seqs,
+        (d_conv - 1) * ggml_element_size(conv_state_all),  // nb1: stride between channels
+        n_embd_r_total * ggml_element_size(conv_state_all),  // nb2: stride between seqs
+        qkv * conv_state_size * ggml_element_size(conv_state_all));
+
+// Causal Conv1d function for Q,K,V
+// When qkv is 0, it is Q, 1 is K, 2 is V
+    // Step 1: Q, K, V projections -> [d_inner, n_tokens]
+    ggml_tensor * x_proj = ggml_mul_mat(ctx0, proj_w, x);
+
+    // Reshape input: {d_inner, n_tokens} -> {d_inner, n_seq_tokens, n_seqs}
+    ggml_tensor * x_3d = ggml_reshape_3d(ctx0, x_proj, d_inner, n_seq_tokens, n_seqs);
+
+    // Concat Q conv state and current input: {d_conv-1 + n_seq_tokens, d_inner, n_seqs}
+    ggml_tensor * conv_x = ggml_concat(ctx0, conv_state_x, ggml_transpose(ctx0, x_3d), 0);
+
+    // Save last (d_conv-1) columns back to Q conv state
+    ggml_tensor * last_conv_x = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs,
+        conv_x->nb[1], conv_x->nb[2], n_seq_tokens * conv_x->nb[0]);
+    ggml_build_forward_expand(gf,
+        ggml_cpy(ctx0, last_conv_x,
+            ggml_view_3d(ctx0, conv_states_all,
+                d_conv - 1, d_inner, n_seqs,
+                (d_conv - 1) * ggml_element_size(conv_states_all),           // nb1: contiguous within one channel's conv taps
+                n_embd_r_total * ggml_element_size(conv_states_all),         // nb2: stride between sequences (skip over K,V states)
+                (kv_head * n_embd_r_total + qkv * conv_state_size) * ggml_element_size(conv_states_all))));  // offset to first seq's Q/K/V state
+    // Reshape conv weight: GGUF [d_conv, 1, d_inner, 1] -> ggml_ssm_conv expects [d_conv, d_inner]
+    // GGUF stores as [d_conv, 1, d_inner, 1] with memory layout w[conv_step + channel * d_conv]
+    // vLLM stores as [d_inner, d_conv] with memory layout w[channel * d_conv + conv_step]
+    // ggml_ssm_conv computes: c[conv_step + channel * d_conv]
+    // GGUF layout: [d_conv, 1, d_inner] or [d_conv, 1, d_inner, 1] -> reshape to [d_conv, d_inner]
+    // Reshape conv weight from [d_conv, 1, d_inner, 1] to [d_conv, d_inner] for ggml_ssm_conv
+    ggml_tensor * conv_weight = ggml_reshape_2d(ctx0, conv_w, d_conv, d_inner);
+
+    // Apply conv1d
+    // ggml_ssm_conv output: {d_inner, n_seq_tokens, n_seqs}
+    ggml_tensor * Xcur = ggml_ssm_conv(ctx0, conv_x, conv_weight);
+    // Reshape to 2D for bias add: {d_inner, n_tokens}
+    Xcur = ggml_reshape_2d(ctx0, Xcur, d_inner, n_tokens);
+    Xcur = ggml_silu(ctx0, Xcur);
+
+    return ggml_reshape_4d(ctx0, Xcur, head_dim, n_head, n_seq_tokens, n_seqs);
+}
+
+llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params) :
+    llm_build_delta_net_base(params), model(model) {
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+    cb(inpL, "model.embed_tokens", -1);
+
+    // Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM)
+    // So we don't need inp_pos
+
+    auto * inp_kv = !hparams.is_mla() ? build_inp_mem_hybrid() : nullptr;
+    auto * inp_k = hparams.is_mla() ? build_inp_mem_hybrid_k() : nullptr;
+    auto * inp_rs = hparams.is_mla() ? inp_k->get_recr() : inp_kv->get_recr();
+    auto * inp_attn_kv = !hparams.is_mla() ? inp_kv->get_attn() : nullptr;
+    auto * inp_attn_k = hparams.is_mla() ? inp_k->get_attn() : nullptr;
+
+    // Output ids for selecting which tokens to output
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    // Kimi dimension constants
+    const int64_t n_head = hparams.n_head();
+    const int64_t head_dim = hparams.n_embd_head_kda;
+    const int64_t d_conv = hparams.ssm_d_conv;
+    const int64_t d_inner = n_head * head_dim;  // 32 * 128 = 4096
+    const int64_t n_seqs = ubatch.n_seqs;
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+    // Verify batch consistency for recurrent layers
+    GGML_ASSERT(n_seqs != 0);
+    GGML_ASSERT(ubatch.equal_seqs());
+    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+
+    // MLA params
+    const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla();
+    const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla();
+    const int64_t kv_lora_rank = hparams.n_lora_kv;
+    // qk_rope_head_dim = 64 (from Kimi config) which is hparams.n_rot
+    // Confirmed from tensor shape: wkv_a_mqa [2304, 576] = [n_embd, kv_lora_rank + qk_rope_head_dim]
+    const int64_t n_embd_head_qk_rope = hparams.n_rot;  // config.qk_rope_head_dim
+    const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope;  // 192 - 64 = 128
+    // Attention scale for MLA
+    const float kq_scale_mla = 1.0f / sqrtf((float)n_embd_head_k_mla);
+
+    for (int il = 0; il < n_layer; ++il) {
+        const auto & layer = model.layers[il];
+        ggml_tensor * inpSA = inpL;
+
+        // Attention Norm
+        cur = build_norm(inpL, layer.attn_norm, NULL, LLM_NORM_RMS, il);
+        cb(cur, "attn_norm", il);
+
+        ggml_build_forward_expand(gf, cur);
+
+        // Check layer type by checking which tensors exist
+        // KDA layers have ssm_a_log tensor, MLA layers have wkv_a_mqa tensor
+        bool is_kda = (layer.ssm_a != nullptr);
+        bool is_mla = (layer.wkv_a_mqa != nullptr);
+
+        if (is_kda) {
+            // === KDA Layer (Kimi Delta Attention) with Recurrent State ===
+            // Reference: vLLM kda.py
+            const auto * mctx_cur = inp_rs->mctx;
+            const auto kv_head = mctx_cur->get_head();
+
+            // Get conv states from r_l tensor (Q, K, V each have separate state)
+            ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
+            cb(conv_states_all, "conv_states_all", il);
+            ggml_tensor * conv_state_all = build_rs(inp_rs, conv_states_all, hparams.n_embd_r(), n_seqs);
+            ggml_tensor * Qcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 0, cur, layer.wq, layer.ssm_q_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
+            ggml_tensor * Kcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 1, cur, layer.wk, layer.ssm_k_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
+            ggml_tensor * Vcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 2, cur, layer.wv, layer.ssm_v_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
+
+            // g1 = -exp(A_log) * softplus(f_b(f_a(x)) + dt_bias)
+            ggml_tensor * f_a = ggml_mul_mat(ctx0, layer.ssm_f_a, cur);
+            ggml_tensor * g1 = ggml_mul_mat(ctx0, layer.ssm_f_b, f_a);
+            cb(g1, "g1 f_b(f_a(cur))", il);
+            g1 = ggml_add(ctx0, g1, layer.ssm_dt_b);
+            g1 = ggml_softplus(ctx0, g1);
+            g1 = ggml_reshape_3d(ctx0, g1, head_dim, n_head, n_tokens);
+
+            // A_log shape is [1, n_head] or [1, n_head, 1, 1], need to broadcast to [head_dim, n_head, n_tokens]. No need to -exp(a_log) because it was done in convert_hf_to_gguf.py
+            // Reshape to [1, n_head, 1] for broadcasting with g1 [head_dim, n_head, n_tokens]
+            ggml_tensor * A = ggml_reshape_3d(ctx0, layer.ssm_a, 1, n_head, 1);
+            g1 = ggml_mul(ctx0, g1, A);
+            cb(g1, "kda_g1", il);
+
+            g1 = ggml_reshape_4d(ctx0, g1, head_dim, n_head, n_seq_tokens, n_seqs);
+
+            // Compute beta (mixing coefficient)
+            ggml_tensor * beta = ggml_mul_mat(ctx0, layer.ssm_beta, cur);
+            beta = ggml_reshape_4d(ctx0, beta, 1, n_head, n_seq_tokens, n_seqs);
+            cb(beta, "kda_beta", il);
+
+            beta = ggml_sigmoid(ctx0, beta);
+
+            // Reshape for KDA recurrence
+            // {n_embd, n_tokens} -> {n_embd, n_seq_tokens, n_seqs}
+            cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
+
+            // Get SSM state and compute KDA recurrence using ggml_kda_scan
+            ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
+            ggml_tensor * state = build_rs(inp_rs, ssm_states_all, hparams.n_embd_s(), n_seqs);
+            state = ggml_reshape_4d(ctx0, state, head_dim, head_dim, n_head, n_seqs);
+
+            const float eps_norm = hparams.f_norm_rms_eps;
+
+            Qcur = ggml_l2_norm(ctx0, Qcur, eps_norm);
+            Kcur = ggml_l2_norm(ctx0, Kcur, eps_norm);
+
+            // Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens
+            std::pair attn_out = n_seq_tokens == 1 ?
+                build_delta_net_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) :
+                build_delta_net_chunking(Qcur, Kcur, Vcur, g1, beta, state, il);
+
+            ggml_tensor * output = ggml_cont(ctx0, attn_out.first);
+            ggml_tensor * new_state = attn_out.second;
+            cb(output, "attn_output", il);
+            cb(new_state, "new_state", il);
+
+            // Update the recurrent states
+            ggml_build_forward_expand(gf,
+                                     ggml_cpy(ctx0, new_state,
+                                              ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
+                                                           kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
+
+            // Output gating g2 = g_b(g_a(x))
+            ggml_tensor * cur_2d = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
+            ggml_tensor * g_a = ggml_mul_mat(ctx0, layer.ssm_g_a, cur_2d);
+            ggml_tensor * g2 = ggml_mul_mat(ctx0, layer.ssm_g_b, g_a);
+            cb(g2, "g2 g_b(g_a(cur_2d))", il);
+            g2 = ggml_reshape_3d(ctx0, g2, head_dim, n_head, n_seq_tokens * n_seqs);
+
+            // Apply o_norm with sigmoid gating
+            // Note: Kimi model uses sigmoid gating, not SiLU (despite FusedRMSNormGated default being swish)
+            // Formula: output = RMSNorm(x) * sigmoid(g)
+            ggml_tensor * attn_out_final = ggml_reshape_3d(ctx0, output, head_dim, n_head,  n_seq_tokens * n_seqs);
+            ggml_tensor * normed = build_norm(attn_out_final, layer.ssm_o_norm, nullptr, LLM_NORM_RMS, il);
+            cb(normed, "kda_normed", il);
+            ggml_tensor * gate = ggml_sigmoid(ctx0, g2);
+            ggml_tensor * gated = ggml_mul(ctx0, normed, gate);
+
+            // Output projection
+            gated = ggml_cont_2d(ctx0, gated, d_inner, n_tokens);
+            cur = ggml_mul_mat(ctx0, layer.wo, gated);
+            cb(cur, "kda_out", il);
+
+        } else if (is_mla) {
+            // === MLA Layer (Multi-head Latent Attention) without KV Cache ===
+            // Reference: vLLM mla.py
+            // Step 1: Q projection and reshape
+            // vLLM Kimi: q = q_proj(hidden_states), then view as [n_tokens, n_head, qk_head_dim]
+            // Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM)
+            ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.wq, cur);
+
+            // Step 2: KV compression
+            // kv_cmpr_pe = kv_a_proj_with_mqa(hidden_states) -> [kv_lora_rank + qk_rope_head_dim, n_tokens]
+            ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, layer.wkv_a_mqa, cur);
+
+            // Split: kv_cmpr = kv_lora[:kv_lora_rank], k_pe = kv_lora[kv_lora_rank:]
+            ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_cmpr_pe, kv_lora_rank, n_tokens,
+                ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0);
+            ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, n_embd_head_qk_rope, 1, n_tokens,
+                ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
+                ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
+                ggml_row_size(kv_cmpr_pe->type, kv_lora_rank));
+            // Note: Kimi MLA does NOT apply RoPE (rotary_emb=None in vLLM)
+            // k_pe is used directly without RoPE
+            // Normalize kv_c
+            kv_cmpr = build_norm(kv_cmpr, layer.attn_kv_a_norm, nullptr, LLM_NORM_RMS, il);
+
+            if (layer.wk_b && layer.wv_b) { // MLA KV cache enabled
+                // extract q_nope
+                ggml_tensor * q_nope =
+                    ggml_view_3d(ctx0, Qcur, n_embd_head_qk_nope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla),
+                                 ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, 0);
+                cb(q_nope, "q_nope", il);
+
+                // and {n_embd_head_qk_rope, n_head, n_tokens}
+                ggml_tensor * q_pe = ggml_view_3d(
+                    ctx0, Qcur, n_embd_head_qk_rope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla),
+                    ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, ggml_row_size(Qcur->type, n_embd_head_qk_nope));
+                cb(q_pe, "q_pe", il);
+
+                // {n_embd_head_qk_nope, n_tokens, n_head}
+                q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
+                cb(q_nope, "q_nope_perm", il);
+
+                // {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head}
+                ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, layer.wk_b, q_nope);
+                cb(q_nope_absorbed, "q_nope_absorbed", il);
+
+                // {kv_lora_rank, n_head, n_tokens}
+                q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3);
+                cb(q_nope_absorbed, "q_nope_absorbed_perm", il);
+
+                // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
+                // note: rope must go first for in-place context shifting in build_rope_shift()
+                Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0);
+                cb(Qcur, "Qcur", il);
+
+                kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
+                cb(kv_cmpr, "kv_cmpr_reshape", il);
+
+                // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
+                ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0);
+                cb(Kcur, "Kcur", il);
+
+                // {kv_lora_rank, 1, n_tokens}
+                ggml_tensor * Vcur = kv_cmpr;
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(inp_attn_k, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, layer.wv_b, kq_scale_mla, il);
+                cb(cur, "mla_out", il);
+            } else { // MLA KV cache disabled. Fall back to MHA KV cache.
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k_mla, n_head, n_tokens);
+                cb(Qcur, "mla_Q", il);
+                // KV decompression: kv = kv_b_proj(kv_c_normed)
+                ggml_tensor * kv = ggml_mul_mat(ctx0, layer.wkv_b, kv_cmpr);
+                const int64_t kv_per_head = n_embd_head_qk_nope + n_embd_head_v_mla;
+
+                // Split kv into k_nope and v
+                ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
+                    ggml_row_size(kv->type, kv_per_head),
+                    ggml_row_size(kv->type, kv_per_head * n_head), 0);
+                ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, n_embd_head_v_mla, n_head, n_tokens,
+                    ggml_row_size(kv->type, kv_per_head),
+                    ggml_row_size(kv->type, kv_per_head * n_head),
+                    ggml_row_size(kv->type, n_embd_head_qk_nope));
+                Vcur = ggml_cont(ctx0, Vcur);
+                cb(Vcur, "mla_V", il);
+
+                // Concatenate k_nope + k_pe (broadcast k_pe to all heads)
+                // K = [k_nope, k_pe] where k_nope is [qk_nope_head_dim, n_head, n_tokens]
+                // and k_pe is [qk_rope_head_dim, 1, n_tokens] broadcast to all heads
+                // Need to broadcast k_pe from [qk_rope, 1, n_tokens] to [qk_rope, n_head, n_tokens]
+                ggml_tensor * k_pe_target = ggml_new_tensor_3d(ctx0, k_pe->type, n_embd_head_qk_rope, n_head, n_tokens);
+                ggml_tensor * k_pe_repeated = ggml_repeat(ctx0, k_pe, k_pe_target);
+                ggml_tensor * Kcur = ggml_concat(ctx0, k_pe_repeated, k_nope, 0);
+                cb(Kcur, "mla_K", il);
+
+                // Direct softmax attention (with MHA KV cache)
+                // Use build_attn with inp_attn for proper mask handling
+                cur = build_attn(inp_attn_kv, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il);
+                cb(cur, "mla_out", il);
+            }
+        } else {
+            // Unknown layer type - this should not happen
+            GGML_ABORT("Kimi layer is neither KDA nor MLA - missing required tensors");
+        }
+
+        // On last layer, select only the output tokens
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur   = ggml_get_rows(ctx0, cur,   inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+
+        // Residual
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+        cb(ffn_inp, "ffn_inp", il);
+
+        // FFN Norm
+        cur = build_norm(ffn_inp, layer.ffn_norm, NULL, LLM_NORM_RMS, il);
+        cb(cur, "ffn_norm", il);
+
+        if ((uint32_t) il < hparams.n_layer_dense_lead) {
+            // Dense FFN layer
+            cur = build_ffn(cur,
+                layer.ffn_up, NULL, NULL,
+                layer.ffn_gate, NULL, NULL,
+                layer.ffn_down, NULL, NULL,
+                NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+        } else {
+            // MoE layer
+            // Kimi uses moe_renormalize=True and routed_scaling_factor (stored as expert_weights_scale) = 2.446
+            ggml_tensor * moe_out = build_moe_ffn(cur,
+                layer.ffn_gate_inp,
+                layer.ffn_up_exps,
+                layer.ffn_gate_exps,
+                layer.ffn_down_exps,
+                layer.ffn_exp_probs_b,
+                hparams.n_expert,
+                hparams.n_expert_used,
+                LLM_FFN_SILU, true,
+                true, hparams.expert_weights_scale,
+                (llama_expert_gating_func_type) hparams.expert_gating_func,
+                il);
+            cb(moe_out, "ffn_moe_out", il);
+
+            // Shared expert
+            {
+                ggml_tensor * ffn_shexp = build_ffn(cur,
+                        layer.ffn_up_shexp, NULL, NULL,
+                        layer.ffn_gate_shexp, NULL, NULL,
+                        layer.ffn_down_shexp, NULL, NULL,
+                        NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(ffn_shexp, "ffn_shexp", il);
+
+                cur = ggml_add(ctx0, moe_out, ffn_shexp);
+                cb(cur, "ffn_out", il);
+            }
+        }
+        // Residual
+        cur = ggml_add(ctx0, cur, ffn_inp);
+
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        inpL = cur;
+    }
+    cur = inpL;
+
+    // Final Norm
+    cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
+
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    // Output
+    cur = ggml_mul_mat(ctx0, model.output, cur);
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
diff --git a/llama/llama.cpp/src/models/lfm2.cpp b/llama/llama.cpp/src/models/lfm2.cpp
index 7f805d78795..cf01ad62557 100644
--- a/llama/llama.cpp/src/models/lfm2.cpp
+++ b/llama/llama.cpp/src/models/lfm2.cpp
@@ -1,18 +1,149 @@
 #include "models.h"
 
+#include "../llama-memory-hybrid-iswa.h"
 #include "../llama-memory-hybrid.h"
 
-
-llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params & params) :
-    llm_graph_context(params),
-    model(model) {
+template 
+llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params & params) :
+    llm_graph_context(params) {
+    using inp_hybrid_type = std::conditional_t;
+    using inp_attn_type   = std::conditional_t;
+    using mem_hybrid_ctx  = std::conditional_t;
+
+    // lambda helpers for readability
+    auto build_dense_feed_forward = [&model, this](ggml_tensor * cur, int il) -> ggml_tensor * {
+        GGML_ASSERT(!model.layers[il].ffn_up_b);
+        GGML_ASSERT(!model.layers[il].ffn_gate_b);
+        GGML_ASSERT(!model.layers[il].ffn_down_b);
+        return build_ffn(cur,
+            model.layers[il].ffn_up, NULL, NULL,
+            model.layers[il].ffn_gate, NULL, NULL,
+            model.layers[il].ffn_down, NULL, NULL,
+            NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
+    };
+    auto build_moe_feed_forward = [&model, this](ggml_tensor * cur, int il) -> ggml_tensor * {
+        return build_moe_ffn(cur,
+                            model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
+                            model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
+                            model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0,
+                            static_cast(hparams.expert_gating_func), il);
+    };
+    auto build_attn_block = [&model, this](ggml_tensor *   cur,
+                                           ggml_tensor *   inp_pos,
+                                           inp_attn_type * inp_attn,
+                                           int             il) -> ggml_tensor * {
+        GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il));
+        const auto n_embd_head = hparams.n_embd_head_v;
+        const auto n_head_kv   = hparams.n_head_kv(il);
+
+        auto * q = build_lora_mm(model.layers[il].wq, cur);
+        cb(q, "model.layers.{}.self_attn.q_proj", il);
+        auto * k = build_lora_mm(model.layers[il].wk, cur);
+        cb(k, "model.layers.{}.self_attn.k_proj", il);
+        auto * v = build_lora_mm(model.layers[il].wv, cur);
+        cb(v, "model.layers.{}.self_attn.v_proj", il);
+
+        q = ggml_reshape_3d(ctx0, q, n_embd_head, n_head, n_tokens);
+        k = ggml_reshape_3d(ctx0, k, n_embd_head, n_head_kv, n_tokens);
+        v = ggml_reshape_3d(ctx0, v, n_embd_head, n_head_kv, n_tokens);
+
+        // qk norm
+        q = build_norm(q, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+        cb(q, "model.layers.{}.self_attn.q_layernorm", il);
+        k = build_norm(k, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+        cb(k, "model.layers.{}.self_attn.k_layernorm", il);
+
+        // RoPE
+        q = ggml_rope_ext(ctx0, q, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
+                          attn_factor, beta_fast, beta_slow);
+        k = ggml_rope_ext(ctx0, k, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
+                          attn_factor, beta_fast, beta_slow);
+
+        cur = build_attn(inp_attn,
+                model.layers[il].wo, NULL,
+                q, k, v, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
+
+        cb(cur, "model.layers.{}.self_attn.out_proj", il);
+
+        return cur;
+    };
+    auto build_shortconv_block = [&model, this](ggml_tensor *        cur,
+                                                llm_graph_input_rs * inp_recr,
+                                                int                  il) -> ggml_tensor * {
+        const auto * mctx_cur = static_cast(mctx)->get_recr();
+        const uint32_t kv_head      = mctx_cur->get_head();
+        const int64_t  n_seq_tokens = ubatch.n_seq_tokens;
+        const int64_t  n_seqs       = ubatch.n_seqs;
+        GGML_ASSERT(n_seqs != 0);
+        GGML_ASSERT(ubatch.equal_seqs());
+        GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+
+        GGML_ASSERT(hparams.n_shortconv_l_cache > 1);
+        const uint32_t d_conv = hparams.n_shortconv_l_cache - 1;
+
+        // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
+        cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
+
+        auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur);
+        cb(bcx, "model.layers.{}.conv.in_proj", il);
+
+        constexpr auto n_chunks = 3;
+        GGML_ASSERT(bcx->ne[0] % n_chunks == 0);
+        const auto chunk_size = bcx->ne[0] / n_chunks;
+        auto *     b          = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2],
+                                             0 * chunk_size * ggml_element_size(bcx));
+        auto *     c          = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2],
+                                             1 * chunk_size * ggml_element_size(bcx));
+        auto *     x          = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2],
+                                             2 * chunk_size * ggml_element_size(bcx));
+
+        auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x));
+
+        // read conv state
+        auto * conv_state = mctx_cur->get_r_l(il);
+        auto * conv_rs    = build_rs(inp_recr, conv_state, hparams.n_embd_r(), n_seqs);
+        auto * conv       = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs);
+
+        bx = ggml_concat(ctx0, conv, bx, 0);
+        GGML_ASSERT(bx->ne[0] > conv->ne[0]);
+
+        // last d_conv columns is a new conv state
+        auto * new_conv = ggml_view_3d(ctx0, bx, conv->ne[0], bx->ne[1], bx->ne[2], bx->nb[1], bx->nb[2],
+                                       (bx->ne[0] - conv->ne[0]) * ggml_element_size(bx));
+        GGML_ASSERT(ggml_are_same_shape(conv, new_conv));
+
+        // write new conv conv state
+        ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv,
+                                               ggml_view_1d(ctx0, conv_state, ggml_nelements(new_conv),
+                                                            kv_head * d_conv * n_embd * ggml_element_size(new_conv))));
+
+        auto * conv_kernel = model.layers[il].shortconv.conv;
+        auto * conv_out    = ggml_ssm_conv(ctx0, bx, conv_kernel);
+        cb(conv_out, "model.layers.{}.conv.conv", il);
+
+        auto * y = ggml_mul(ctx0, c, conv_out);
+        y        = build_lora_mm(model.layers[il].shortconv.out_proj, y);
+        cb(y, "model.layers.{}.conv.out_proj", il);
+        // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
+        y = ggml_reshape_2d(ctx0, y, y->ne[0], n_seq_tokens * n_seqs);
+
+        return y;
+    };
+
+    // actual graph construction starts here
     ggml_tensor * cur = build_inp_embd(model.tok_embd);
     cb(cur, "model.embed_tokens", -1);
 
     ggml_build_forward_expand(gf, cur);
 
+    inp_hybrid_type * inp_hybrid = nullptr;
+    if constexpr (iswa) {
+        inp_hybrid = build_inp_mem_hybrid_iswa();
+    } else {
+        inp_hybrid = build_inp_mem_hybrid();
+    }
+
     ggml_tensor * inp_pos     = build_inp_pos();
-    auto *        inp_hybrid  = build_inp_mem_hybrid();
     ggml_tensor * inp_out_ids = build_inp_out_ids();
 
     for (int il = 0; il < n_layer; ++il) {
@@ -54,122 +185,6 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params
     ggml_build_forward_expand(gf, cur);
 }
 
-ggml_tensor * llm_build_lfm2::build_moe_feed_forward(ggml_tensor * cur, int il) const {
-    return build_moe_ffn(cur,
-                        model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
-                        model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
-                        model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0,
-                        static_cast(hparams.expert_gating_func), il);
-}
-
-ggml_tensor * llm_build_lfm2::build_dense_feed_forward(ggml_tensor * cur, int il) const {
-    GGML_ASSERT(!model.layers[il].ffn_up_b);
-    GGML_ASSERT(!model.layers[il].ffn_gate_b);
-    GGML_ASSERT(!model.layers[il].ffn_down_b);
-    return build_ffn(cur,
-        model.layers[il].ffn_up, NULL, NULL,
-        model.layers[il].ffn_gate, NULL, NULL,
-        model.layers[il].ffn_down, NULL, NULL,
-        NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
-}
-
-ggml_tensor * llm_build_lfm2::build_attn_block(ggml_tensor *             cur,
-                                               ggml_tensor *             inp_pos,
-                                               llm_graph_input_attn_kv * inp_attn,
-                                               int                       il) const {
-    GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il));
-    const auto n_embd_head = hparams.n_embd_head_v;
-    const auto n_head_kv   = hparams.n_head_kv(il);
-
-    auto * q = build_lora_mm(model.layers[il].wq, cur);
-    cb(q, "model.layers.{}.self_attn.q_proj", il);
-    auto * k = build_lora_mm(model.layers[il].wk, cur);
-    cb(k, "model.layers.{}.self_attn.k_proj", il);
-    auto * v = build_lora_mm(model.layers[il].wv, cur);
-    cb(v, "model.layers.{}.self_attn.v_proj", il);
-
-    q = ggml_reshape_3d(ctx0, q, n_embd_head, n_head, n_tokens);
-    k = ggml_reshape_3d(ctx0, k, n_embd_head, n_head_kv, n_tokens);
-    v = ggml_reshape_3d(ctx0, v, n_embd_head, n_head_kv, n_tokens);
-
-    // qk norm
-    q = build_norm(q, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
-    cb(q, "model.layers.{}.self_attn.q_layernorm", il);
-    k = build_norm(k, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
-    cb(k, "model.layers.{}.self_attn.k_layernorm", il);
-
-    // RoPE
-    q = ggml_rope_ext(ctx0, q, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
-                      attn_factor, beta_fast, beta_slow);
-    k = ggml_rope_ext(ctx0, k, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
-                      attn_factor, beta_fast, beta_slow);
-
-    cur = build_attn(inp_attn,
-            model.layers[il].wo, NULL,
-            q, k, v, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
-
-    cb(cur, "model.layers.{}.self_attn.out_proj", il);
-
-    return cur;
-}
-
-ggml_tensor * llm_build_lfm2::build_shortconv_block(ggml_tensor * cur, llm_graph_input_rs * inp_recr, int il) {
-    const auto *   mctx_cur     = static_cast(mctx)->get_recr();
-    const uint32_t kv_head      = mctx_cur->get_head();
-    const int64_t  n_seq_tokens = ubatch.n_seq_tokens;
-    const int64_t  n_seqs       = ubatch.n_seqs;
-    GGML_ASSERT(n_seqs != 0);
-    GGML_ASSERT(ubatch.equal_seqs());
-    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
-
-    GGML_ASSERT(hparams.n_shortconv_l_cache > 1);
-    const uint32_t d_conv = hparams.n_shortconv_l_cache - 1;
-
-    // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
-    cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
-
-    auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur);
-    cb(bcx, "model.layers.{}.conv.in_proj", il);
-
-    constexpr auto n_chunks = 3;
-    GGML_ASSERT(bcx->ne[0] % n_chunks == 0);
-    const auto chunk_size = bcx->ne[0] / n_chunks;
-    auto *     b          = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2],
-                                         0 * chunk_size * ggml_element_size(bcx));
-    auto *     c          = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2],
-                                         1 * chunk_size * ggml_element_size(bcx));
-    auto *     x          = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2],
-                                         2 * chunk_size * ggml_element_size(bcx));
-
-    auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x));
-
-    // read conv state
-    auto * conv_state = mctx_cur->get_r_l(il);
-    auto * conv_rs    = build_rs(inp_recr, conv_state, hparams.n_embd_r(), n_seqs);
-    auto * conv       = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs);
-
-    bx = ggml_concat(ctx0, conv, bx, 0);
-    GGML_ASSERT(bx->ne[0] > conv->ne[0]);
-
-    // last d_conv columns is a new conv state
-    auto * new_conv = ggml_view_3d(ctx0, bx, conv->ne[0], bx->ne[1], bx->ne[2], bx->nb[1], bx->nb[2],
-                                   (bx->ne[0] - conv->ne[0]) * ggml_element_size(bx));
-    GGML_ASSERT(ggml_are_same_shape(conv, new_conv));
-
-    // write new conv conv state
-    ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv,
-                                           ggml_view_1d(ctx0, conv_state, ggml_nelements(new_conv),
-                                                        kv_head * d_conv * n_embd * ggml_element_size(new_conv))));
-
-    auto * conv_kernel = model.layers[il].shortconv.conv;
-    auto * conv_out    = ggml_ssm_conv(ctx0, bx, conv_kernel);
-    cb(conv_out, "model.layers.{}.conv.conv", il);
-
-    auto * y = ggml_mul(ctx0, c, conv_out);
-    y        = build_lora_mm(model.layers[il].shortconv.out_proj, y);
-    cb(y, "model.layers.{}.conv.out_proj", il);
-    // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
-    y = ggml_reshape_2d(ctx0, y, y->ne[0], n_seq_tokens * n_seqs);
-
-    return y;
-}
+// Explicit template instantiations
+template struct llm_build_lfm2;
+template struct llm_build_lfm2;
diff --git a/llama/llama.cpp/src/models/llama-iswa.cpp b/llama/llama.cpp/src/models/llama-iswa.cpp
index 03f80616821..61dd2c179f1 100644
--- a/llama/llama.cpp/src/models/llama-iswa.cpp
+++ b/llama/llama.cpp/src/models/llama-iswa.cpp
@@ -25,8 +25,12 @@ llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_
     ggml_tensor * inp_out_ids = build_inp_out_ids();
 
     for (int il = 0; il < n_layer; ++il) {
+        const float freq_base_l  = model.get_rope_freq_base (cparams, il);
+        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+
         ggml_tensor * inpSA = inpL;
 
+        // This overlaps with SWA layers in current models, so get_rope_freq_base/scale may be superfluous
         const bool use_rope = hparams.n_no_rope_layer_step > 0 &&
                               (il + 1) % hparams.n_no_rope_layer_step != 0;
 
@@ -67,13 +71,13 @@ llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_
             if (use_rope) {
                 Qcur = ggml_rope_ext(
                         ctx0, Qcur, inp_pos, rope_factors,
-                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
 
                 Kcur = ggml_rope_ext(
                         ctx0, Kcur, inp_pos, rope_factors,
-                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
             } else if (inp_attn_scale) {
diff --git a/llama/llama.cpp/src/models/llama.cpp b/llama/llama.cpp/src/models/llama.cpp
index ab7fd5d0508..42b5fcdf42e 100644
--- a/llama/llama.cpp/src/models/llama.cpp
+++ b/llama/llama.cpp/src/models/llama.cpp
@@ -1,6 +1,7 @@
 #include "models.h"
 
-llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+template 
+llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
     const int64_t n_embd_head = hparams.n_embd_head_v;
 
     GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -14,7 +15,14 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_para
     // inp_pos - contains the positions
     ggml_tensor * inp_pos = build_inp_pos();
 
-    auto * inp_attn = build_attn_inp_kv();
+    using inp_attn_type = std::conditional_t;
+
+    inp_attn_type * inp_attn = nullptr;
+    if constexpr (embed) {
+        inp_attn = build_attn_inp_no_cache();
+    } else {
+        inp_attn = build_attn_inp_kv();
+    }
 
     const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
@@ -145,11 +153,16 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_para
     cb(cur, "result_norm", -1);
     res->t_embd = cur;
 
-    // lm_head
-    cur = build_lora_mm(model.output, cur);
+    if constexpr (!embed) {
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
 
-    cb(cur, "result_output", -1);
-    res->t_logits = cur;
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+    }
 
     ggml_build_forward_expand(gf, cur);
 }
+
+template struct llm_build_llama;
+template struct llm_build_llama;
diff --git a/llama/llama.cpp/src/models/maincoder.cpp b/llama/llama.cpp/src/models/maincoder.cpp
new file mode 100644
index 00000000000..da57308167e
--- /dev/null
+++ b/llama/llama.cpp/src/models/maincoder.cpp
@@ -0,0 +1,117 @@
+#include "models.h"
+
+llm_build_maincoder::llm_build_maincoder(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+    const int64_t n_embd_head = hparams.n_embd_head_v;
+
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+
+    // inp_pos - contains the positions
+    ggml_tensor * inp_pos = build_inp_pos();
+
+    auto * inp_attn = build_attn_inp_kv();
+
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_tensor * inpSA = inpL;
+
+        // norm
+        cur = build_norm(inpL,
+                model.layers[il].attn_norm, NULL,
+                LLM_NORM_RMS, il);
+        cb(cur, "attn_norm", il);
+
+        // self-attention
+        {
+            // compute Q and K and RoPE them
+            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+            cb(Qcur, "Qcur", il);
+
+            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+            cb(Kcur, "Kcur", il);
+
+            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+            cb(Vcur, "Vcur", il);
+
+            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+            Qcur = ggml_rope_ext(
+                    ctx0, Qcur, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            Kcur = ggml_rope_ext(
+                    ctx0, Kcur, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+            cb(Qcur, "Qcur_normed", il);
+
+            Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+            cb(Kcur, "Kcur_normed", il);
+
+            cb(Qcur, "Qcur", il);
+            cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
+
+            cur = build_attn(inp_attn,
+                    model.layers[il].wo, model.layers[il].bo,
+                    Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+        }
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+        cb(ffn_inp, "ffn_inp", il);
+
+        // feed-forward network
+        cur = build_norm(ffn_inp,
+                model.layers[il].ffn_norm, NULL,
+                LLM_NORM_RMS, il);
+        cb(cur, "ffn_norm", il);
+
+        cur = build_ffn(cur,
+                model.layers[il].ffn_up,   NULL, NULL,
+                model.layers[il].ffn_gate, NULL, NULL,
+                model.layers[il].ffn_down, NULL, NULL,
+                NULL,
+                LLM_FFN_SILU, LLM_FFN_PAR, il);
+        cb(cur, "ffn_out", il);
+
+        cur = ggml_add(ctx0, cur, ffn_inp);
+
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        // input for next layer
+        inpL = cur;
+    }
+    cur = inpL;
+
+    cur = build_norm(cur,
+            model.output_norm, NULL,
+            LLM_NORM_RMS, -1);
+
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    // lm_head
+    cur = build_lora_mm(model.output, cur);
+
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
diff --git a/llama/llama.cpp/src/models/graph-context-mamba.cpp b/llama/llama.cpp/src/models/mamba-base.cpp
similarity index 97%
rename from llama/llama.cpp/src/models/graph-context-mamba.cpp
rename to llama/llama.cpp/src/models/mamba-base.cpp
index b9a363b32b6..aaac9487dfa 100644
--- a/llama/llama.cpp/src/models/graph-context-mamba.cpp
+++ b/llama/llama.cpp/src/models/mamba-base.cpp
@@ -1,8 +1,10 @@
 #include "models.h"
 
-llm_graph_context_mamba::llm_graph_context_mamba(const llm_graph_params & params) : llm_graph_context(params) {}
+#include "llama-memory-recurrent.h"
 
-ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * inp,
+llm_build_mamba_base::llm_build_mamba_base(const llm_graph_params & params) : llm_graph_context(params) {}
+
+ggml_tensor * llm_build_mamba_base::build_mamba_layer(llm_graph_input_rs * inp,
                                                          ggml_tensor *        cur,
                                                          const llama_model &  model,
                                                          const llama_ubatch & ubatch,
@@ -143,7 +145,7 @@ ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * in
     return cur;
 }
 
-ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * inp,
+ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp,
                                                           ggml_tensor *        cur,
                                                           const llama_model &  model,
                                                           const llama_ubatch & ubatch,
diff --git a/llama/llama.cpp/src/models/mamba.cpp b/llama/llama.cpp/src/models/mamba.cpp
index 46819613c2d..55fd2e055c4 100644
--- a/llama/llama.cpp/src/models/mamba.cpp
+++ b/llama/llama.cpp/src/models/mamba.cpp
@@ -1,7 +1,6 @@
 #include "models.h"
 
-
-llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) {
+llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) {
     ggml_tensor * cur;
     ggml_tensor * inpL;
 
diff --git a/llama/llama.cpp/src/models/mimo2-iswa.cpp b/llama/llama.cpp/src/models/mimo2-iswa.cpp
new file mode 100644
index 00000000000..edc87cc9f0d
--- /dev/null
+++ b/llama/llama.cpp/src/models/mimo2-iswa.cpp
@@ -0,0 +1,123 @@
+
+#include "models.h"
+
+llm_build_mimo2_iswa::llm_build_mimo2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+
+    ggml_tensor * inp_pos = build_inp_pos();
+    auto * inp_attn = build_attn_inp_kv_iswa();
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_tensor * inpSA = inpL;
+
+        uint32_t n_head_l    = hparams.n_head(il);
+        uint32_t n_head_kv_l = hparams.n_head_kv(il);
+        const float freq_base_l  = model.get_rope_freq_base(cparams, il);
+        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+
+        cur = inpL;
+
+        // self_attention
+        {
+            cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // compute Q and K and RoPE them
+            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+            cb(Qcur, "Qcur", il);
+
+            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+            cb(Kcur, "Kcur", il);
+
+            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+            cb(Vcur, "Vcur", il);
+
+            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l,    n_tokens);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens);
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens);
+
+            Qcur = ggml_rope_ext(
+                ctx0, Qcur, inp_pos, nullptr,
+                n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                ext_factor, attn_factor, beta_fast, beta_slow
+                );
+
+            Kcur = ggml_rope_ext(
+                ctx0, Kcur, inp_pos, nullptr,
+                n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                ext_factor, attn_factor, beta_fast, beta_slow
+                );
+
+            cb(Qcur, "Qcur", il);
+            cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
+
+            ggml_tensor * sinks = model.layers[il].attn_sinks;
+
+            cur = build_attn(inp_attn,
+                    model.layers[il].wo, NULL,
+                    Qcur, Kcur, Vcur, nullptr, sinks, nullptr, 1.0f/sqrtf(float(n_embd_head_k)), il);
+        }
+
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+        cb(ffn_inp, "ffn_inp", il);
+
+        cur = build_norm(ffn_inp,
+                model.layers[il].ffn_norm, NULL,
+                LLM_NORM_RMS, il);
+        cb(cur, "ffn_norm", il);
+
+        // feed-forward network
+        if (model.layers[il].ffn_gate_inp == nullptr) {
+            // dense branch
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                    model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                    model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+        } else {
+            // MoE branch
+            cur = build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
+                                model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
+                                model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, false,
+                                0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, il);
+            cb(cur, "ffn_moe_out", il);
+        }
+
+        cur = ggml_add(ctx0, cur, ffn_inp);
+
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        // input for next layer
+        inpL = cur;
+    }
+
+    cur = inpL;
+
+    cur = build_norm(cur,
+            model.output_norm, NULL,
+            LLM_NORM_RMS, -1);
+
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    // lm_head
+    cur = build_lora_mm(model.output, cur);
+
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
diff --git a/llama/llama.cpp/src/models/minicpm3.cpp b/llama/llama.cpp/src/models/minicpm3.cpp
index f374a9fd030..297cc34ba58 100644
--- a/llama/llama.cpp/src/models/minicpm3.cpp
+++ b/llama/llama.cpp/src/models/minicpm3.cpp
@@ -9,6 +9,7 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap
 
     const uint32_t n_embd_head_qk_rope = hparams.n_rot;
     const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
+
     const uint32_t kv_lora_rank = hparams.n_lora_kv;
 
     ggml_tensor * cur;
diff --git a/llama/llama.cpp/src/models/models.go b/llama/llama.cpp/src/models/models.go
index c6c42b6fee1..fda1290eae5 100644
--- a/llama/llama.cpp/src/models/models.go
+++ b/llama/llama.cpp/src/models/models.go
@@ -1,6 +1,6 @@
 package models
 
 // #cgo CXXFLAGS: -std=c++17
-// #cgo CPPFLAGS: -I${SRCDIR}/../../include -I${SRCDIR}/../../vendor
+// #cgo CPPFLAGS: -I${SRCDIR}/.. -I${SRCDIR}/../../include -I${SRCDIR}/../../vendor
 // #cgo CPPFLAGS: -I${SRCDIR}/../../../../ml/backend/ggml/ggml/include
 import "C"
diff --git a/llama/llama.cpp/src/models/models.h b/llama/llama.cpp/src/models/models.h
index 6d84a185d73..d076bf28809 100644
--- a/llama/llama.cpp/src/models/models.h
+++ b/llama/llama.cpp/src/models/models.h
@@ -1,23 +1,51 @@
 #pragma once
 
-#include "../llama-model.h"
-#include "../llama-graph.h"
+#include "llama-model.h"
+#include "llama-graph.h"
 
-// TODO: remove in follow-up PR - move to .cpp files
-#include "../llama-memory-recurrent.h"
+// note: almost all graphs require atleast sqrtf, so include cmath globally
 #include 
 
-struct llm_graph_context_mamba : public llm_graph_context {
-    llm_graph_context_mamba(const llm_graph_params & params);
+//
+// base classes
+//
 
-    virtual ~llm_graph_context_mamba() = default;
+struct llm_build_mamba_base : public llm_graph_context {
+    llm_build_mamba_base(const llm_graph_params & params);
+
+    virtual ~llm_build_mamba_base() = default;
 
     ggml_tensor * build_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il);
     ggml_tensor * build_mamba2_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il) const;
 
 };
 
-// Base class for RWKV-related models
+struct llm_build_delta_net_base : public llm_graph_context {
+    llm_build_delta_net_base(const llm_graph_params & params);
+
+    virtual ~llm_build_delta_net_base() = default;
+
+    // returns pair of output and new state
+    std::pair build_delta_net_chunking(
+                ggml_tensor * q,
+                ggml_tensor * k,
+                ggml_tensor * v,
+                ggml_tensor * g,
+                ggml_tensor * b,
+                ggml_tensor * s,
+                        int   il);
+
+    // returns pair of output and new state
+    std::pair build_delta_net_autoregressive(
+                ggml_tensor * q,
+                ggml_tensor * k,
+                ggml_tensor * v,
+                ggml_tensor * g,
+                ggml_tensor * b,
+                ggml_tensor * s,
+                int           il);
+};
+
 struct llm_build_rwkv6_base : public llm_graph_context {
     const llama_model & model;
 
@@ -58,6 +86,10 @@ struct llm_build_rwkv7_base : public llm_graph_context {
                                        int                  il) const;
 };
 
+//
+// models
+//
+
 struct llm_build_afmoe : public llm_graph_context {
     llm_build_afmoe(const llama_model & model, const llm_graph_params & params);
 };
@@ -158,6 +190,10 @@ struct llm_build_ernie4_5_moe : public llm_graph_context {
     llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params);
 };
 
+struct llm_build_paddleocr : public llm_graph_context {
+    llm_build_paddleocr(const llama_model & model, const llm_graph_params & params);
+};
+
 template 
 struct llm_build_exaone4 : public llm_graph_context {
     llm_build_exaone4(const llama_model & model, const llm_graph_params & params);
@@ -167,11 +203,15 @@ struct llm_build_exaone : public llm_graph_context {
     llm_build_exaone(const llama_model & model, const llm_graph_params & params);
 };
 
+struct llm_build_exaone_moe : public llm_graph_context {
+    llm_build_exaone_moe(const llama_model & model, const llm_graph_params & params);
+};
+
 struct llm_build_falcon : public llm_graph_context {
     llm_build_falcon(const llama_model & model, const llm_graph_params & params);
 };
 
-struct llm_build_falcon_h1 : public llm_graph_context_mamba {
+struct llm_build_falcon_h1 : public llm_build_mamba_base {
     llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params);
 };
 
@@ -249,7 +289,7 @@ struct llm_build_granite : public llm_graph_context {
         const int                 il);
 };
 
-struct llm_build_granite_hybrid : public llm_graph_context_mamba {
+struct llm_build_granite_hybrid : public llm_build_mamba_base {
     llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params);
     ggml_tensor * build_layer_ffn(ggml_tensor * cur, ggml_tensor * inpSA, const llama_model & model, const int il);
     ggml_tensor * build_attention_layer(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn,
@@ -280,19 +320,44 @@ struct llm_build_jais : public llm_graph_context {
     llm_build_jais(const llama_model & model, const llm_graph_params & params);
 };
 
-struct llm_build_jamba : public llm_graph_context_mamba {
+struct llm_build_jais2 : public llm_graph_context {
+    llm_build_jais2(const llama_model & model, const llm_graph_params & params);
+};
+
+struct llm_build_jamba : public llm_build_mamba_base {
     llm_build_jamba(const llama_model & model, const llm_graph_params & params);
 };
 
-struct llm_build_lfm2 : public llm_graph_context {
+struct llm_build_kimi_linear : public llm_build_delta_net_base {
+    llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params);
+
+    std::pair build_kda_autoregressive(
+                ggml_tensor * q,
+                ggml_tensor * k,
+                ggml_tensor * v,
+                ggml_tensor * gk,
+                ggml_tensor * beta,
+                ggml_tensor * state,
+                        int   il);
+
+    std::pair build_kda_chunking(
+                ggml_tensor * q,
+                ggml_tensor * k,
+                ggml_tensor * v,
+                ggml_tensor * gk,
+                ggml_tensor * beta,
+                ggml_tensor * state,
+                ggml_tensor * causal_mask,
+                ggml_tensor * identity,
+                ggml_tensor * diag_mask,
+                        int   il);
+
     const llama_model & model;
+};
 
+template 
+struct llm_build_lfm2 : public llm_graph_context {
     llm_build_lfm2(const llama_model & model, const llm_graph_params & params);
-    ggml_tensor * build_moe_feed_forward(ggml_tensor * cur, int il) const;
-    ggml_tensor * build_dense_feed_forward(ggml_tensor * cur, int il) const;
-    ggml_tensor * build_attn_block(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn, int il) const;
-    ggml_tensor * build_shortconv_block(ggml_tensor * cur, llm_graph_input_rs * inp_recr, int il);
-
 };
 
 struct llm_build_llada : public llm_graph_context {
@@ -303,6 +368,7 @@ struct llm_build_llada_moe : public llm_graph_context {
     llm_build_llada_moe(const llama_model & model, const llm_graph_params & params);
 };
 
+template 
 struct llm_build_llama : public llm_graph_context {
     llm_build_llama(const llama_model & model, const llm_graph_params & params);
 };
@@ -311,10 +377,18 @@ struct llm_build_llama_iswa : public llm_graph_context {
     llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params);
 };
 
-struct llm_build_mamba : public llm_graph_context_mamba {
+struct llm_build_maincoder : public llm_graph_context {
+    llm_build_maincoder(const llama_model & model, const llm_graph_params & params);
+};
+
+struct llm_build_mamba : public llm_build_mamba_base {
     llm_build_mamba(const llama_model & model, const llm_graph_params & params);
 };
 
+struct llm_build_mimo2_iswa : public llm_graph_context {
+    llm_build_mimo2_iswa(const llama_model & model, const llm_graph_params & params);
+};
+
 struct llm_build_minicpm3 : public llm_graph_context {
     llm_build_minicpm3(const llama_model & model, const llm_graph_params & params);
 };
@@ -327,6 +401,10 @@ struct llm_build_mistral3 : public llm_graph_context {
     llm_build_mistral3(const llama_model & model, const llm_graph_params & params);
 };
 
+struct llm_build_modern_bert : public llm_graph_context {
+    llm_build_modern_bert(const llama_model & model, const llm_graph_params & params);
+};
+
 struct llm_build_mpt : public llm_graph_context {
     llm_build_mpt(const llama_model & model, const llm_graph_params & params);
 };
@@ -335,17 +413,21 @@ struct llm_build_nemotron : public llm_graph_context {
     llm_build_nemotron(const llama_model & model, const llm_graph_params & params);
 };
 
-struct llm_build_nemotron_h : public llm_graph_context_mamba {
+struct llm_build_nemotron_h : public llm_build_mamba_base {
     llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params);
-    ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il);
+    ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il);
     ggml_tensor * build_attention_layer(ggml_tensor * cur, llm_graph_input_attn_kv * inp_attn,
-        const llama_model & model, const int64_t n_embd_head, const int il);
+        const llama_model & model, int64_t n_embd_head, int il);
 };
 
 struct llm_build_neo_bert : public llm_graph_context {
     llm_build_neo_bert(const llama_model & model, const llm_graph_params & params);
 };
 
+struct llm_build_eurobert : public llm_graph_context {
+    llm_build_eurobert(const llama_model & model, const llm_graph_params & params);
+};
+
 template 
 struct llm_build_olmo2 : public llm_graph_context {
     llm_build_olmo2(const llama_model & model, const llm_graph_params & params);
@@ -384,7 +466,7 @@ struct llm_build_phi3 : public llm_graph_context {
     llm_build_phi3(const llama_model & model, const llm_graph_params & params);
 };
 
-struct llm_build_plamo2 : public llm_graph_context_mamba {
+struct llm_build_plamo2 : public llm_build_mamba_base {
     llm_build_plamo2(const llama_model & model, const llm_graph_params & params);
     private:
         ggml_tensor * build_plamo2_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il);
@@ -396,6 +478,11 @@ struct llm_build_plamo : public llm_graph_context {
     llm_build_plamo(const llama_model & model, const llm_graph_params & params);
 };
 
+template 
+struct llm_build_plamo3 : public llm_graph_context {
+    llm_build_plamo3(const llama_model & model, const llm_graph_params & params);
+};
+
 struct llm_build_plm : public llm_graph_context {
     llm_build_plm(const llama_model & model, const llm_graph_params & params);
 };
@@ -427,7 +514,8 @@ struct llm_build_qwen3vl : public llm_graph_context {
 struct llm_build_qwen3vlmoe : public llm_graph_context {
     llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params);
 };
-struct llm_build_qwen3next : public llm_graph_context_mamba {
+
+struct llm_build_qwen3next : public llm_build_delta_net_base {
     llm_build_qwen3next(const llama_model & model, const llm_graph_params & params);
 private:
     ggml_tensor * build_layer_attn(
@@ -439,35 +527,44 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
     ggml_tensor * build_layer_attn_linear(
          llm_graph_input_rs * inp,
                 ggml_tensor * cur,
-                ggml_tensor * causal_mask,
-                ggml_tensor * identity,
-                ggml_tensor * diag_mask,
                         int   il);
 
     ggml_tensor * build_layer_ffn(
                 ggml_tensor * cur,
                         int   il);
 
-    ggml_tensor * build_delta_net_chunking(
-                ggml_tensor * q,
-                ggml_tensor * k,
-                ggml_tensor * v,
-                ggml_tensor * g,
-                ggml_tensor * beta,
-                ggml_tensor * state,
-                ggml_tensor * causal_mask,
-                ggml_tensor * identity,
-                ggml_tensor * diag_mask,
+    ggml_tensor * build_norm_gated(
+                ggml_tensor * input,
+                ggml_tensor * weights,
+                ggml_tensor * gate,
+                        int   layer);
+
+    // returns pair of qkv, z
+    std::pair build_qkvz(
+                ggml_tensor * input,
                         int   il);
 
-    ggml_tensor * build_delta_net_autoregressive(
-                ggml_tensor * q,
-                ggml_tensor * k,
-                ggml_tensor * v,
-                ggml_tensor * g,
-                ggml_tensor * beta,
-                ggml_tensor * state,
-                int           il);
+    const llama_model & model;
+};
+
+struct llm_build_qwen35 : public llm_build_delta_net_base {
+    llm_build_qwen35(const llama_model & model, const llm_graph_params & params);
+private:
+    ggml_tensor * build_layer_attn(
+    llm_graph_input_attn_kv * inp_attn,
+                ggml_tensor * cur,
+                ggml_tensor * inp_pos,
+                        int * sections,
+                        int   il);
+
+    ggml_tensor * build_layer_attn_linear(
+         llm_graph_input_rs * inp,
+                ggml_tensor * cur,
+                        int   il);
+
+    ggml_tensor * build_layer_ffn(
+                ggml_tensor * cur,
+                        int   il);
 
     ggml_tensor * build_norm_gated(
                 ggml_tensor * input,
@@ -475,6 +572,45 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
                 ggml_tensor * gate,
                         int   layer);
 
+    // returns pair of qkv, z
+    std::pair build_qkvz(
+                ggml_tensor * input,
+                        int   il);
+
+    const llama_model & model;
+};
+
+// TODO: derive llm_build_delta_net_base instead
+struct llm_build_qwen35moe : public llm_build_delta_net_base {
+    llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params);
+private:
+    ggml_tensor * build_layer_attn(
+    llm_graph_input_attn_kv * inp_attn,
+                ggml_tensor * cur,
+                ggml_tensor * inp_pos,
+                        int * sections,
+                        int   il);
+
+    ggml_tensor * build_layer_attn_linear(
+         llm_graph_input_rs * inp,
+                ggml_tensor * cur,
+                        int   il);
+
+    ggml_tensor * build_layer_ffn(
+                ggml_tensor * cur,
+                        int   il);
+
+    ggml_tensor * build_norm_gated(
+                ggml_tensor * input,
+                ggml_tensor * weights,
+                ggml_tensor * gate,
+                        int   layer);
+
+    // returns pair of qkv, z
+    std::pair build_qkvz(
+                ggml_tensor * input,
+                        int   il);
+
     const llama_model & model;
 };
 
@@ -532,6 +668,10 @@ struct llm_build_starcoder : public llm_graph_context {
     llm_build_starcoder(const llama_model & model, const llm_graph_params & params);
 };
 
+struct llm_build_step35_iswa : public llm_graph_context {
+    llm_build_step35_iswa(const llama_model & model, const llm_graph_params & params);
+};
+
 struct llm_build_t5_dec : public llm_graph_context {
     llm_build_t5_dec(const llama_model & model, const llm_graph_params & params);
 };
diff --git a/llama/llama.cpp/src/models/modern-bert.cpp b/llama/llama.cpp/src/models/modern-bert.cpp
new file mode 100644
index 00000000000..32066c712b4
--- /dev/null
+++ b/llama/llama.cpp/src/models/modern-bert.cpp
@@ -0,0 +1,109 @@
+#include "models.h"
+
+llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+    ggml_tensor * inp_pos = build_inp_pos();
+
+    // construct input embeddings (token, type, position)
+    inpL = build_inp_embd(model.tok_embd);
+    cb(inpL, "inp_embd", -1);
+
+    // embed layer norm
+    inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, -1);
+    cb(inpL, "inp_norm", -1);
+
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    auto * inp_attn = build_attn_inp_no_cache();
+
+    for (int il = 0; il < n_layer; ++il) {
+        const float freq_base_l  = model.get_rope_freq_base(cparams, il);
+        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+
+        cur = inpL;
+
+        // attention layer norm
+        if (model.layers[il].attn_norm) {
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+        }
+
+        // self attention
+        cur = build_lora_mm(model.layers[il].wqkv, cur);
+        cb(cur, "wqkv", il);
+
+        const size_t type_size = ggml_type_size(cur->type);
+
+        ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*type_size, cur->nb[1], 0*type_size*(n_embd));
+        ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*type_size, cur->nb[1], 1*type_size*(n_embd));
+        ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*type_size, cur->nb[1], 1*type_size*(n_embd + n_embd_gqa));
+
+        // RoPE
+        Qcur = ggml_rope_ext(
+                ctx0, Qcur, inp_pos, nullptr,
+                n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                ext_factor, attn_factor, beta_fast, beta_slow
+                );
+
+        Kcur = ggml_rope_ext(
+                ctx0, Kcur, inp_pos, nullptr,
+                n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                ext_factor, attn_factor, beta_fast, beta_slow
+                );
+
+        cb(Qcur, "Qcur", il);
+        cb(Kcur, "Kcur", il);
+        cb(Vcur, "Vcur", il);
+
+        cur = build_attn(inp_attn,
+                    model.layers[il].wo, nullptr,
+                    Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+        cb(cur, "kqv_out", il);
+
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+            inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+        }
+
+        // re-add the layer input
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+        cb(ffn_inp, "ffn_inp", il);
+
+        // attention layer norm
+        cur = build_norm(ffn_inp,
+                model.layers[il].ffn_norm, NULL,
+                LLM_NORM, il);
+        cb(cur, "ffn_norm", il);
+
+        cur = build_ffn(cur,
+                model.layers[il].ffn_up,   NULL, NULL,
+                NULL,                      NULL, NULL,
+                model.layers[il].ffn_down, NULL, NULL,
+                NULL,
+                LLM_FFN_GEGLU, LLM_FFN_SEQ, il);
+
+        // attentions bypass the intermediate layer
+        cur = ggml_add(ctx0, cur, ffn_inp);
+
+        // input for next layer
+        inpL = cur;
+    }
+
+    cur = inpL;
+
+    cur = build_norm(cur,
+            model.output_norm, NULL,
+            LLM_NORM, -1);
+    cb(cur, "final_norm_out", -1);
+
+    res->t_embd = cur;
+    ggml_build_forward_expand(gf, cur);
+}
diff --git a/llama/llama.cpp/src/models/nemotron-h.cpp b/llama/llama.cpp/src/models/nemotron-h.cpp
index eb135e63f18..d61d62a8c96 100644
--- a/llama/llama.cpp/src/models/nemotron-h.cpp
+++ b/llama/llama.cpp/src/models/nemotron-h.cpp
@@ -1,9 +1,7 @@
 #include "models.h"
 
-
-
 llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params) :
-    llm_graph_context_mamba(params) {
+    llm_build_mamba_base(params) {
     const int64_t n_embd_head = hparams.n_embd_head_v;
     GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
 
@@ -65,9 +63,9 @@ llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_
 ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor *             cur,
                                                           llm_graph_input_attn_kv * inp_attn,
                                                           const llama_model &       model,
-                                                          const int64_t             n_embd_head,
-                                                          const int                 il) {
-    // compute Q and K and (optionally) RoPE them
+                                                                int64_t             n_embd_head,
+                                                                int                 il) {
+    // compute Q and K
     ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
     cb(Qcur, "Qcur", il);
     if (model.layers[il].bq) {
@@ -106,7 +104,7 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor *
     return cur;
 }
 
-ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il) {
+ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il) {
     if (model.layers[il].ffn_gate_inp == nullptr) {
         cur = build_ffn(cur,
                 model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
diff --git a/llama/llama.cpp/src/models/openai-moe-iswa.cpp b/llama/llama.cpp/src/models/openai-moe-iswa.cpp
index 96596709eec..dbe3ca1851f 100644
--- a/llama/llama.cpp/src/models/openai-moe-iswa.cpp
+++ b/llama/llama.cpp/src/models/openai-moe-iswa.cpp
@@ -14,6 +14,9 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model,
     ggml_tensor * inp_out_ids = build_inp_out_ids();
 
     for (int il = 0; il < n_layer; ++il) {
+        const float freq_base_l  = model.get_rope_freq_base (cparams, il);
+        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+
         ggml_tensor * inpSA = inpL;
 
         // norm
@@ -49,13 +52,13 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model,
 
             Qcur = ggml_rope_ext(
                     ctx0, Qcur, inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
                     ext_factor, attn_factor, beta_fast, beta_slow
                     );
 
             Kcur = ggml_rope_ext(
                     ctx0, Kcur, inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
                     ext_factor, attn_factor, beta_fast, beta_slow
                     );
 
diff --git a/llama/llama.cpp/src/models/openelm.cpp b/llama/llama.cpp/src/models/openelm.cpp
index ee46a3375e8..fbf682ec835 100644
--- a/llama/llama.cpp/src/models/openelm.cpp
+++ b/llama/llama.cpp/src/models/openelm.cpp
@@ -43,7 +43,7 @@ llm_build_openelm::llm_build_openelm(const llama_model & model, const llm_graph_
             ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head);
             cb(Kcur, "Kcur", il);
 
-            ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv)));
+            ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv));
             cb(Vcur, "Vcur", il);
 
             Qcur = build_norm(Qcur,
diff --git a/llama/llama.cpp/src/models/paddleocr.cpp b/llama/llama.cpp/src/models/paddleocr.cpp
new file mode 100644
index 00000000000..39a368df53b
--- /dev/null
+++ b/llama/llama.cpp/src/models/paddleocr.cpp
@@ -0,0 +1,122 @@
+#include "models.h"
+
+llm_build_paddleocr::llm_build_paddleocr(const llama_model & model, const llm_graph_params & params) :
+    llm_graph_context(params) {
+
+    // NOTE: same with qwen2vl.cpp, but bias tensors are optional
+
+    const int64_t n_embd_head = hparams.n_embd_head_v;
+
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+
+    int sections[4];
+    std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
+
+    // inp_pos - contains the positions
+    ggml_tensor * inp_pos = build_inp_pos();
+
+    auto * inp_attn = build_attn_inp_kv();
+
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_tensor * inpSA = inpL;
+
+        // norm
+        {
+            cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+        }
+        // self-attention
+        {
+            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+            cb(Qcur, "Qcur", il);
+            if (model.layers[il].bq) {
+                Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                cb(Qcur, "Qcur", il);
+            }
+            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+            cb(Kcur, "Kcur", il);
+            if (model.layers[il].bk) {
+                Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                cb(Kcur, "Kcur", il);
+            }
+            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+            cb(Vcur, "Vcur", il);
+            if (model.layers[il].bv) {
+                Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                cb(Vcur, "Vcur", il);
+            }
+            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+            Qcur = ggml_rope_multi(
+                    ctx0, Qcur, inp_pos, nullptr,
+                    n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            Kcur = ggml_rope_multi(
+                    ctx0, Kcur, inp_pos, nullptr,
+                    n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            cb(Qcur, "Qcur", il);
+            cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
+
+            cur = build_attn(inp_attn,
+                    model.layers[il].wo, model.layers[il].bo,
+                    Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+        }
+        if (il == n_layer - 1) {
+            // skip computing output for unused tokens
+            cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+        cb(ffn_inp, "ffn_inp", il);
+
+        // feed-forward network
+        {
+            cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up, NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+        }
+        cur = ggml_add(ctx0, cur, ffn_inp);
+
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        // input for next layer
+        inpL = cur;
+    }
+    cur = inpL;
+
+    cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
+
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    // lm_head
+    cur = build_lora_mm(model.output, cur);
+
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
diff --git a/llama/llama.cpp/src/models/plamo2.cpp b/llama/llama.cpp/src/models/plamo2.cpp
index 31115a08f95..3af236843bb 100644
--- a/llama/llama.cpp/src/models/plamo2.cpp
+++ b/llama/llama.cpp/src/models/plamo2.cpp
@@ -1,7 +1,9 @@
 #include "models.h"
 
+#include "llama-memory-recurrent.h"
+
 llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_params & params) :
-    llm_graph_context_mamba(params) {
+    llm_build_mamba_base(params) {
     ggml_tensor * cur;
     ggml_tensor * inpL;
 
diff --git a/llama/llama.cpp/src/models/plamo3.cpp b/llama/llama.cpp/src/models/plamo3.cpp
new file mode 100644
index 00000000000..55c8064679e
--- /dev/null
+++ b/llama/llama.cpp/src/models/plamo3.cpp
@@ -0,0 +1,128 @@
+#include "models.h"
+
+template 
+llm_build_plamo3::llm_build_plamo3(const llama_model & model, const llm_graph_params & params) :
+    llm_graph_context(params) {
+    const int64_t head_dim_q = hparams.n_embd_head_k;
+    const int64_t head_dim_v = hparams.n_embd_head_v;
+
+    ggml_tensor * cur;
+    ggml_tensor * inpL = build_inp_embd(model.tok_embd);
+    ggml_tensor * inp_pos = build_inp_pos();
+
+    using inp_attn_type = std::conditional_t;
+    inp_attn_type * inp_attn = nullptr;
+
+    if constexpr (iswa) {
+        inp_attn = build_attn_inp_kv_iswa();
+    } else {
+        inp_attn = build_attn_inp_kv();
+    }
+
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_tensor * residual = inpL;
+
+        float freq_base_l  = 0.0f;
+        float freq_scale_l = 0.0f;
+        if constexpr (iswa) {
+            freq_base_l  = model.get_rope_freq_base (cparams, il);
+            freq_scale_l = model.get_rope_freq_scale(cparams, il);
+        } else {
+            freq_base_l  = freq_base;
+            freq_scale_l = freq_scale;
+        }
+
+        cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+        cb(cur, "attn_norm", il);
+
+        ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur);
+        cb(cur, "wqkv", il);
+
+        const int32_t n_head    = hparams.n_head(il);
+        const int32_t n_head_kv = hparams.n_head_kv(il);
+
+        const int64_t q_offset = 0;
+        const int64_t k_offset = head_dim_q * n_head;
+        const int64_t v_offset = k_offset + head_dim_q * n_head_kv;
+
+        ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, head_dim_q, n_head, n_tokens,
+                head_dim_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv));
+        ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, head_dim_q, n_head_kv, n_tokens,
+                head_dim_q * sizeof(float), qkv->nb[1], k_offset * ggml_element_size(qkv));
+        ggml_tensor * Vcur = ggml_view_3d(ctx0, qkv, head_dim_v, n_head_kv, n_tokens,
+                head_dim_v * sizeof(float), qkv->nb[1], v_offset * ggml_element_size(qkv));
+
+        cb(Qcur, "Qcur", il);
+        cb(Kcur, "Kcur", il);
+        cb(Vcur, "Vcur", il);
+
+        Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+        cb(Qcur, "attn_q_norm", il);
+        Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+        cb(Kcur, "attn_k_norm", il);
+
+        Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
+                n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                ext_factor, attn_factor, beta_fast, beta_slow);
+        Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
+                n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                ext_factor, attn_factor, beta_fast, beta_slow);
+
+        const float attn_scale = 1.0f / sqrtf(float(head_dim_q));
+
+        cur = build_attn(inp_attn,
+                model.layers[il].wo, NULL,
+                Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, attn_scale, il);
+        cb(cur, "attn_out", il);
+
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur      = ggml_get_rows(ctx0, cur, inp_out_ids);
+            residual = ggml_get_rows(ctx0, residual, inp_out_ids);
+        }
+
+        cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il);
+        cb(cur, "attn_post_norm", il);
+
+        cur = ggml_add(ctx0, cur, residual);
+        cb(cur, "attn_residual", il);
+
+        residual = cur;
+
+        cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
+        cb(cur, "ffn_norm", il);
+
+        cur = build_ffn(cur,
+                model.layers[il].ffn_up,   NULL, NULL,
+                NULL,                      NULL, NULL,
+                model.layers[il].ffn_down, NULL, NULL,
+                NULL,
+                LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
+        cb(cur, "ffn_out", il);
+
+        cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il);
+        cb(cur, "ffn_post_norm", il);
+
+        cur = ggml_add(ctx0, cur, residual);
+        cb(cur, "ffn_residual", il);
+
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+        inpL = cur;
+    }
+
+    cur = inpL;
+
+    cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
+    res->t_embd = cur;
+
+    cur = build_lora_mm(model.output, cur);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
+
+// Explicit template instantiations
+template struct llm_build_plamo3;
+template struct llm_build_plamo3;
diff --git a/llama/llama.cpp/src/models/plm.cpp b/llama/llama.cpp/src/models/plm.cpp
index 481cbba6907..612a487c564 100644
--- a/llama/llama.cpp/src/models/plm.cpp
+++ b/llama/llama.cpp/src/models/plm.cpp
@@ -5,6 +5,7 @@ llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params &
 
     const uint32_t n_embd_head_qk_rope = hparams.n_rot;
     const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
+
     const uint32_t kv_lora_rank = hparams.n_lora_kv;
 
     ggml_tensor * cur;
diff --git a/llama/llama.cpp/src/models/qwen35.cpp b/llama/llama.cpp/src/models/qwen35.cpp
new file mode 100644
index 00000000000..bacf7a4c2ee
--- /dev/null
+++ b/llama/llama.cpp/src/models/qwen35.cpp
@@ -0,0 +1,386 @@
+#include "models.h"
+
+#include "llama-memory-recurrent.h"
+
+llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_params & params) :
+    llm_build_delta_net_base(params), model(model) {
+    const int64_t n_embd_head = hparams.n_embd_head_v;
+
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+    int sections[4];
+    std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
+
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+
+    cb(inpL, "model.input_embed", -1);
+
+    auto * inp = build_inp_mem_hybrid();
+
+    ggml_tensor * inp_pos     = build_inp_pos();
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_tensor * inpSA = inpL;
+
+        cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
+        cb(cur, "attn_norm", il);
+
+        ggml_build_forward_expand(gf, cur);
+
+        // Determine layer type and build appropriate attention mechanism
+        if (hparams.is_recurrent(il)) {
+            // Linear attention layer (gated delta net)
+            cur = build_layer_attn_linear(inp->get_recr(), cur, il);
+        } else {
+            // Full attention layer
+            cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
+        }
+
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+
+        // Residual connection
+        cur = ggml_add(ctx0, cur, inpSA);
+        cb(cur, "attn_residual", il);
+
+        // Save the tensor before post-attention norm for residual connection
+        ggml_tensor * ffn_residual = cur;
+
+        // Post-attention norm
+        ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il);
+        cb(attn_post_norm, "attn_post_norm", il);
+
+        // Dense FFN layer - without residual connection
+        cur = build_layer_ffn(attn_post_norm, il);
+        cb(cur, "ffn_out", il);
+
+        // Residual connection for FFN - add to the tensor from before post_attention_layernorm
+        cur = ggml_add(ctx0, cur, ffn_residual);
+        cb(cur, "post_ffn", il);
+
+        // Input for next layer
+        inpL = cur;
+    }
+    cur = inpL;
+
+    // Final norm
+    cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
+
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    // LM head
+    cur = build_lora_mm(model.output, cur);
+
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
+
+std::pair llm_build_qwen35::build_qkvz(
+                ggml_tensor * input,
+                        int   il) {
+    const int64_t n_seqs       = ubatch.n_seqs;
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+    ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input);
+    qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs);
+    cb(qkv_mixed, "linear_attn_qkv_mixed", il);
+
+    ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input);
+    cb(z, "z", il);
+
+    return { qkv_mixed, z };
+}
+
+ggml_tensor * llm_build_qwen35::build_norm_gated(
+        ggml_tensor * input,
+        ggml_tensor * weights,
+        ggml_tensor * gate,
+        int           layer) {
+    ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer);
+    ggml_tensor * gated_silu = ggml_silu(ctx0, gate);
+
+    return ggml_mul(ctx0, normalized, gated_silu);
+}
+
+ggml_tensor * llm_build_qwen35::build_layer_attn(
+        llm_graph_input_attn_kv * inp,
+        ggml_tensor *             cur,
+        ggml_tensor *             inp_pos,
+        int *                     sections,
+        int                       il) {
+    const int64_t n_embd_head = hparams.n_embd_head_v;
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+    // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention
+
+    // Qwen3Next uses a single Q projection that outputs query + gate
+    ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); // [ (n_embd_head * 2) * n_head, n_tokens ]
+    cb(Qcur_full, "Qcur_full", il);
+
+    ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens,
+        ggml_element_size(Qcur_full) * n_embd_head * 2,
+        ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, 0);
+    cb(Qcur, "Qcur_reshaped", il);
+
+    // Apply Q normalization
+    Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
+    cb(Qcur, "Qcur_normed", il);
+
+    ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+    cb(Kcur, "Kcur", il);
+
+    ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+    cb(Vcur, "Vcur", il);
+
+    // Apply K normalization
+    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+    Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
+    cb(Kcur, "Kcur_normed", il);
+
+    ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens,
+        ggml_element_size(Qcur_full) * n_embd_head * 2,
+        ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head,
+        ggml_element_size(Qcur_full) * n_embd_head);
+    gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
+    cb(gate, "gate_reshaped", il);
+
+    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+    // Apply MRoPE
+    Qcur = ggml_rope_multi(
+            ctx0, Qcur, inp_pos, nullptr,
+            n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
+            ext_factor, attn_factor, beta_fast, beta_slow
+            );
+
+    Kcur = ggml_rope_multi(
+            ctx0, Kcur, inp_pos, nullptr,
+            n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
+            ext_factor, attn_factor, beta_fast, beta_slow
+            );
+
+    cb(Qcur, "Qcur", il);
+    cb(Kcur, "Kcur", il);
+    cb(Vcur, "Vcur", il);
+
+    // Attention computation
+    const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+
+    cur = build_attn(inp,
+                nullptr, nullptr,
+                Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
+    cb(cur, "attn_pregate", il);
+
+    ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate);
+    cb(gate_sigmoid, "gate_sigmoid", il);
+
+    cur = ggml_mul(ctx0, cur, gate_sigmoid);
+    cb(cur, "attn_gated", il);
+
+    cur = build_lora_mm(model.layers[il].wo, cur);
+    cb(cur, "attn_output", il);
+
+    return cur;
+}
+
+ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
+        llm_graph_input_rs * inp,
+        ggml_tensor *        cur,
+        int                  il) {
+    const auto * mctx_cur = inp->mctx;
+
+    const int64_t d_inner      = hparams.ssm_d_inner;
+    const int64_t n_seqs       = ubatch.n_seqs;
+    const int64_t head_k_dim   = hparams.ssm_d_state;
+    const int64_t num_k_heads  = hparams.ssm_n_group;
+    const int64_t num_v_heads  = hparams.ssm_dt_rank;
+    const int64_t head_v_dim   = d_inner / num_v_heads;
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+    const auto kv_head = mctx_cur->get_head();
+
+    GGML_ASSERT(n_seqs != 0);
+    GGML_ASSERT(ubatch.equal_seqs());
+    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+
+    // Input projections
+    auto qkvz = build_qkvz(cur, il);
+    ggml_tensor * qkv_mixed = qkvz.first;
+    ggml_tensor * z         = qkvz.second;
+
+    ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur);
+    beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs);
+    cb(beta, "beta", il);
+
+    beta = ggml_sigmoid(ctx0, beta);
+
+    ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur);
+    alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs);
+    cb(alpha, "alpha", il);
+
+    ggml_tensor * alpha_biased   = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
+    ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
+    cb(alpha_softplus, "a_softplus", il);
+
+    ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a);  // -A_log.exp() * softplus
+    cb(gate, "gate", il);
+
+    gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs);
+
+    // Get convolution states from cache
+    ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
+    ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
+
+    // Build the convolution states tensor
+    ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
+    cb(conv_states, "conv_states", il);
+
+    // Calculate convolution kernel size
+    ggml_tensor * conv_kernel      = model.layers[il].ssm_conv1d;
+    const int64_t conv_kernel_size = conv_kernel->ne[0];
+    const int64_t conv_channels    = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
+
+    conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
+    cb(conv_states, "conv_states_reshaped", il);
+
+    qkv_mixed = ggml_transpose(ctx0, qkv_mixed);
+    cb(qkv_mixed, "qkv_mixed_transposed", il);
+
+    ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
+    cb(conv_input, "conv_input", il);
+
+    // Update convolution state cache
+    // Extract the last (conv_kernel_size - 1) states from conv_input
+    ggml_tensor * last_conv_states =
+        ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1],
+                     conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
+    cb(last_conv_states, "last_conv_states", il);
+
+    ggml_tensor * state_update_target =
+        ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs,
+                     kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
+    cb(state_update_target, "state_update_target", il);
+
+    ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
+
+    ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
+    state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
+    cb(state, "state_predelta", il);
+
+    ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel);
+    cb(conv_output_proper, "conv_output_raw", il);
+
+    ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper);
+    cb(conv_output_silu, "conv_output_silu", il);
+
+    ggml_tensor * conv_qkv_mix = conv_output_silu;
+
+    // Calculate the total conv dimension
+    int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
+    int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim);
+
+    // Extract the convolved Q, K, V from conv_output
+    ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs,
+            ggml_row_size(conv_qkv_mix->type, head_k_dim),
+            nb1_qkv,
+            nb1_qkv * n_seq_tokens,
+            0);
+
+    ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs,
+            ggml_row_size(conv_qkv_mix->type, head_k_dim),
+            nb1_qkv,
+            nb1_qkv * n_seq_tokens,
+            head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
+
+    ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_v_dim, num_v_heads, n_seq_tokens, n_seqs,
+            ggml_row_size(conv_qkv_mix->type, head_v_dim),
+            nb1_qkv,
+            nb1_qkv * n_seq_tokens,
+            ggml_row_size(conv_qkv_mix->type, 2 * head_k_dim * num_k_heads));
+
+    cb(q_conv, "q_conv", il);
+    cb(k_conv, "k_conv", il);
+    cb(v_conv, "v_conv", il);
+
+    const float eps_norm = hparams.f_norm_rms_eps;
+
+    q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm);
+    k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm);
+
+    //q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
+    //k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
+    //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
+
+    // if head keys and value keys are different, repeat to force tensors into matching shapes
+    if (num_k_heads != num_v_heads) {
+        GGML_ASSERT(num_v_heads % num_k_heads == 0);
+        // TODO: try to avoid these explicit repeats by utilizing op broadcast
+        q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
+        k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
+    }
+
+    cb(q_conv, "q_conv_predelta", il);
+    cb(k_conv, "k_conv_predelta", il);
+    cb(v_conv, "v_conv_predelta", il);
+
+    // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
+    std::pair attn_out; // pair of (output, new_state)
+    if (n_seq_tokens == 1) {
+        attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
+    } else {
+        attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il);
+    }
+    ggml_tensor * output    = attn_out.first;
+    ggml_tensor * new_state = attn_out.second;
+    cb(output, "attn_output", il);
+    cb(new_state, "new_state", il);
+
+    // Update the recurrent states
+    ggml_build_forward_expand(gf,
+            ggml_cpy(ctx0, new_state,
+                ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
+                    kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
+
+    // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
+    ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
+
+    // Apply gated normalization: self.norm(core_attn_out, z)
+    ggml_tensor * attn_out_norm = build_norm_gated(output, model.layers[il].ssm_norm, z_2d, il);
+
+    // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
+    ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
+    cb(final_output, "final_output", il);
+
+    // Output projection
+    cur = build_lora_mm(model.layers[il].ssm_out, final_output);
+    cb(cur, "linear_attn_out", il);
+
+    // Reshape back to original dimensions
+    cur = ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
+
+    return cur;
+}
+
+ggml_tensor * llm_build_qwen35::build_layer_ffn(ggml_tensor * cur, const int il) {
+    // Qwen3.5 does not use MoE FFN
+    GGML_ASSERT(model.layers[il].ffn_gate_inp == nullptr);
+
+    cur = build_ffn(cur,
+        model.layers[il].ffn_up, NULL, NULL,
+        model.layers[il].ffn_gate, NULL, NULL,
+        model.layers[il].ffn_down, NULL, NULL,
+        NULL,
+        LLM_FFN_SILU, LLM_FFN_PAR, il);
+    cb(cur, "ffn_out", il);
+
+    return cur;
+}
diff --git a/llama/llama.cpp/src/models/qwen35moe.cpp b/llama/llama.cpp/src/models/qwen35moe.cpp
new file mode 100644
index 00000000000..22d708f2062
--- /dev/null
+++ b/llama/llama.cpp/src/models/qwen35moe.cpp
@@ -0,0 +1,420 @@
+#include "models.h"
+
+#include "llama-memory-recurrent.h"
+
+llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params) :
+    llm_build_delta_net_base(params), model(model) {
+    const int64_t n_embd_head = hparams.n_embd_head_v;
+
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+    int sections[4];
+    std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
+
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+
+    cb(inpL, "model.input_embed", -1);
+
+    auto * inp = build_inp_mem_hybrid();
+
+    ggml_tensor * inp_pos     = build_inp_pos();
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_tensor * inpSA = inpL;
+
+        cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
+        cb(cur, "attn_norm", il);
+
+        ggml_build_forward_expand(gf, cur);
+
+        // Determine layer type and build appropriate attention mechanism
+        if (hparams.is_recurrent(il)) {
+            // Linear attention layer (gated delta net)
+            cur = build_layer_attn_linear(inp->get_recr(), cur, il);
+        } else {
+            // Full attention layer
+            cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
+        }
+
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+
+        // Residual connection
+        cur = ggml_add(ctx0, cur, inpSA);
+        cb(cur, "attn_residual", il);
+
+        // Save the tensor before post-attention norm for residual connection
+        ggml_tensor * ffn_residual = cur;
+
+        // Post-attention norm
+        ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il);
+        cb(attn_post_norm, "attn_post_norm", il);
+
+        // MOE FFN layer
+        cur = build_layer_ffn(attn_post_norm, il);
+        cb(cur, "ffn_out", il);
+
+        // Residual connection for FFN - add to the tensor from before post_attention_layernorm
+        cur = ggml_add(ctx0, cur, ffn_residual);
+        cb(cur, "post_moe", il);
+
+        // Input for next layer
+        inpL = cur;
+    }
+    cur = inpL;
+
+    // Final norm
+    cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
+
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    // LM head
+    cur = build_lora_mm(model.output, cur);
+
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
+
+std::pair llm_build_qwen35moe::build_qkvz(
+                ggml_tensor * input,
+                        int   il) {
+    const int64_t n_seqs       = ubatch.n_seqs;
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+    ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input);
+    qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs);
+    cb(qkv_mixed, "linear_attn_qkv_mixed", il);
+
+    ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input);
+    cb(z, "z", il);
+
+    return { qkv_mixed, z };
+}
+
+ggml_tensor * llm_build_qwen35moe::build_norm_gated(
+        ggml_tensor * input,
+        ggml_tensor * weights,
+        ggml_tensor * gate,
+        int           layer) {
+    ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer);
+    ggml_tensor * gated_silu = ggml_silu(ctx0, gate);
+
+    return ggml_mul(ctx0, normalized, gated_silu);
+}
+
+ggml_tensor * llm_build_qwen35moe ::build_layer_attn(
+        llm_graph_input_attn_kv * inp,
+        ggml_tensor *             cur,
+        ggml_tensor *             inp_pos,
+        int *                     sections,
+        int                       il) {
+    const int64_t n_embd_head = hparams.n_embd_head_v;
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+    // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention
+
+    // Qwen3Next uses a single Q projection that outputs query + gate
+    ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); // [ (n_embd_head * 2) * n_head, n_tokens ]
+    cb(Qcur_full, "Qcur_full", il);
+
+    ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens,
+        ggml_element_size(Qcur_full) * n_embd_head * 2,
+        ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, 0);
+    cb(Qcur, "Qcur_reshaped", il);
+
+    // Apply Q normalization
+    Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
+    cb(Qcur, "Qcur_normed", il);
+
+    ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+    cb(Kcur, "Kcur", il);
+
+    ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+    cb(Vcur, "Vcur", il);
+
+    // Apply K normalization
+    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+    Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
+    cb(Kcur, "Kcur_normed", il);
+
+    ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens,
+        ggml_element_size(Qcur_full) * n_embd_head * 2,
+        ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head,
+        ggml_element_size(Qcur_full) * n_embd_head);
+    gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
+    cb(gate, "gate_reshaped", il);
+
+    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+    // Apply IMRoPE
+    Qcur = ggml_rope_multi(
+            ctx0, Qcur, inp_pos, nullptr,
+            n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
+            ext_factor, attn_factor, beta_fast, beta_slow
+            );
+
+    Kcur = ggml_rope_multi(
+            ctx0, Kcur, inp_pos, nullptr,
+            n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
+            ext_factor, attn_factor, beta_fast, beta_slow
+            );
+
+    cb(Qcur, "Qcur", il);
+    cb(Kcur, "Kcur", il);
+    cb(Vcur, "Vcur", il);
+
+    // Attention computation
+    const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+
+    cur = build_attn(inp,
+                nullptr, nullptr,
+                Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
+    cb(cur, "attn_pregate", il);
+
+    ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate);
+    cb(gate_sigmoid, "gate_sigmoid", il);
+
+    cur = ggml_mul(ctx0, cur, gate_sigmoid);
+    cb(cur, "attn_gated", il);
+
+    cur = build_lora_mm(model.layers[il].wo, cur);
+    cb(cur, "attn_output", il);
+
+    return cur;
+}
+
+ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
+        llm_graph_input_rs * inp,
+        ggml_tensor *        cur,
+        int                  il) {
+    const auto * mctx_cur = inp->mctx;
+
+    const int64_t d_inner      = hparams.ssm_d_inner;
+    const int64_t n_seqs       = ubatch.n_seqs;
+    const int64_t head_k_dim   = hparams.ssm_d_state;
+    const int64_t num_k_heads  = hparams.ssm_n_group;
+    const int64_t num_v_heads  = hparams.ssm_dt_rank;
+    const int64_t head_v_dim   = d_inner / num_v_heads;
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+    const auto kv_head = mctx_cur->get_head();
+
+    GGML_ASSERT(n_seqs != 0);
+    GGML_ASSERT(ubatch.equal_seqs());
+    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+
+    // Input projections
+    auto qkvz = build_qkvz(cur, il);
+    ggml_tensor * qkv_mixed = qkvz.first;
+    ggml_tensor * z         = qkvz.second;
+
+    ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur);
+    beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs);
+    cb(beta, "beta", il);
+
+    beta = ggml_sigmoid(ctx0, beta);
+
+    ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur);
+    alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs);
+    cb(alpha, "alpha", il);
+
+    ggml_tensor * alpha_biased   = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
+    ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
+    cb(alpha_softplus, "a_softplus", il);
+
+    ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a);  // -A_log.exp() * softplus
+    cb(gate, "gate", il);
+
+    gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs);
+
+    // Get convolution states from cache
+    ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
+    ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
+
+    // Build the convolution states tensor
+    ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
+    cb(conv_states, "conv_states", il);
+
+    // Calculate convolution kernel size
+    ggml_tensor * conv_kernel      = model.layers[il].ssm_conv1d;
+    const int64_t conv_kernel_size = conv_kernel->ne[0];
+    const int64_t conv_channels    = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
+
+    conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
+    cb(conv_states, "conv_states_reshaped", il);
+
+    qkv_mixed = ggml_transpose(ctx0, qkv_mixed);
+    cb(qkv_mixed, "qkv_mixed_transposed", il);
+
+    ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
+    cb(conv_input, "conv_input", il);
+
+    // Update convolution state cache
+    // Extract the last (conv_kernel_size - 1) states from conv_input
+    ggml_tensor * last_conv_states =
+        ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1],
+                     conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
+    cb(last_conv_states, "last_conv_states", il);
+
+    ggml_tensor * state_update_target =
+        ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs,
+                     kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
+    cb(state_update_target, "state_update_target", il);
+
+    ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
+
+    ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
+    state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
+    cb(state, "state_predelta", il);
+
+    ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel);
+    cb(conv_output_proper, "conv_output_raw", il);
+
+    ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper);
+    cb(conv_output_silu, "conv_output_silu", il);
+
+    ggml_tensor * conv_qkv_mix = conv_output_silu;
+
+    // Calculate the total conv dimension
+    int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
+    int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim);
+
+    // Extract the convolved Q, K, V from conv_output
+    ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs,
+            ggml_row_size(conv_qkv_mix->type, head_k_dim),
+            nb1_qkv,
+            nb1_qkv * n_seq_tokens,
+            0);
+
+    ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs,
+            ggml_row_size(conv_qkv_mix->type, head_k_dim),
+            nb1_qkv,
+            nb1_qkv * n_seq_tokens,
+            head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
+
+    ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_v_dim, num_v_heads, n_seq_tokens, n_seqs,
+            ggml_row_size(conv_qkv_mix->type, head_v_dim),
+            nb1_qkv,
+            nb1_qkv * n_seq_tokens,
+            ggml_row_size(conv_qkv_mix->type, 2 * head_k_dim * num_k_heads));
+
+    cb(q_conv, "q_conv", il);
+    cb(k_conv, "k_conv", il);
+    cb(v_conv, "v_conv", il);
+
+    const float eps_norm = hparams.f_norm_rms_eps;
+
+    q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm);
+    k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm);
+
+    //q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
+    //k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
+    //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
+
+    // if head keys and value keys are different, repeat to force tensors into matching shapes
+    if (num_k_heads != num_v_heads) {
+        GGML_ASSERT(num_v_heads % num_k_heads == 0);
+        // TODO: try to avoid these explicit repeats by utilizing op broadcast
+        q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
+        k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
+    }
+
+    cb(q_conv, "q_conv_predelta", il);
+    cb(k_conv, "k_conv_predelta", il);
+    cb(v_conv, "v_conv_predelta", il);
+
+    // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
+    std::pair attn_out; // pair of (output, new_state)
+    if (n_seq_tokens == 1) {
+        attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
+    } else {
+        attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il);
+    }
+    ggml_tensor * output    = attn_out.first;
+    ggml_tensor * new_state = attn_out.second;
+    cb(output, "attn_output", il);
+    cb(new_state, "new_state", il);
+
+    // Update the recurrent states
+    ggml_build_forward_expand(gf,
+            ggml_cpy(ctx0, new_state,
+                ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
+                    kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
+
+    // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
+    ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
+
+    // Apply gated normalization: self.norm(core_attn_out, z)
+    ggml_tensor * attn_out_norm = build_norm_gated(output, model.layers[il].ssm_norm, z_2d, il);
+
+    // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
+    ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
+    cb(final_output, "final_output", il);
+
+    // Output projection
+    cur = build_lora_mm(model.layers[il].ssm_out, final_output);
+    cb(cur, "linear_attn_out", il);
+
+    // Reshape back to original dimensions
+    cur = ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
+
+    return cur;
+}
+
+ggml_tensor * llm_build_qwen35moe ::build_layer_ffn(ggml_tensor * cur, const int il) {
+    // Check if this is an MoE layer
+    GGML_ASSERT(model.layers[il].ffn_gate_inp != nullptr);
+
+    ggml_tensor * moe_out =
+        build_moe_ffn(cur,
+            model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
+            model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
+            nullptr,
+            n_expert, n_expert_used, LLM_FFN_SILU,
+            true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il,
+            nullptr, model.layers[il].ffn_gate_up_exps);
+    cb(moe_out, "ffn_moe_out", il);
+
+    // Add shared experts if present - following Qwen3Next reference implementation
+    if (model.layers[il].ffn_up_shexp != nullptr) {
+        ggml_tensor * ffn_shexp =
+            build_ffn(cur,
+                model.layers[il].ffn_up_shexp, NULL, NULL,
+                model.layers[il].ffn_gate_shexp, NULL, NULL,
+                model.layers[il].ffn_down_shexp, NULL, NULL,
+                NULL,
+                LLM_FFN_SILU, LLM_FFN_PAR, il);
+        cb(ffn_shexp, "ffn_shexp", il);
+
+        // Apply shared expert gating as in the reference implementation
+        // The shared expert has its own gate that is sigmoided
+        // Note: ffn_gate_inp_shexp is the shared expert gate (outputs 1 value per token)
+        ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur);
+        cb(shared_gate, "shared_expert_gate", il);
+
+        // Apply sigmoid to the gate
+        shared_gate = ggml_sigmoid(ctx0, shared_gate);
+        cb(shared_gate, "shared_expert_gate_sigmoid", il);
+
+
+        // Apply the gate to the shared expert output
+        ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
+        cb(ffn_shexp, "ffn_shexp_gated", il);
+
+        cur = ggml_add(ctx0, moe_out, ffn_shexp);
+        cb(cur, "ffn_out", il);
+    } else {
+        cur = moe_out;
+    }
+
+    return cur;
+}
diff --git a/llama/llama.cpp/src/models/qwen3next.cpp b/llama/llama.cpp/src/models/qwen3next.cpp
index 775b3135d35..f2621200f23 100644
--- a/llama/llama.cpp/src/models/qwen3next.cpp
+++ b/llama/llama.cpp/src/models/qwen3next.cpp
@@ -1,10 +1,9 @@
-#include "ggml.h"
 #include "models.h"
 
-#define CHUNK_SIZE 64
+#include "llama-memory-recurrent.h"
 
 llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) :
-    llm_graph_context_mamba(params), model(model) {
+    llm_build_delta_net_base(params), model(model) {
     ggml_tensor * cur;
     ggml_tensor * inpL;
 
@@ -16,27 +15,18 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
     ggml_tensor * inp_pos     = build_inp_pos();
     ggml_tensor * inp_out_ids = build_inp_out_ids();
 
-    ggml_tensor * causal_mask =
-        ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f),
-                    GGML_TRI_TYPE_LOWER);
-
-    ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f));
-    ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity);
-
-    ggml_build_forward_expand(gf, causal_mask);
-    ggml_build_forward_expand(gf, identity);
-    ggml_build_forward_expand(gf, diag_mask);
-
     for (int il = 0; il < n_layer; ++il) {
         ggml_tensor * inpSA = inpL;
 
         cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
         cb(cur, "attn_norm", il);
 
+        ggml_build_forward_expand(gf, cur);
+
         // Determine layer type and build appropriate attention mechanism
         if (hparams.is_recurrent(il)) {
             // Linear attention layer (gated delta net)
-            cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il);
+            cur = build_layer_attn_linear(inp->get_recr(), cur, il);
         } else {
             // Full attention layer
             cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il);
@@ -86,344 +76,12 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
     ggml_build_forward_expand(gf, cur);
 }
 
-ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
-        ggml_tensor * q,
-        ggml_tensor * k,
-        ggml_tensor * v,
-        ggml_tensor * g,
-        ggml_tensor * beta,
-        ggml_tensor * state,
-        ggml_tensor * causal_mask,
-        ggml_tensor * identity,
-        ggml_tensor * diag_mask,
-        int           il) {
-    const int64_t S_k      = q->ne[0];
-    const int64_t H_k      = q->ne[1];
-    const int64_t n_tokens = q->ne[2];
-    const int64_t n_seqs   = q->ne[3];
-
-    const int64_t S_v = v->ne[0];
-    const int64_t H_v = v->ne[1];
-
-    GGML_ASSERT(v->ne[2] == n_tokens);
-    GGML_ASSERT(k->ne[2] == n_tokens);
-    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
-    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
-    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
-
-    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
-    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
-
-    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
-
-    const float eps_norm = hparams.f_norm_rms_eps;
-
-    q = ggml_l2_norm(ctx0, q, eps_norm);
-    k = ggml_l2_norm(ctx0, k, eps_norm);
-
-    const float scale = 1.0f / sqrtf(S_v);
-
-    q = ggml_scale(ctx0, q, scale);
-
-    beta = ggml_sigmoid(ctx0, beta);
-
-    cb(q, "q_in", il);
-    cb(k, "k_in", il);
-    cb(v, "v_in", il);
-    cb(beta, "beta_in", il);
-    cb(g, "g_in", il);
-
-    q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
-    k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
-    v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
-    g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
-
-    beta  = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
-    state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
-
-    cb(q, "q_perm", il);
-    cb(k, "k_perm", il);
-    cb(v, "v_perm", il);
-    cb(beta, "beta_perm", il);
-    cb(g, "g_perm", il);
-    cb(state, "state_in", il);
-
-    GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
-    GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
-    GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
-    GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
-
-    // Do padding
-    const int64_t chunk_size = CHUNK_SIZE;
-
-    const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
-    const int64_t n_chunks = (n_tokens + pad) / chunk_size;
-
-    q = ggml_pad(ctx0, q, 0, pad, 0, 0);
-    k = ggml_pad(ctx0, k, 0, pad, 0, 0);
-    v = ggml_pad(ctx0, v, 0, pad, 0, 0);
-    g = ggml_pad(ctx0, g, pad, 0, 0, 0);
-    beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
-
-    cb(q, "q_pad", il);
-    cb(k, "k_pad", il);
-    cb(v, "v_pad", il);
-    cb(beta, "beta_pad", il);
-    cb(g, "g_pad", il);
-
-    ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
-    ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
-
-    cb(v_beta, "v_beta", il);
-    cb(k_beta, "k_beta", il);
-
-    q      = ggml_reshape_4d(ctx0, q,      S_k, chunk_size, n_chunks, H_k * n_seqs);
-    k      = ggml_reshape_4d(ctx0, k,      S_k, chunk_size, n_chunks, H_k * n_seqs);
-    k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs);
-    v      = ggml_reshape_4d(ctx0, v,      S_v, chunk_size, n_chunks, H_v * n_seqs);
-    v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs);
-
-    g    = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs);
-    beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
-
-    ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
-
-    cb(g_cumsum, "g_cumsum", il);
-
-    ggml_tensor * gcs_i = ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs);
-    ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
-
-    ggml_tensor * gcs_j_broadcast =
-        ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
-
-    ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
-
-    cb(decay_mask, "decay_mask", il);
-
-    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
-    decay_mask = ggml_exp(ctx0, decay_mask);
-    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
-
-    ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
-
-    ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
-    ggml_tensor * attn    = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
-
-    cb(attn, "attn_pre_solve", il);
-
-    ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
-    ggml_tensor * lhs        = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
-
-    ggml_tensor * lin_solve  = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
-    attn                     = ggml_mul(ctx0, lin_solve, causal_mask);
-    attn                     = ggml_add(ctx0, attn, identity);
-
-    cb(attn, "attn_solved", il);
-
-    v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
-
-    ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
-    ggml_tensor * gexp       = ggml_exp(ctx0, g_cumsum_t);
-
-    ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
-
-    cb(kbeta_gexp, "kbeta_gexp", il);
-
-    ggml_tensor * k_cumdecay =
-        ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)))));
-
-    cb(k_cumdecay, "k_cumdecay", il);
-
-    ggml_tensor * core_attn_out = nullptr;
-    ggml_tensor * new_state = ggml_dup(ctx0, state);
-
-    cb(new_state, "new_state", il);
-
-    for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
-        auto chunkify = [=](ggml_tensor * t) {
-            return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3],
-                t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
-        };
-
-        auto chunkify_g = [=](ggml_tensor * t) {
-            return ggml_cont(ctx0, ggml_view_4d(ctx0, t, chunk_size, t->ne[1], 1, t->ne[3],
-                t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
-        };
-
-        ggml_tensor * k_chunk = chunkify(k);
-        ggml_tensor * q_chunk = chunkify(q);
-        ggml_tensor * v_chunk = chunkify(v);
-
-        ggml_tensor * g_cs_chunk = chunkify_g(g_cumsum);
-        ggml_tensor * g_cs_chunk_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cs_chunk));
-
-        ggml_tensor * decay_mask_chunk = chunkify(decay_mask);
-        ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay);
-
-        ggml_tensor * gexp_chunk = ggml_exp(ctx0, g_cs_chunk_t);
-
-        // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
-        attn = ggml_mul_mat(ctx0, k_chunk, q_chunk);
-        attn = ggml_mul(ctx0, attn, decay_mask_chunk);
-        attn = ggml_mul(ctx0, attn, diag_mask);
-
-        ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
-
-        // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
-        ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
-
-        // v_new = v_i - v_prime
-        ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime);
-        ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
-
-        // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
-        ggml_tensor * q_g_exp    = ggml_mul(ctx0, q_chunk, gexp_chunk);
-        ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
-
-        // core_attn_out[:, :, i] = attn_inter + attn @ v_new
-        ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn);
-
-        ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
-
-        core_attn_out = core_attn_out == nullptr ? core_attn_out_chunk : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 1);
-
-        // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
-        // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
-        // key_gdiff = key * g_diff.unsqueeze(-1)
-        // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
-        // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
-
-        ggml_tensor * g_cum_last =
-            ggml_cont(ctx0, ggml_view_4d(ctx0, g_cs_chunk_t, g_cs_chunk_t->ne[0], 1, g_cs_chunk_t->ne[2], g_cs_chunk_t->ne[3],
-                                        g_cs_chunk_t->nb[1], g_cs_chunk_t->nb[2], g_cs_chunk_t->nb[3],
-                                        g_cs_chunk_t->nb[0] * (g_cs_chunk_t->ne[1] - 1)));
-
-        ggml_tensor * gexp_last =
-            ggml_reshape_4d(ctx0, ggml_exp(ctx0, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]);
-
-        ggml_tensor * g_cum_last_3d =
-            ggml_reshape_3d(ctx0, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]);
-
-        ggml_tensor * g_cumsum_3d = ggml_reshape_3d(ctx0, g_cs_chunk, g_cs_chunk->ne[0], g_cs_chunk->ne[2], g_cs_chunk->ne[3]);
-
-        ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum_3d, g_cum_last_3d));
-
-        ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
-
-        ggml_tensor * key_gdiff = ggml_mul(ctx0, k_chunk,
-                                        ggml_reshape_4d(ctx0, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1],
-                                                        g_diff_exp->ne[2] * g_diff_exp->ne[3]));
-
-        ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)));
-
-        new_state = ggml_add(ctx0,
-            ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last, gexp_last->ne[0], gexp_last->ne[1], H_v, n_seqs)),
-            ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
-    }
-
-    core_attn_out = ggml_cont_4d(ctx0, core_attn_out, S_v, chunk_size * n_chunks, H_v, n_seqs);
-
-    ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, S_v, n_tokens, H_v, n_seqs, core_attn_out->nb[1], core_attn_out->nb[2], core_attn_out->nb[3], 0);
-    cb(output_tokens, "output_tokens", il);
-
-    // flatten output
-    ggml_tensor * flat_output =
-        ggml_cont_1d(ctx0, ggml_permute(ctx0, output_tokens, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs);
-
-    ggml_tensor * flat_state = ggml_cont_1d(ctx0, new_state, S_v * S_v * H_v * n_seqs);
-
-    return ggml_concat(ctx0, flat_output, flat_state, 0);
-}
-
-ggml_tensor * llm_build_qwen3next::build_delta_net_autoregressive(
-        ggml_tensor * q,
-        ggml_tensor * k,
-        ggml_tensor * v,
-        ggml_tensor * g,
-        ggml_tensor * beta,
-        ggml_tensor * state,
-        int           il) {
-    const int64_t S_k      = q->ne[0];
-    const int64_t H_k      = q->ne[1];
-    const int64_t n_tokens = q->ne[2];
-    const int64_t n_seqs   = q->ne[3];
-
-    const int64_t S_v = v->ne[0];
-    const int64_t H_v = v->ne[1];
-
-    GGML_ASSERT(n_tokens == 1);  // This function is optimized for single token processing
-    GGML_ASSERT(v->ne[2] == n_tokens);
-    GGML_ASSERT(k->ne[2] == n_tokens);
-    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
-    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
-    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
-
-    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
-    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
-
-    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
-
-    const float eps_norm = hparams.f_norm_rms_eps;
-
-    q = ggml_l2_norm(ctx0, q, eps_norm);
-    k = ggml_l2_norm(ctx0, k, eps_norm);
-
-    const float scale = 1.0f / sqrtf(S_v);
-
-    q    = ggml_scale(ctx0, q, scale);
-    beta = ggml_sigmoid(ctx0, beta);
-
-    cb(q, "q_in", il);
-    cb(k, "k_in", il);
-    cb(v, "v_in", il);
-    cb(beta, "beta_in", il);
-    cb(g, "g_in", il);
-
-    state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
-
-    ggml_tensor * g_t    = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs);
-    ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
-
-    // Apply exponential to g_t
-    g_t = ggml_exp(ctx0, g_t);
-
-    // Apply the gated delta rule for the single timestep
-    // last_recurrent_state = last_recurrent_state * g_t
-    state = ggml_mul(ctx0, state, g_t);
-
-    // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
-    ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs);
-    ggml_tensor * kv_mem         = ggml_mul(ctx0, state, k_t_unsqueezed);
-    // we need to sum over dim=-2, so we transpose, sum, then transpose again
-    kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem))));
-
-    // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v)
-    ggml_tensor * v_t    = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
-    // delta = (v_t - kv_mem) * beta_t
-    ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem);  // both should be [S_v, 1, H_v, n_seqs]
-    ggml_tensor * delta  = ggml_mul(ctx0, v_diff, beta_t);
-
-    // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta
-    ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta);
-    state                   = ggml_add(ctx0, state, k_t_delta);
-
-    // Compute the attention output
-    // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
-    ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs);  // unsqueeze q_t
-    ggml_tensor * state_q        = ggml_mul(ctx0, state, q_t_unsqueezed);
-    // again, since it's over dim = -2, transpose, sum, transpose back
-    ggml_tensor * core_attn_out =
-        ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q))));
-
-    // core_attn_out should be [S_v, 1, H_v, n_seqs] after this
-    cb(core_attn_out, "output_tokens", il);
-    cb(state, "new_state", il);
-
-    // flatten output, no need to permute since n_tokens is 1 so [S_v, 1, H_v, n_seqs] and [S_v, H_v, 1, n_seqs] are equivalent memory-layout wise
-    ggml_tensor * flat_output = ggml_reshape_1d(ctx0, core_attn_out, S_v * H_v * n_tokens * n_seqs);
-    ggml_tensor * flat_state  = ggml_reshape_1d(ctx0, state, S_v * S_v * H_v * n_seqs);
-
-    return ggml_concat(ctx0, flat_output, flat_state, 0);
+// utility to get one slice from the third dimension
+// input dim:  [x, y, c, b]
+// output dim: [x, y, 1, b]
+static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) {
+    return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
+        t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
 }
 
 ggml_tensor * llm_build_qwen3next::build_norm_gated(
@@ -456,39 +114,29 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
     // Split Q projection into query and gate
     // The split should be along dimension 0 (the feature dimension)
     ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
-                                             Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0);
+                                            Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0);
+    cb(Qcur, "Qcur_view", il);
+
     ggml_tensor * gate =
         ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
                      Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full));
-    cb(Qcur, "Qcur", il);
     cb(gate, "gate", il);
 
-    // Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention
-    Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-    cb(Qcur, "Qcur_reshaped", il);
-
-    // Apply Q normalization
-    Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
-    cb(Qcur, "Qcur_normed", il);
-
     ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
     cb(Kcur, "Kcur", il);
 
     ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
     cb(Vcur, "Vcur", il);
 
-    // Apply K normalization
     Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-    Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
-    cb(Kcur, "Kcur_normed", il);
+    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
-    // Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads)
-    gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
-    cb(gate, "gate_reshaped", il);
+    Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
+    cb(Qcur, "Qcur_normed", il);
 
-    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+    Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
+    cb(Kcur, "Kcur_normed", il);
 
-    // Apply RoPE
     Qcur = ggml_rope_ext(
             ctx0, Qcur, inp_pos, nullptr,
             n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -503,7 +151,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
     cb(Kcur, "Kcur", il);
     cb(Vcur, "Vcur", il);
 
-    // Attention computation
     const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
     cur = build_attn(inp,
@@ -511,10 +158,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
                 Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
     cb(cur, "attn_pregate", il);
 
-    ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate);
-    cb(gate_sigmoid, "gate_sigmoid", il);
+    // TODO: CUDA is missing non-contiguous unary ops. when implemented: remove this cont
+    gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
+
+    gate = ggml_sigmoid(ctx0, gate);
+    cb(gate, "gate_sigmoid", il);
 
-    cur = ggml_mul(ctx0, cur, gate_sigmoid);
+    gate = ggml_reshape_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
+
+    cur = ggml_mul(ctx0, cur, gate);
     cb(cur, "attn_gated", il);
 
     cur = build_lora_mm(model.layers[il].wo, cur);
@@ -523,12 +175,90 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
     return cur;
 }
 
+std::pair llm_build_qwen3next::build_qkvz(
+                ggml_tensor * input,
+                        int   il) {
+    const int64_t d_inner      = hparams.ssm_d_inner;
+    const int64_t n_seqs       = ubatch.n_seqs;
+    const int64_t head_k_dim   = hparams.ssm_d_state;
+    const int64_t num_k_heads  = hparams.ssm_n_group;
+    const int64_t num_v_heads  = hparams.ssm_dt_rank;
+    const int64_t head_v_dim   = d_inner / num_v_heads;
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+    if (model.layers[il].wqkv) {
+        // optimized path
+        ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input);
+        qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs);
+        cb(qkv_mixed, "linear_attn_qkv_mixed", il);
+
+        ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input);
+        cb(z, "z", il);
+
+        return { qkv_mixed, z };
+    } else {
+        // legacy (slower) path
+        ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, input);
+        cb(mixed_qkvz, "linear_attn_mixed_qkvz", il);
+
+        int64_t       qkvz_new_dim        = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads);
+        ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs);
+
+        // Split mixed_qkvz into query, key, value, z
+        int64_t split_sizes_qkvz[4] = {
+            head_k_dim,                              // query size
+            head_k_dim,                              // key size
+            head_v_dim * num_v_heads / num_k_heads,  // value size
+            head_v_dim * num_v_heads / num_k_heads   // z size
+        };
+
+        ggml_tensor * query =
+            ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs,
+                        mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0);
+        cb(query, "q", il);
+
+        ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs,
+                                        mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
+                                        split_sizes_qkvz[0] * ggml_element_size(mixed_qkvz_reshaped));
+        cb(key, "k", il);
+
+        ggml_tensor * value =
+            ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs,
+                        mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
+                        (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * ggml_element_size(mixed_qkvz_reshaped));
+        cb(value, "v", il);
+
+        ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs,
+                                    mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
+                                    (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * ggml_element_size(mixed_qkvz_reshaped));
+        z = ggml_cont(ctx0, z);
+        cb(z, "z", il);
+
+        // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions
+        // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
+        ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
+        cb(query_flat, "query_flat", il);
+
+        // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
+        ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
+        cb(key_flat, "key_flat", il);
+
+        // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs]
+        ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
+        cb(value_flat, "value_flat", il);
+
+        // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs]
+        ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
+        qkv_mixed               = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
+        cb(qkv_mixed, "qkv_mixed", il);
+
+        return { qkv_mixed, z };
+    }
+}
+
 ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
         llm_graph_input_rs * inp,
         ggml_tensor *        cur,
-        ggml_tensor *        causal_mask,
-        ggml_tensor *        identity,
-        ggml_tensor *        diag_mask,
         int                  il) {
     const auto * mctx_cur = inp->mctx;
 
@@ -547,15 +277,13 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
 
     // Input projections
-    ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, cur);
-    cb(mixed_qkvz, "linear_attn_mixed_qkvz", il);
+    auto qkvz = build_qkvz(cur, il);
+    ggml_tensor * qkv_mixed = qkvz.first;
+    ggml_tensor * z         = qkvz.second;
 
     ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur);
     cb(mixed_ba, "linear_attn_mixed_ba", il);
 
-    int64_t       qkvz_new_dim        = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads);
-    ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs);
-
     // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads]
     int64_t       ba_new_dim        = 2 * num_v_heads / num_k_heads;
     ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs);
@@ -575,86 +303,43 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
                                    split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped));
     cb(a, "a", il);
 
-    // Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
-    ggml_tensor * beta  = ggml_cont_3d(ctx0, b, num_v_heads, n_seq_tokens, n_seqs);
+    // TODO: CUDA is missing non-contiguous unary ops. when implemented: remove this cont
+    b = ggml_cont(ctx0, b);
+
+    ggml_tensor * beta = ggml_sigmoid(ctx0, b);
+
+    // Reshape a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
     ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs);
 
     ggml_tensor * alpha_biased   = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
     ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
     cb(alpha_softplus, "a_softplus", il);
+
     ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a);  // -A_log.exp() * softplus
     cb(gate, "gate", il);
 
-    // Split mixed_qkvz into query, key, value, z
-    int64_t split_sizes_qkvz[4] = {
-        head_k_dim,                              // query size
-        head_k_dim,                              // key size
-        head_v_dim * num_v_heads / num_k_heads,  // value size
-        head_v_dim * num_v_heads / num_k_heads   // z size
-    };
-
-    ggml_tensor * query =
-        ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs,
-                     mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0);
-    cb(query, "q", il);
-
-    ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs,
-                                     mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
-                                     split_sizes_qkvz[0] * sizeof(float));
-    cb(key, "k", il);
-
-    ggml_tensor * value =
-        ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs,
-                     mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
-                     (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float));
-    cb(value, "v", il);
-
-    ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs,
-                                   mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
-                                   (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float));
-    cb(z, "z", il);
-
-    // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions
-    // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
-    ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
-    cb(query_flat, "query_flat", il);
-
-    // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
-    ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
-    cb(key_flat, "key_flat", il);
-
-    // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs]
-    ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
-    cb(value_flat, "value_flat", il);
+    beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs);
+    gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs);
 
     // Get convolution states from cache
     ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
     ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
 
-    // bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state();
-
     // Build the convolution states tensor
     ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
     cb(conv_states, "conv_states", il);
 
-    // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs]
-    ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
-    qkv_mixed               = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
-    cb(qkv_mixed, "qkv_mixed", il);
-
-    qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
-    cb(qkv_mixed, "qkv_mixed_permuted", il);
-
-    // Calculate the total conv dimension
-    int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
-
     // Calculate convolution kernel size
     ggml_tensor * conv_kernel      = model.layers[il].ssm_conv1d;
     const int64_t conv_kernel_size = conv_kernel->ne[0];
     const int64_t conv_channels    = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
-    conv_states                    = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
+
+    conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
     cb(conv_states, "conv_states_reshaped", il);
 
+    qkv_mixed = ggml_transpose(ctx0, qkv_mixed);
+    cb(qkv_mixed, "qkv_mixed_transposed", il);
+
     ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
     cb(conv_input, "conv_input", il);
 
@@ -671,45 +356,54 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     cb(state_update_target, "state_update_target", il);
 
     ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
-    cb(conv_states_all, "conv_states_updated", il);
 
-    // Apply SSM convolution
+    ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
+    state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
+    cb(state, "state_predelta", il);
+
     ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel);
     cb(conv_output_proper, "conv_output_raw", il);
 
-    conv_output_proper = ggml_cont(ctx0, ggml_transpose(ctx0, conv_output_proper));
-    cb(conv_output_proper, "conv_output_pre_silu", il);
-
     ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper);
     cb(conv_output_silu, "conv_output_silu", il);
 
-    ggml_tensor * conv_qkv_mix =
-        ggml_cont_2d(ctx0, ggml_transpose(ctx0, conv_output_silu), qkv_dim, n_seq_tokens * n_seqs);
-    cb(conv_qkv_mix, "conv_qkv_mix", il);
+    ggml_tensor * conv_qkv_mix = conv_output_silu;
+
+    // Calculate the total conv dimension
+    int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
+    int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim);
 
     // Extract the convolved Q, K, V from conv_output
-    ggml_tensor * q_conv =
-        ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], 0);
+    ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs,
+            ggml_row_size(conv_qkv_mix->type, head_k_dim),
+            nb1_qkv,
+            nb1_qkv * n_seq_tokens,
+            0);
+
+    ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs,
+            ggml_row_size(conv_qkv_mix->type, head_k_dim),
+            nb1_qkv,
+            nb1_qkv * n_seq_tokens,
+            head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
+
+    ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_v_dim, num_v_heads, n_seq_tokens, n_seqs,
+            ggml_row_size(conv_qkv_mix->type, head_v_dim),
+            nb1_qkv,
+            nb1_qkv * n_seq_tokens,
+            ggml_row_size(conv_qkv_mix->type, 2 * head_k_dim * num_k_heads));
+
     cb(q_conv, "q_conv", il);
-    ggml_tensor * k_conv =
-        ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1],
-                     head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
     cb(k_conv, "k_conv", il);
-    ggml_tensor * v_conv =
-        ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1],
-                     2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
     cb(v_conv, "v_conv", il);
 
-    // Unsqueeze them
-    q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
-    k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
-    v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
+    const float eps_norm = hparams.f_norm_rms_eps;
 
-    beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs);
+    q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm);
+    k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm);
 
-    ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
-    state               = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs);
-    cb(state, "state_predelta", il);
+    //q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
+    //k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
+    //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
 
     // if head keys and value keys are different, repeat to force tensors into matching shapes
     if (num_k_heads != num_v_heads) {
@@ -738,48 +432,28 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     cb(v_conv, "v_conv_predelta", il);
 
     // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
-    ggml_tensor * attn_out;
+    std::pair attn_out; // pair of (output, new_state)
     if (n_seq_tokens == 1) {
         attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
     } else {
-        attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
+        attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il);
     }
-    cb(attn_out, "attn_out", il);
-
-    // The tensors were concatenated 1d, so we need to extract them 1d as well
-    const int64_t output_flat_size = head_v_dim * num_v_heads * n_seq_tokens * n_seqs;
-    ggml_tensor * attn_out_1d      = ggml_view_1d(ctx0, attn_out, output_flat_size, 0);
-    cb(attn_out_1d, "attn_out_1d", il);
-
-    ggml_tensor * attn_out_final = ggml_cont_4d(ctx0, attn_out_1d, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
-    cb(attn_out_final, "attn_out_reshaped", il);
-
-    // Extract the state part (second part of the concatenated tensor)
-    // State starts after n_tokens elements along dimension 1
-    const int64_t state_flat_size = head_v_dim * head_v_dim * num_v_heads * n_seqs;
-
-    ggml_tensor * state_1d =
-        ggml_view_1d(ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size(attn_out));
-    cb(state_1d, "state_1d", il);
+    ggml_tensor * output    = attn_out.first;
+    ggml_tensor * new_state = attn_out.second;
+    cb(output, "attn_output", il);
+    cb(new_state, "new_state", il);
 
     // Update the recurrent states
     ggml_build_forward_expand(gf,
-                              ggml_cpy(ctx0, state_1d,
-                                       ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
-                                                    kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
-
-    GGML_ASSERT(ggml_nelements(attn_out_1d) + ggml_nelements(state_1d) == ggml_nelements(attn_out));
-
-    // Reshape both attn_out_final and z to 2D tensors for normalization
-    // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
-    ggml_tensor * attn_out_2d_final =
-        ggml_cont_2d(ctx0, attn_out_final, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
+            ggml_cpy(ctx0, new_state,
+                ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
+                    kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
 
     // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
-    ggml_tensor * z_2d = ggml_cont_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
+    ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
 
     // Apply gated normalization: self.norm(core_attn_out, z)
-    ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
+    ggml_tensor * attn_out_norm = build_norm_gated(output, model.layers[il].ssm_norm, z_2d, il);
 
     // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
     ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
@@ -790,7 +464,8 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     cb(cur, "linear_attn_out", il);
 
     // Reshape back to original dimensions
-    cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
+    cur = ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
+
     return cur;
 }
 
@@ -804,14 +479,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int
                 model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
                 nullptr,
                 n_expert, n_expert_used, LLM_FFN_SILU,
-                true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
+                true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il,
+                nullptr, model.layers[il].ffn_gate_up_exps);
         cb(moe_out, "ffn_moe_out", il);
 
         // Add shared experts if present - following Qwen3Next reference implementation
         if (model.layers[il].ffn_up_shexp != nullptr) {
             ggml_tensor * ffn_shexp =
                 build_ffn(cur,
-                    model.layers[il].ffn_up_shexp, NULL, NULL,
+                    model.layers[il].ffn_up_shexp,   NULL, NULL,
                     model.layers[il].ffn_gate_shexp, NULL, NULL,
                     model.layers[il].ffn_down_shexp, NULL, NULL,
                     NULL,
@@ -824,17 +500,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int
             ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur);
             cb(shared_gate, "shared_expert_gate", il);
 
-            // Apply sigmoid to the gate
             shared_gate = ggml_sigmoid(ctx0, shared_gate);
             cb(shared_gate, "shared_expert_gate_sigmoid", il);
 
-            // The gate needs to be broadcast to match the dimensions of ffn_shexp
-            // ffn_shexp is [n_embd, n_tokens, 1, 1] and shared_gate is [1, n_tokens, 1, 1]
-            // We need to repeat the gate along the feature dimension
-            shared_gate = ggml_repeat(ctx0, shared_gate, ffn_shexp);
-            cb(shared_gate, "shared_expert_gate_broadcast", il);
-
-            // Apply the gate to the shared expert output
             ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
             cb(ffn_shexp, "ffn_shexp_gated", il);
 
diff --git a/llama/llama.cpp/src/models/qwen3vl-moe.cpp b/llama/llama.cpp/src/models/qwen3vl-moe.cpp
index f72f80a8376..e5e1a2150c8 100644
--- a/llama/llama.cpp/src/models/qwen3vl-moe.cpp
+++ b/llama/llama.cpp/src/models/qwen3vl-moe.cpp
@@ -2,7 +2,8 @@
 
 llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
     const size_t n_deepstack_layers = hparams.n_deepstack_layers;
-    const int64_t n_embd = hparams.n_embd;
+
+    const int64_t n_embd      = hparams.n_embd;
     const int64_t n_embd_head = hparams.n_embd_head_v;
 
     GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -16,17 +17,6 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_
     int sections[4];
     std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
 
-    std::vector deepstack_features(n_deepstack_layers, nullptr);
-
-    if (ubatch.embd) {
-        // Image input: split main embd and deepstack embds
-        ggml_tensor * inpL_main = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], 0);
-        for (size_t i = 0; i < n_deepstack_layers; i++) {
-            deepstack_features[i] = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], (i + 1) * n_embd * sizeof(float));
-        }
-        inpL = inpL_main;
-    }
-
     // inp_pos - contains the positions
     ggml_tensor * inp_pos = build_inp_pos();
 
@@ -120,8 +110,9 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_
         cur = build_cvec(cur, il);
         cb(cur, "l_out", il);
 
-        if (ubatch.embd && (size_t)il < n_deepstack_layers) {
-            cur = ggml_add(ctx0, cur, deepstack_features[il]);
+        if (il < (int) n_deepstack_layers) {
+            ggml_tensor * ds = ggml_view_2d(ctx0, res->t_inp_embd, n_embd, n_tokens, res->t_inp_embd->nb[1], (il + 1) * n_embd * sizeof(float));
+            cur = ggml_add(ctx0, cur, ds);
             cb(cur, "deepstack_out", il);
         }
 
diff --git a/llama/llama.cpp/src/models/qwen3vl.cpp b/llama/llama.cpp/src/models/qwen3vl.cpp
index 0bae52239ca..0f8315b3240 100644
--- a/llama/llama.cpp/src/models/qwen3vl.cpp
+++ b/llama/llama.cpp/src/models/qwen3vl.cpp
@@ -2,7 +2,8 @@
 
 llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
     const size_t n_deepstack_layers = hparams.n_deepstack_layers;
-    const int64_t n_embd = hparams.n_embd;
+
+    const int64_t n_embd      = hparams.n_embd;
     const int64_t n_embd_head = hparams.n_embd_head_v;
 
     GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -16,17 +17,6 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_
     int sections[4];
     std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
 
-    std::vector deepstack_features(n_deepstack_layers, nullptr);
-
-    if (ubatch.embd) {
-        // Image input: split main embd and deepstack embds
-        ggml_tensor * inpL_main = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], 0);
-        for (size_t i = 0; i < n_deepstack_layers; i++) {
-            deepstack_features[i] = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], (i + 1) * n_embd * sizeof(float));
-        }
-        inpL = inpL_main;
-    }
-
     // inp_pos - contains the positions
     ggml_tensor * inp_pos = build_inp_pos();
 
@@ -113,8 +103,9 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_
         cur = build_cvec(cur, il);
         cb(cur, "l_out", il);
 
-        if (ubatch.embd && (size_t)il < n_deepstack_layers) {
-            cur = ggml_add(ctx0, cur, deepstack_features[il]);
+        if (il < (int) n_deepstack_layers) {
+            ggml_tensor * ds = ggml_view_2d(ctx0, res->t_inp_embd, n_embd, n_tokens, res->t_inp_embd->nb[1], (il + 1) * n_embd * sizeof(float));
+            cur = ggml_add(ctx0, cur, ds);
             cb(cur, "deepstack_out", il);
         }
 
diff --git a/llama/llama.cpp/src/models/rwkv6-base.cpp b/llama/llama.cpp/src/models/rwkv6-base.cpp
index 7beed2daffb..83aeab7280b 100644
--- a/llama/llama.cpp/src/models/rwkv6-base.cpp
+++ b/llama/llama.cpp/src/models/rwkv6-base.cpp
@@ -1,5 +1,7 @@
 #include "models.h"
 
+#include "llama-memory-recurrent.h"
+
 llm_build_rwkv6_base::llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params),
     model(model) {}
diff --git a/llama/llama.cpp/src/models/rwkv7-base.cpp b/llama/llama.cpp/src/models/rwkv7-base.cpp
index cda44653849..7fcab77745c 100644
--- a/llama/llama.cpp/src/models/rwkv7-base.cpp
+++ b/llama/llama.cpp/src/models/rwkv7-base.cpp
@@ -1,5 +1,7 @@
 #include "models.h"
 
+#include "llama-memory-recurrent.h"
+
 llm_build_rwkv7_base::llm_build_rwkv7_base(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params),
     model(model) {}
diff --git a/llama/llama.cpp/src/models/smallthinker.cpp b/llama/llama.cpp/src/models/smallthinker.cpp
index 277eec29554..4c497ca76f4 100644
--- a/llama/llama.cpp/src/models/smallthinker.cpp
+++ b/llama/llama.cpp/src/models/smallthinker.cpp
@@ -26,10 +26,16 @@ llm_build_smallthinker::llm_build_smallthinker(const llama_model & model,
     ggml_tensor * inp_out_ids = build_inp_out_ids();
 
     for (int il = 0; il < n_layer; ++il) {
+        const float freq_base_l  = model.get_rope_freq_base (cparams, il);
+        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+
         ggml_tensor * inpSA  = inpL;
-        ggml_tensor * probs  = nullptr;
 
-        probs = build_lora_mm(model.layers[il].ffn_gate_inp, inpL);  // [n_expert, n_tokens]
+        // This overlaps with SWA layers in current models, so get_rope_freq_base/scale may be superfluous
+        const bool use_rope = hparams.n_no_rope_layer_step == n_layer ||
+                              il % hparams.n_no_rope_layer_step != 0;
+
+        ggml_tensor * probs = build_lora_mm(model.layers[il].ffn_gate_inp, inpL);  // [n_expert, n_tokens]
         cb(probs, "ffn_moe_logits", il);
 
         // norm
@@ -52,11 +58,11 @@ llm_build_smallthinker::llm_build_smallthinker(const llama_model & model,
             Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
             Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
-            if (hparams.n_no_rope_layer_step == n_layer || il % hparams.n_no_rope_layer_step != 0) {
-                Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+            if (use_rope) {
+                Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
                                     ext_factor, attn_factor, beta_fast, beta_slow);
 
-                Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
                                     ext_factor, attn_factor, beta_fast, beta_slow);
             }
             cb(Qcur, "Qcur", il);
diff --git a/llama/llama.cpp/src/models/step35-iswa.cpp b/llama/llama.cpp/src/models/step35-iswa.cpp
new file mode 100644
index 00000000000..f8737815a67
--- /dev/null
+++ b/llama/llama.cpp/src/models/step35-iswa.cpp
@@ -0,0 +1,168 @@
+#include "models.h"
+
+llm_build_step35_iswa::llm_build_step35_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+    ggml_tensor * inp_pos     = build_inp_pos();
+    auto        * inp_attn    = build_attn_inp_kv_iswa();
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_tensor * inpSA = inpL;
+
+        const uint32_t n_head_l    = hparams.n_head(il);
+        const uint32_t n_head_kv_l = hparams.n_head_kv(il);
+
+        const float freq_base_l  = model.get_rope_freq_base(cparams, il);
+        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+
+        cur = inpL;
+
+        // dump pre-attn RMSNorm input to pinpoint layer boundary issues
+        cb(cur, "attn_norm_in", il);
+
+        // self-attention
+        {
+            cur = build_norm(cur, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+
+            cb(Qcur, "Qcur", il);
+            cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
+
+            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l,    n_tokens);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens);
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens);
+
+            // Q/K per-head RMSNorm (Step35 q_norm / k_norm)
+            if (model.layers[il].attn_q_norm) {
+                Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
+                cb(Qcur, "Qcur_normed", il);
+            }
+            if (model.layers[il].attn_k_norm) {
+                Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
+                cb(Kcur, "Kcur_normed", il);
+            }
+
+            // RoPE (partial rotary factors per layer)
+            const bool is_swa = hparams.is_swa(il);
+            ggml_tensor * rope_factors = is_swa ? nullptr : model.get_rope_factors(cparams, il);
+            const int64_t n_rot_l = is_swa ? hparams.n_rot : (hparams.n_rot / 2);
+            Qcur = ggml_rope_ext(
+                ctx0, Qcur, inp_pos, rope_factors,
+                n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                ext_factor, attn_factor, beta_fast, beta_slow
+            );
+            Kcur = ggml_rope_ext(
+                ctx0, Kcur, inp_pos, rope_factors,
+                n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                ext_factor, attn_factor, beta_fast, beta_slow
+            );
+            cb(Qcur, "Qcur_pos", il);
+            cb(Kcur, "Kcur_pos", il);
+
+            const float kq_scale = 1.0f / sqrtf(float(n_embd_head_k));
+            ggml_tensor * attn_out = build_attn(inp_attn,
+                    nullptr, nullptr,
+                    Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
+            cb(attn_out, "attn_out", il);
+            // head-wise attention gate: sigmoid(g_proj(x)) in torch
+            if (model.layers[il].wqkv_gate) {
+                ggml_tensor * gate = build_lora_mm(model.layers[il].wqkv_gate, cur); // [n_head_l, n_tokens]
+                cb(gate, "attn_gate", il);
+
+                gate = ggml_sigmoid(ctx0, gate);
+                cb(gate, "attn_gate_sigmoid", il);
+
+                // reshape + broadcast to [n_embd_head_v, n_head_l, n_tokens]
+                ggml_tensor * attn_3d = ggml_reshape_3d(ctx0, attn_out, n_embd_head_v, n_head_l, n_tokens);
+                ggml_tensor * gate_3d = ggml_reshape_3d(ctx0, gate,       1,          n_head_l, n_tokens);
+                cb(gate_3d, "attn_gate_3d", il);
+
+                attn_3d = ggml_mul(ctx0, attn_3d, gate_3d);
+                cb(attn_3d, "attn_gated_3d", il);
+
+                attn_out = ggml_reshape_2d(ctx0, attn_3d, n_embd_head_v * n_head_l, n_tokens);
+                cb(attn_out, "attn_gated", il);
+            }
+
+            // output projection
+            cur = build_lora_mm(model.layers[il].wo, attn_out);
+            cb(cur, "attn_proj", il);
+        }
+
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+        cb(ffn_inp, "ffn_inp", il);
+
+        cur = build_norm(ffn_inp, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il);
+        cb(cur, "ffn_norm", il);
+
+        // feed-forward
+        if (model.layers[il].ffn_gate_inp == nullptr) {
+            // dense MLP
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   nullptr,
+                    model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, nullptr,
+                    model.layers[il].ffn_down, model.layers[il].ffn_down_b, nullptr,
+                    nullptr,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+        } else {
+            // MoE routed experts
+            const bool  norm_w  = hparams.expert_weights_norm;
+            const float w_scale = hparams.expert_weights_scale;
+            const bool  scale_w = w_scale != 0.0f;
+            ggml_tensor * moe_out = build_moe_ffn(cur,
+                    model.layers[il].ffn_gate_inp,
+                    model.layers[il].ffn_up_exps,
+                    model.layers[il].ffn_gate_exps,
+                    model.layers[il].ffn_down_exps,
+                    model.layers[il].ffn_exp_probs_b,
+                    n_expert, n_expert_used,
+                    LLM_FFN_SILU,
+                    norm_w, scale_w, w_scale,
+                    (llama_expert_gating_func_type) hparams.expert_gating_func,
+                    il);
+            cb(moe_out, "ffn_moe_out", il);
+
+            // shared expert MLP (always added on MoE layers in Step35)
+            ggml_tensor * sh_out = build_ffn(cur,
+                    model.layers[il].ffn_up_shexp,   nullptr, nullptr,
+                    model.layers[il].ffn_gate_shexp, nullptr, nullptr,
+                    model.layers[il].ffn_down_shexp, nullptr, nullptr,
+                    nullptr,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(sh_out, "ffn_shared_out", il);
+
+            cur = ggml_add(ctx0, moe_out, sh_out);
+            cb(cur, "ffn_out", il);
+        }
+        cur = ggml_add(ctx0, cur, ffn_inp);
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        inpL = cur;
+    }
+
+    cur = inpL;
+
+    cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    cur = build_lora_mm(model.output, cur);
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
diff --git a/llama/llama.cpp/src/unicode.cpp b/llama/llama.cpp/src/unicode.cpp
index 13ced055f21..1475b53b659 100644
--- a/llama/llama.cpp/src/unicode.cpp
+++ b/llama/llama.cpp/src/unicode.cpp
@@ -1,21 +1,10 @@
-#if defined(_MSC_VER)
-#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
-#endif
-
-#if defined(_WIN32)
-#define WIN32_LEAN_AND_MEAN
-#include 
-#endif
-
 #include "unicode.h"
 #include "unicode-data.h"
 
 #include 
 #include 
-#include 
 #include 
 #include 
-#include 
 #include 
 #include 
 #include 
@@ -204,43 +193,6 @@ static std::unordered_map unicode_utf8_to_byte_map() {
     return map;
 }
 
-static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
-#ifdef _WIN32
-    int wlen = MultiByteToWideChar(CP_UTF8, 0, s.c_str(), -1, NULL, 0);
-    if (!wlen) {
-        throw std::invalid_argument("failed to convert regex");
-    }
-    wchar_t * wbuf = (wchar_t *) malloc(wlen * sizeof(wchar_t));
-    wlen = MultiByteToWideChar(CP_UTF8, 0, s.c_str(), -1, wbuf, wlen);
-    if (!wlen) {
-        free(wbuf);
-        throw std::invalid_argument("failed to convert regex");
-    }
-    std::wstring ret = std::wstring(wbuf);
-    free(wbuf);
-    return ret;
-#else
-#if defined(__clang__)
-    // disable C++17 deprecation warning for std::codecvt_utf8
-#    pragma clang diagnostic push
-#    pragma clang diagnostic ignored "-Wdeprecated-declarations"
-#elif defined(__GNUC__)
-#    pragma GCC diagnostic push
-#    pragma GCC diagnostic ignored "-Wdeprecated-declarations"
-#endif
-
-    std::wstring_convert> conv;
-
-#if defined(__clang__)
-#    pragma clang diagnostic pop
-#elif defined(__GNUC__)
-#    pragma GCC diagnostic pop
-#endif
-
-    return conv.from_bytes(s);
-#endif
-}
-
 static std::vector unicode_byte_encoding_process(const std::vector & bpe_words) {
     std::vector bpe_encoded_words;
     for (const auto & word : bpe_words) {
@@ -518,49 +470,26 @@ static std::vector unicode_regex_split_custom_llama3(const std::string &
     return bpe_offsets;
 }
 
-// use std::wregex to split the text
-static std::vector unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector & offsets) {
-    std::wregex expr(regex_expr, std::regex_constants::optimize | std::regex_constants::nosubs);
-    std::vector bpe_offsets; // store the offset of each word
-    bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
-    size_t start = 0;
-    for (auto offset : offsets) {
-        std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
-        std::wcregex_iterator end;
-
-        int64_t start_idx = 0;
-        while (it != end) {
-            std::wcmatch match = *it;
-            if (match.position() > start_idx) {
-                bpe_offsets.emplace_back(match.position() - start_idx);
-            }
-            bpe_offsets.emplace_back(match.length());
-            start_idx = match.position() + match.length();
-            ++it;
-        }
-
-        if (start_idx < (int64_t) offset) {
-            bpe_offsets.emplace_back(offset - start_idx);
-        }
-        start += offset;
-    }
-
-    return bpe_offsets;
-}
-
-// use std::regex to split the text
-static std::vector unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector & offsets) {
-    std::regex expr(regex_expr, std::regex_constants::optimize | std::regex_constants::nosubs);
+template 
+static std::vector unicode_regex_split_stl(const std::basic_string & text, const std::basic_string & regex, const std::vector & offsets) {
+    using BidirIt = typename std::basic_string::const_iterator;
+#ifdef _MSC_VER
+    // Bypass bug in MSVC: https://github.com/ggml-org/llama.cpp/issues/17830
+    constexpr auto regex_flags = std::regex_constants::ECMAScript;
+#else
+    constexpr auto regex_flags = std::regex_constants::optimize | std::regex_constants::nosubs;
+#endif
+    std::basic_regex expr(regex, regex_flags);
     std::vector bpe_offsets; // store the offset of each word
     bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
     size_t start = 0;
     for (auto offset : offsets) {
-        std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
-        std::cregex_iterator end;
+        std::regex_iterator it(text.begin() + start, text.begin() + start + offset, expr);
+        std::regex_iterator end;
 
         int64_t start_idx = 0;
         while (it != end) {
-            std::cmatch match = *it;
+            std::match_results match = *it;
             if (match.position() > start_idx) {
                 bpe_offsets.emplace_back(match.position() - start_idx);
             }
@@ -840,6 +769,12 @@ static std::vector unicode_regex_split_custom(const std::string & text,
     } else if (regex_expr == "\\p{AFMoE_digits}") {
         // AFMOE digit pattern - use custom implementation for proper splitting
         bpe_offsets = unicode_regex_split_custom_afmoe(text, offsets);
+    } else if (regex_expr == "\\d{1,3}(?=(?:\\d{3})*\\b)") {
+        // tiny_aya digit grouping pattern from tokenizer.json:
+        //   {"type": "Split", "pattern": {"Regex": "\\d{1,3}(?=(?:\\d{3})*\\b)"}, "behavior": "Isolated"}
+        // Splits digits into groups of 3 from the right (e.g., 1234567 -> 1, 234, 567)
+        // TODO: Revisit this regex, incase there are any subtle tokenization differences with the original regex.
+        bpe_offsets = unicode_regex_split_custom_afmoe(text, offsets);
     }
 
     return bpe_offsets;
@@ -985,6 +920,11 @@ std::vector unicode_regex_split(const std::string & text, const std
         { "\\p{P}", unicode_cpt_flags::PUNCTUATION },
         { "\\p{M}", unicode_cpt_flags::ACCENT_MARK },
         { "\\p{S}", unicode_cpt_flags::SYMBOL },
+        { "\\p{Lu}", unicode_cpt_flags::LETTER }, // Uppercase letter
+        { "\\p{Ll}", unicode_cpt_flags::LETTER }, // Lowercase letter
+        { "\\p{Lt}", unicode_cpt_flags::LETTER }, // Titlecase letter
+        { "\\p{Lm}", unicode_cpt_flags::LETTER }, // Modifier letter
+        { "\\p{Lo}", unicode_cpt_flags::LETTER }, // Other letter
     };
 
     static const std::map k_ucat_cpt = {
@@ -1067,10 +1007,10 @@ std::vector unicode_regex_split(const std::string & text, const std
                     break;
                 }
             }
+            const auto cpts_regex = unicode_cpts_from_utf8(regex_expr);
 
             if (use_collapsed) {
                 // sanity-check that the original regex does not contain any non-ASCII characters
-                const auto cpts_regex = unicode_cpts_from_utf8(regex_expr);
                 for (size_t i = 0; i < cpts_regex.size(); ++i) {
                     if (cpts_regex[i] >= 128) {
                         throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported");
@@ -1095,22 +1035,26 @@ std::vector unicode_regex_split(const std::string & text, const std
                         continue;
                     }
 
-                    if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() &&
+                    // Match \p{...} Unicode properties of varying lengths
+                    if (regex_expr[i + 0] == '\\' && i + 3 < regex_expr.size() &&
                         regex_expr[i + 1] == 'p' &&
-                        regex_expr[i + 2] == '{' &&
-                        regex_expr[i + 4] == '}') {
-                        const std::string pat = regex_expr.substr(i, 5);
-                        if (k_ucat_enum.find(pat) != k_ucat_enum.end()) {
-                            if (!inside) {
-                                regex_expr_collapsed += '[';
-                            }
-                            regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
-                            regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
-                            if (!inside) {
-                                regex_expr_collapsed += ']';
+                        regex_expr[i + 2] == '{') {
+                        // Find the closing brace
+                        size_t closing_brace = regex_expr.find('}', i + 3);
+                        if (closing_brace != std::string::npos && closing_brace <= i + 10) { // reasonable limit
+                            const std::string pat = regex_expr.substr(i, closing_brace - i + 1);
+                            if (k_ucat_enum.find(pat) != k_ucat_enum.end()) {
+                                if (!inside) {
+                                    regex_expr_collapsed += '[';
+                                }
+                                regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
+                                regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
+                                if (!inside) {
+                                    regex_expr_collapsed += ']';
+                                }
+                                i = closing_brace;
+                                continue;
                             }
-                            i += 4;
-                            continue;
                         }
                     }
 
@@ -1122,7 +1066,7 @@ std::vector unicode_regex_split(const std::string & text, const std
                 bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
             } else {
                 // no unicode category used, we can use std::wregex directly
-                const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
+                std::wstring wregex_expr(cpts_regex.begin(), cpts_regex.end());
 
                 // std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
                 std::wstring wtext(cpts.begin(), cpts.end());
diff --git a/llama/llama.cpp/tools/mtmd/clip-graph.h b/llama/llama.cpp/tools/mtmd/clip-graph.h
index 2b1915779f2..4c7f7504cfc 100644
--- a/llama/llama.cpp/tools/mtmd/clip-graph.h
+++ b/llama/llama.cpp/tools/mtmd/clip-graph.h
@@ -32,10 +32,6 @@ struct clip_graph {
     const float kq_scale;
     const clip_flash_attn_type flash_attn_type;
 
-    // for debugging
-    const bool debug_graph;
-    std::vector & debug_print_tensors;
-
     ggml_context_ptr ctx0_ptr;
     ggml_context * ctx0;
     ggml_cgraph * gf;
diff --git a/llama/llama.cpp/tools/mtmd/clip-impl.h b/llama/llama.cpp/tools/mtmd/clip-impl.h
index d75233cc0a9..a30c32ed42b 100644
--- a/llama/llama.cpp/tools/mtmd/clip-impl.h
+++ b/llama/llama.cpp/tools/mtmd/clip-impl.h
@@ -36,6 +36,8 @@
 // vision-specific
 #define KEY_VISION_PROJ_TYPE    "clip.vision.projector_type" // for models with mixed modalities
 #define KEY_IMAGE_SIZE          "clip.vision.image_size"
+#define KEY_IMAGE_MIN_PIXELS    "clip.vision.image_min_pixels"
+#define KEY_IMAGE_MAX_PIXELS    "clip.vision.image_max_pixels"
 #define KEY_PREPROC_IMAGE_SIZE  "clip.vision.preproc_image_size"
 #define KEY_PATCH_SIZE          "clip.vision.patch_size"
 #define KEY_IMAGE_MEAN          "clip.vision.image_mean"
@@ -45,13 +47,14 @@
 #define KEY_SPATIAL_MERGE_SIZE  "clip.vision.spatial_merge_size"
 #define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers"
 
-#define KEY_MM_PATCH_MERGE_TYPE   "clip.vision.mm_patch_merge_type"
-#define KEY_IMAGE_GRID_PINPOINTS  "clip.vision.image_grid_pinpoints"
-#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
-#define KEY_WIN_ATTN_PATTERN      "clip.vision.n_wa_pattern"
-#define KEY_ATTN_WINDOW_SIZE      "clip.vision.window_size"
-#define KEY_MINICPMV_VERSION      "clip.minicpmv_version"
-#define KEY_MINICPMV_QUERY_NUM    "clip.minicpmv_query_num"
+#define KEY_MM_PATCH_MERGE_TYPE    "clip.vision.mm_patch_merge_type"
+#define KEY_IMAGE_GRID_PINPOINTS   "clip.vision.image_grid_pinpoints"
+#define KEY_IMAGE_CROP_RESOLUTION  "clip.vision.image_crop_resolution"
+#define KEY_WIN_ATTN_PATTERN       "clip.vision.n_wa_pattern"
+#define KEY_WIN_ATTN_LAYER_INDEXES "clip.vision.wa_layer_indexes"
+#define KEY_ATTN_WINDOW_SIZE       "clip.vision.window_size"
+#define KEY_MINICPMV_VERSION       "clip.minicpmv_version"
+#define KEY_MINICPMV_QUERY_NUM     "clip.minicpmv_query_num"
 
 // audio-specific
 #define KEY_AUDIO_PROJ_TYPE     "clip.audio.projector_type" // for models with mixed modalities
@@ -138,6 +141,62 @@
 #define TN_TOK_BOI         "v.boi"
 #define TN_TOK_EOI         "v.eoi"
 
+// (conformer) lfm2
+#define TN_PRE_ENCODE_OUT  "a.pre_encode.out.%s"
+#define TN_FFN_NORM        "%s.blk.%d.ffn_norm.%s"
+#define TN_FFN_NORM_1      "%s.blk.%d.ffn_norm_1.%s"
+#define TN_FFN_UP_1        "%s.blk.%d.ffn_up_1.%s"
+#define TN_FFN_DOWN_1      "%s.blk.%d.ffn_down_1.%s"
+#define TN_POS_BIAS_U      "%s.blk.%d.pos_bias_u"
+#define TN_POS_BIAS_V      "%s.blk.%d.pos_bias_v"
+#define TN_NORM_CONV       "%s.blk.%d.norm_conv.%s"
+#define TN_LINEAR_POS      "%s.blk.%d.linear_pos.%s"
+#define TN_CONV_DW         "%s.blk.%d.conv_dw.%s"
+#define TN_CONV_NORM       "%s.blk.%d.conv_norm.%s"
+#define TN_CONV_PW1        "%s.blk.%d.conv_pw1.%s"
+#define TN_CONV_PW2        "%s.blk.%d.conv_pw2.%s"
+
+// mobilenetv5 (gemma3n) definitions
+#define TN_MNV5_STEM_CONV        "v.conv_stem.conv.weight"
+#define TN_MNV5_STEM_BIAS        "v.conv_stem.conv.bias"
+#define TN_MNV5_STEM_BN          "v.conv_stem.bn.weight"
+
+// Stage 0 Block (Edge Residual)
+#define TN_MNV5_BLK_S0_EXP_W     "v.blk.%d.%d.conv_exp.weight"
+#define TN_MNV5_BLK_S0_BN1_W     "v.blk.%d.%d.bn1.weight"
+#define TN_MNV5_BLK_S0_PWL_W     "v.blk.%d.%d.conv_pwl.weight"
+#define TN_MNV5_BLK_S0_BN2_W     "v.blk.%d.%d.bn2.weight"
+
+// Stage 1+ Block (Universal Inverted Residual)
+#define TN_MNV5_BLK_DW_START_W   "v.blk.%d.%d.dw_start.conv.weight"
+#define TN_MNV5_BLK_DW_START_BN  "v.blk.%d.%d.dw_start.bn.weight"
+#define TN_MNV5_BLK_DW_MID_W     "v.blk.%d.%d.dw_mid.conv.weight"
+#define TN_MNV5_BLK_DW_MID_BN    "v.blk.%d.%d.dw_mid.bn.weight"
+#define TN_MNV5_BLK_PW_EXP_W     "v.blk.%d.%d.pw_exp.conv.weight"
+#define TN_MNV5_BLK_PW_EXP_BN    "v.blk.%d.%d.pw_exp.bn.weight"
+#define TN_MNV5_BLK_PW_PROJ_W    "v.blk.%d.%d.pw_proj.conv.weight"
+#define TN_MNV5_BLK_PW_PROJ_BN   "v.blk.%d.%d.pw_proj.bn.weight"
+#define TN_MNV5_BLK_LAYER_SCALE  "v.blk.%d.%d.layer_scale.gamma"
+
+// Attention Components
+#define TN_MNV5_ATTN_Q_W         "v.blk.%d.%d.attn.query.proj.weight"
+#define TN_MNV5_ATTN_K_W         "v.blk.%d.%d.attn.key.proj.weight"
+#define TN_MNV5_ATTN_V_W         "v.blk.%d.%d.attn.value.proj.weight"
+#define TN_MNV5_ATTN_O_W         "v.blk.%d.%d.attn.output.proj.weight"
+#define TN_MNV5_ATTN_K_DW        "v.blk.%d.%d.attn.key.down_conv.weight"
+#define TN_MNV5_ATTN_K_NORM      "v.blk.%d.%d.attn.key.norm.weight"
+#define TN_MNV5_ATTN_V_DW        "v.blk.%d.%d.attn.value.down_conv.weight"
+#define TN_MNV5_ATTN_V_NORM      "v.blk.%d.%d.attn.value.norm.weight"
+#define TN_MNV5_ATTN_NORM        "v.blk.%d.%d.norm.weight" // Block norm used in attn blocks
+
+// MSFA
+#define TN_MNV5_MSFA_FFN_EXP_W   "v.msfa.ffn.pw_exp.conv.weight"
+#define TN_MNV5_MSFA_FFN_EXP_BN  "v.msfa.ffn.pw_exp.bn.weight"
+#define TN_MNV5_MSFA_FFN_PROJ_W  "v.msfa.ffn.pw_proj.conv.weight"
+#define TN_MNV5_MSFA_FFN_PROJ_BN "v.msfa.ffn.pw_proj.bn.weight"
+#define TN_MNV5_MSFA_NORM        "v.msfa.norm.weight"
+
+
 // align x to upper multiple of n
 #define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
 
@@ -155,6 +214,8 @@ enum projector_type {
     PROJECTOR_TYPE_QWEN2VL,
     PROJECTOR_TYPE_QWEN3VL,
     PROJECTOR_TYPE_GEMMA3,
+    PROJECTOR_TYPE_GEMMA3NV,
+    PROJECTOR_TYPE_GEMMA3NA,
     PROJECTOR_TYPE_IDEFICS3,
     PROJECTOR_TYPE_PIXTRAL,
     PROJECTOR_TYPE_QWEN25VL,
@@ -165,12 +226,18 @@ enum projector_type {
     PROJECTOR_TYPE_GLMA,
     PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
     PROJECTOR_TYPE_VOXTRAL,
+    PROJECTOR_TYPE_MUSIC_FLAMINGO,
     PROJECTOR_TYPE_LFM2,
     PROJECTOR_TYPE_KIMIVL,
+    PROJECTOR_TYPE_PADDLEOCR,
     PROJECTOR_TYPE_LIGHTONOCR,
     PROJECTOR_TYPE_COGVLM,
     PROJECTOR_TYPE_JANUS_PRO,
+    PROJECTOR_TYPE_LFM2A,
     PROJECTOR_TYPE_GLM4V,
+    PROJECTOR_TYPE_YOUTUVL,
+    PROJECTOR_TYPE_KIMIK25,
+    PROJECTOR_TYPE_NEMOTRON_V2_VL,
     PROJECTOR_TYPE_UNKNOWN,
 };
 
@@ -184,6 +251,8 @@ static std::map PROJECTOR_TYPE_NAMES = {
     { PROJECTOR_TYPE_QWEN25VL,  "qwen2.5vl_merger"},
     { PROJECTOR_TYPE_QWEN3VL,   "qwen3vl_merger"},
     { PROJECTOR_TYPE_GEMMA3,    "gemma3"},
+    { PROJECTOR_TYPE_GEMMA3NV,  "gemma3nv"},
+    { PROJECTOR_TYPE_GEMMA3NA,  "gemma3na"},
     { PROJECTOR_TYPE_IDEFICS3,  "idefics3"},
     { PROJECTOR_TYPE_PIXTRAL,   "pixtral"},
     { PROJECTOR_TYPE_ULTRAVOX,  "ultravox"},
@@ -193,12 +262,18 @@ static std::map PROJECTOR_TYPE_NAMES = {
     { PROJECTOR_TYPE_GLMA,      "glma"},
     { PROJECTOR_TYPE_QWEN25O,   "qwen2.5o"},
     { PROJECTOR_TYPE_VOXTRAL,   "voxtral"},
+    { PROJECTOR_TYPE_MUSIC_FLAMINGO, "musicflamingo"},
     { PROJECTOR_TYPE_LFM2,      "lfm2"},
     { PROJECTOR_TYPE_KIMIVL,    "kimivl"},
+    { PROJECTOR_TYPE_PADDLEOCR, "paddleocr"},
     { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
     { PROJECTOR_TYPE_COGVLM,    "cogvlm"},
     { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
+    { PROJECTOR_TYPE_LFM2A,     "lfm2a"},
     { PROJECTOR_TYPE_GLM4V,     "glm4v"},
+    { PROJECTOR_TYPE_YOUTUVL,   "youtuvl"},
+    { PROJECTOR_TYPE_KIMIK25,   "kimik25"},
+    { PROJECTOR_TYPE_NEMOTRON_V2_VL, "nemotron_v2_vl"},
 };
 
 static projector_type clip_projector_type_from_string(const std::string & str) {
diff --git a/llama/llama.cpp/tools/mtmd/clip-model.h b/llama/llama.cpp/tools/mtmd/clip-model.h
index f5c41ff1389..e0eb9b32c8f 100644
--- a/llama/llama.cpp/tools/mtmd/clip-model.h
+++ b/llama/llama.cpp/tools/mtmd/clip-model.h
@@ -4,6 +4,7 @@
 #include "clip.h"
 #include "clip-impl.h"
 
+#include 
 #include 
 #include 
 #include 
@@ -14,6 +15,7 @@ enum ffn_op_type {
     FFN_GELU_ERF,
     FFN_SILU,
     FFN_GELU_QUICK,
+    FFN_RELU_SQR,
 };
 
 enum norm_type {
@@ -60,6 +62,7 @@ struct clip_hparams {
     std::unordered_set vision_feature_layer;
     int32_t attn_window_size = 0;
     int32_t n_wa_pattern = 0;
+    std::unordered_set wa_layer_indexes; // explicit layer indexes that use full attention (for irregular patterns like YoutuVL)
 
     // audio
     int32_t n_mel_bins = 0; // whisper preprocessor
@@ -142,11 +145,74 @@ struct clip_layer {
     ggml_tensor * deepstack_fc2_w = nullptr;
     ggml_tensor * deepstack_fc2_b = nullptr;
 
+    // lfm2
+    ggml_tensor * ff_norm_w     = nullptr;
+    ggml_tensor * ff_norm_b     = nullptr;
+    ggml_tensor * ff_norm_1_w   = nullptr;
+    ggml_tensor * ff_norm_1_b   = nullptr;
+    ggml_tensor * ff_up_1_w     = nullptr;
+    ggml_tensor * ff_up_1_b     = nullptr;
+    ggml_tensor * ff_down_1_w   = nullptr;
+    ggml_tensor * ff_down_1_b   = nullptr;
+    ggml_tensor * pos_bias_u    = nullptr;
+    ggml_tensor * pos_bias_v    = nullptr;
+    ggml_tensor * norm_conv_w   = nullptr;
+    ggml_tensor * norm_conv_b   = nullptr;
+    ggml_tensor * linear_pos_w  = nullptr;
+
+    ggml_tensor * conv_norm_w   = nullptr;
+    ggml_tensor * conv_norm_b   = nullptr;
+    ggml_tensor * conv_dw_w     = nullptr;
+    ggml_tensor * conv_dw_b     = nullptr;
+    ggml_tensor * conv_pw1_w    = nullptr;
+    ggml_tensor * conv_pw1_b    = nullptr;
+    ggml_tensor * conv_pw2_w    = nullptr;
+    ggml_tensor * conv_pw2_b    = nullptr;
+
     bool has_deepstack() const {
         return deepstack_fc1_w != nullptr;
     }
 };
 
+// Expanded MobileNetV5 block structure for Gemma3n vision encoder
+struct mobilenetv5_block {
+    // Stage 0 (Edge Residual)
+    ggml_tensor * s0_conv_exp_w = nullptr;
+    ggml_tensor * s0_bn1_w      = nullptr;
+    ggml_tensor * s0_conv_pwl_w = nullptr;
+    ggml_tensor * s0_bn2_w      = nullptr;
+
+    // Stage 1+ (Universal Inverted Residual)
+    ggml_tensor * dw_start_w    = nullptr;
+    ggml_tensor * dw_start_bn_w = nullptr;
+
+    ggml_tensor * pw_exp_w      = nullptr;
+    ggml_tensor * pw_exp_bn_w   = nullptr;
+
+    ggml_tensor * dw_mid_w      = nullptr;
+    ggml_tensor * dw_mid_bn_w   = nullptr;
+
+    ggml_tensor * pw_proj_w     = nullptr;
+    ggml_tensor * pw_proj_bn_w  = nullptr;
+
+    ggml_tensor * layer_scale_w = nullptr;
+
+    // Attention (MQA) components
+    ggml_tensor * attn_q_w = nullptr;
+    ggml_tensor * attn_k_w = nullptr;
+    ggml_tensor * attn_v_w = nullptr;
+    ggml_tensor * attn_o_w = nullptr;
+
+    // Optional downsampling/norm in attention
+    ggml_tensor * attn_k_dw_w   = nullptr;
+    ggml_tensor * attn_k_norm_w = nullptr;
+    ggml_tensor * attn_v_dw_w   = nullptr;
+    ggml_tensor * attn_v_norm_w = nullptr;
+
+    // Block norm (often present in attention blocks)
+    ggml_tensor * attn_norm_w   = nullptr;
+};
+
 struct clip_model {
     clip_modality modality = CLIP_MODALITY_VISION;
     projector_type proj_type = PROJECTOR_TYPE_MLP;
@@ -263,6 +329,23 @@ struct clip_model {
     ggml_tensor * mm_input_proj_w = nullptr;
     ggml_tensor * mm_soft_emb_norm_w = nullptr;
 
+    // mobilenetv5 for gemma3n
+    std::vector mobilenet_blocks;
+    std::vector mobilenet_stage_ends;
+    ggml_tensor * mobilenet_stem_conv_w = nullptr;
+    ggml_tensor * mobilenet_stem_conv_b = nullptr;
+    ggml_tensor * mobilenet_stem_norm_w = nullptr;
+    ggml_tensor * mm_post_proj_norm_w = nullptr;
+
+    // Multi-Scale Fusion Adapter (MSFA) components
+    ggml_tensor * msfa_concat_conv_w = nullptr;
+    ggml_tensor * msfa_concat_norm_w = nullptr;
+    ggml_tensor * msfa_ffn_expand_w = nullptr;
+    ggml_tensor * msfa_ffn_project_w = nullptr;
+    ggml_tensor * msfa_ffn_expand_bn = nullptr;
+    ggml_tensor * msfa_ffn_project_bn = nullptr;
+
+
     // pixtral, glm4v
     ggml_tensor * token_embd_img_break = nullptr;
     ggml_tensor * mm_patch_merger_w = nullptr;
@@ -286,9 +369,16 @@ struct clip_model {
     ggml_tensor * mm_boi = nullptr;
     ggml_tensor * mm_eoi = nullptr;
 
+    // lfm2 audio
+    std::array pre_encode_conv_X_w = {nullptr};
+    std::array pre_encode_conv_X_b = {nullptr};
+    ggml_tensor * pre_encode_out_w = nullptr;
+    ggml_tensor * pre_encode_out_b = nullptr;
+
     bool audio_has_avgpool() const {
         return proj_type == PROJECTOR_TYPE_QWEN2A
-            || proj_type == PROJECTOR_TYPE_VOXTRAL;
+            || proj_type == PROJECTOR_TYPE_VOXTRAL
+            || proj_type == PROJECTOR_TYPE_MUSIC_FLAMINGO;
     }
 
     bool audio_has_stack_frames() const {
diff --git a/llama/llama.cpp/tools/mtmd/clip.cpp b/llama/llama.cpp/tools/mtmd/clip.cpp
index d3a37842df2..db83c9ecfdb 100644
--- a/llama/llama.cpp/tools/mtmd/clip.cpp
+++ b/llama/llama.cpp/tools/mtmd/clip.cpp
@@ -10,6 +10,7 @@
 #include "ggml-backend.h"
 #include "gguf.h"
 
+#include 
 #include 
 #include 
 #include 
@@ -165,18 +166,14 @@ struct clip_ctx {
     ggml_backend_t backend_cpu = nullptr;
     ggml_backend_buffer_ptr buf;
 
+
     int max_nodes = 8192;
     ggml_backend_sched_ptr sched;
     clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO;
     bool is_allocated = false;
 
-    // for debugging
-    bool debug_graph = false;
-    std::vector debug_print_tensors;
-
     clip_ctx(clip_context_params & ctx_params) {
         flash_attn_type = ctx_params.flash_attn_type;
-        debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr;
         backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
         if (!backend_cpu) {
             throw std::runtime_error("failed to initialize CPU backend");
@@ -217,6 +214,10 @@ struct clip_ctx {
         sched.reset(
             ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), 8192, false, true)
         );
+
+        if (ctx_params.cb_eval != nullptr) {
+            ggml_backend_sched_set_eval_callback(sched.get(), ctx_params.cb_eval, ctx_params.cb_eval_user_data);
+        }
     }
 
     ~clip_ctx() {
@@ -252,9 +253,7 @@ clip_graph::clip_graph(clip_ctx * ctx, const clip_image_f32 & img) :
         n_mmproj_embd(clip_n_mmproj_embd(ctx)),
         eps(hparams.eps),
         kq_scale(1.0f / sqrtf((float)d_head)),
-        flash_attn_type(ctx->flash_attn_type),
-        debug_graph(ctx->debug_graph),
-        debug_print_tensors(ctx->debug_print_tensors) {
+        flash_attn_type(ctx->flash_attn_type) {
     struct ggml_init_params params = {
         /*.mem_size   =*/ ctx->buf_compute_meta.size(),
         /*.mem_buffer =*/ ctx->buf_compute_meta.data(),
@@ -265,14 +264,11 @@ clip_graph::clip_graph(clip_ctx * ctx, const clip_image_f32 & img) :
     gf = ggml_new_graph_custom(ctx0, ctx->max_nodes, false);
 }
 
-void clip_graph::cb(ggml_tensor * cur0, const char * name, int il) const {
-    if (debug_graph) {
-        ggml_tensor * cur = ggml_cpy(ctx0, cur0, ggml_dup_tensor(ctx0, cur0));
-        std::string cur_name = il >= 0 ? std::string(name) + "_" + std::to_string(il) : name;
-        ggml_set_name(cur, cur_name.c_str());
-        ggml_set_output(cur);
-        ggml_build_forward_expand(gf, cur);
-        debug_print_tensors.push_back(cur);
+void clip_graph::cb(ggml_tensor * cur, const char * name, int il) const {
+    if (il >= 0) {
+        ggml_format_name(cur, "%s-%d", name, il);
+    } else {
+        ggml_set_name(cur, name);
     }
 }
 
@@ -359,9 +355,17 @@ ggml_tensor * clip_graph::build_vit(
                     /* nb2    */ cur->nb[1],
                     /* offset */ ggml_row_size(cur->type, 2 * n_embd));
 
-                // TODO: q/k norm requires row size == n_embd, while here it's d_head
-                // we can add support in the future if needed
-                GGML_ASSERT(layer.q_norm == nullptr && layer.k_norm == nullptr);
+                if (layer.q_norm) {
+                    GGML_ASSERT(layer.q_norm->ne[0] == Qcur->ne[0]);
+                    Qcur = build_norm(Qcur, layer.q_norm, NULL, norm_t, eps, il);
+                    cb(Qcur, "Qcur_norm", il);
+                }
+
+                if (layer.k_norm) {
+                    GGML_ASSERT(layer.k_norm->ne[0] == Kcur->ne[0]);
+                    Kcur = build_norm(Kcur, layer.k_norm, NULL, norm_t, eps, il);
+                    cb(Kcur, "Kcur_norm", il);
+                }
 
             } else {
                 // separate q, k, v
@@ -576,6 +580,12 @@ ggml_tensor * clip_graph::build_ffn(
                 cur = ggml_gelu_quick(ctx0, cur);
                 cb(cur, "ffn_gelu_quick", il);
             } break;
+        case FFN_RELU_SQR:
+            {
+                cur = ggml_relu(ctx0, cur);
+                cur = ggml_sqr(ctx0, cur);
+                cb(cur, "ffn_relu_sqr", il);
+            } break;
     }
 
     if (down) {
@@ -631,9 +641,6 @@ ggml_tensor * clip_graph::build_attn(
         ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
         v = ggml_cont(ctx0, v);
 
-        const auto n_tokens = q->ne[1];
-        const auto n_head   = q->ne[2];
-
         ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
         // F32 may not needed for vision encoders?
         // ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
@@ -642,7 +649,7 @@ ggml_tensor * clip_graph::build_attn(
 
         ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
         cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
-        cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
+        cur = ggml_cont_2d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2] * cur->ne[3]);
     }
 
     cb(cur, "kqv_out", il);
@@ -690,8 +697,8 @@ ggml_tensor * clip_graph::build_rope_2d(
     {
         first = ggml_view_3d(ctx0, cur,
             n_dim/2, n_head, n_pos,
-            ggml_row_size(cur->type, n_dim),
-            ggml_row_size(cur->type, n_dim*n_head),
+            cur->nb[1],
+            cur->nb[2],
             0);
         first = ggml_rope_ext(
             ctx0,
@@ -709,8 +716,8 @@ ggml_tensor * clip_graph::build_rope_2d(
     {
         second = ggml_view_3d(ctx0, cur,
             n_dim/2, n_head, n_pos,
-            ggml_row_size(cur->type, n_dim),
-            ggml_row_size(cur->type, n_dim*n_head),
+            cur->nb[1],
+            cur->nb[2],
             n_dim/2 * ggml_element_size(cur));
         second = ggml_rope_ext(
             ctx0,
@@ -801,6 +808,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
             {
                 builder = std::make_unique(ctx, img);
             } break;
+        case PROJECTOR_TYPE_GEMMA3NV:
+            {
+                builder = std::make_unique(ctx, img);
+            } break;
         case PROJECTOR_TYPE_PIXTRAL:
         case PROJECTOR_TYPE_LIGHTONOCR:
             {
@@ -823,6 +834,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
             {
                 builder = std::make_unique(ctx, img);
             } break;
+        case PROJECTOR_TYPE_NEMOTRON_V2_VL:
+            {
+                builder = std::make_unique(ctx, img);
+            } break;
         case PROJECTOR_TYPE_LLAMA4:
             {
                 builder = std::make_unique(ctx, img);
@@ -831,6 +846,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
         case PROJECTOR_TYPE_VOXTRAL:
         case PROJECTOR_TYPE_QWEN2A:
         case PROJECTOR_TYPE_GLMA:
+        case PROJECTOR_TYPE_MUSIC_FLAMINGO:
             {
                 builder = std::make_unique(ctx, img);
             } break;
@@ -838,6 +854,14 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
             {
                 builder = std::make_unique(ctx, img);
             } break;
+        case PROJECTOR_TYPE_PADDLEOCR:
+            {
+                builder = std::make_unique(ctx, img);
+            } break;
+        case PROJECTOR_TYPE_KIMIK25:
+            {
+                builder = std::make_unique(ctx, img);
+            } break;
         case PROJECTOR_TYPE_COGVLM:
             {
                 builder = std::make_unique(ctx, img);
@@ -850,10 +874,18 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
             {
                 builder = std::make_unique(ctx, img);
             } break;
+        case PROJECTOR_TYPE_LFM2A:
+            {
+                builder = std::make_unique(ctx, img);
+            } break;
         case PROJECTOR_TYPE_GLM4V:
             {
                 builder = std::make_unique(ctx, img);
             } break;
+        case PROJECTOR_TYPE_YOUTUVL:
+            {
+                builder = std::make_unique(ctx, img);
+            } break;
         default:
             GGML_ABORT("missing cgraph builder");
     }
@@ -1014,6 +1046,8 @@ struct clip_model_loader {
                         hparams.minicpmv_query_num = 64;
                     } else if (hparams.minicpmv_version == 6) {
                         hparams.minicpmv_query_num = 64;
+                    } else if (hparams.minicpmv_version == 100045) {
+                        hparams.minicpmv_query_num = 64;
                     } else {
                         hparams.minicpmv_query_num = 96;
                     }
@@ -1112,6 +1146,7 @@ struct clip_model_loader {
                         }
                     } break;
                 case PROJECTOR_TYPE_INTERNVL:
+                case PROJECTOR_TYPE_NEMOTRON_V2_VL:
                     {
                         get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
                     } break;
@@ -1123,9 +1158,8 @@ struct clip_model_loader {
                 case PROJECTOR_TYPE_LFM2:
                     {
                         get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
-                        // ref: https://huggingface.co/LiquidAI/LFM2-VL-3B/blob/main/preprocessor_config.json
-                        // config above specifies number of tokens after downsampling, while here it is before, relax lowerbound to 64
-                        hparams.set_limit_image_tokens(64, 1024);
+                        // ref: https://huggingface.co/LiquidAI/LFM2.5-VL-1.6B/blob/main/processor_config.json
+                        hparams.set_limit_image_tokens(64, 256);
                     } break;
                 case PROJECTOR_TYPE_PIXTRAL:
                 case PROJECTOR_TYPE_LIGHTONOCR:
@@ -1146,6 +1180,22 @@ struct clip_model_loader {
                         hparams.set_limit_image_tokens(8, 1024);
                         hparams.set_warmup_n_tokens(256); // avoid OOM on warmup
                     } break;
+                case PROJECTOR_TYPE_KIMIK25:
+                    {
+                        hparams.rope_theta = 10000.0f;
+                        get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
+
+                        int min_pixels = 0, max_pixels = 0;
+                        get_u32(KEY_IMAGE_MIN_PIXELS, min_pixels, false);
+                        get_u32(KEY_IMAGE_MAX_PIXELS, max_pixels, false);
+                        if (min_pixels > 0 && max_pixels > 0) {
+                            hparams.image_min_pixels = min_pixels;
+                            hparams.image_max_pixels = max_pixels;
+                            hparams.warmup_image_size = static_cast(std::sqrt(max_pixels));
+                        } else {
+                            hparams.set_limit_image_tokens(2, 4096);
+                        }
+                    } break;
                 case PROJECTOR_TYPE_GEMMA3:
                     {
                         // default value (used by all model sizes in gemma 3 family)
@@ -1154,6 +1204,14 @@ struct clip_model_loader {
                         // test model (tinygemma3) has a different value, we optionally read it
                         get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
                     } break;
+
+                case PROJECTOR_TYPE_GEMMA3NV:
+                    {
+                        // Gemma3n uses MobileNetV5 which produces 256 tokens (16x16)
+                        // Similar configuration to Gemma3
+                        hparams.n_merge = 1;  // MobileNetV5 handles resizing internally
+                        get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
+                    } break;
                 case PROJECTOR_TYPE_QWEN2VL:
                 case PROJECTOR_TYPE_QWEN25VL:
                 case PROJECTOR_TYPE_QWEN3VL:
@@ -1171,6 +1229,20 @@ struct clip_model_loader {
                             LOG_WRN("%s: more info: https://github.com/ggml-org/llama.cpp/issues/16842\n\n", __func__);
                         }
                     } break;
+                case PROJECTOR_TYPE_YOUTUVL:
+                    {
+                        hparams.n_merge = 2;
+                        get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false);
+                        get_u32(KEY_ATTN_WINDOW_SIZE, hparams.attn_window_size, true);
+                        std::vector wa_layer_indexes_vec;
+                        get_arr_int(KEY_WIN_ATTN_LAYER_INDEXES, wa_layer_indexes_vec, true);
+                        for (auto & layer : wa_layer_indexes_vec) {
+                            hparams.wa_layer_indexes.insert(layer);
+                        }
+                        // support max_height * max_width = 8000 * 8000. 8000/16/2 = 250 image tokens
+                        hparams.set_limit_image_tokens(1, 62500);
+                        hparams.set_warmup_n_tokens(16*16); // avoid OOM on warmup
+                    } break;
                 case PROJECTOR_TYPE_GLM4V:
                     {
                         hparams.rope_theta = 10000.0f;
@@ -1189,6 +1261,7 @@ struct clip_model_loader {
                 case PROJECTOR_TYPE_QWEN2A:
                 case PROJECTOR_TYPE_GLMA:
                 case PROJECTOR_TYPE_VOXTRAL:
+                case PROJECTOR_TYPE_MUSIC_FLAMINGO:
                     {
                         bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX ||
                                              model.proj_type == PROJECTOR_TYPE_VOXTRAL ||
@@ -1204,6 +1277,23 @@ struct clip_model_loader {
                         hparams.audio_window_len   = 400;
                         hparams.audio_hop_len      = 160;
                     } break;
+                case PROJECTOR_TYPE_PADDLEOCR:
+                    {
+                        hparams.n_merge = 2;
+                        get_u32(KEY_IMAGE_MIN_PIXELS, hparams.image_min_pixels);
+                        get_u32(KEY_IMAGE_MAX_PIXELS, hparams.image_max_pixels);
+
+                        hparams.set_warmup_n_tokens(28*28); // avoid OOM on warmup
+                    } break;
+                case PROJECTOR_TYPE_LFM2A:
+                    {
+                        // audio preprocessing params
+                        hparams.audio_chunk_len        = 1; // in seconds
+                        hparams.audio_sample_rate      = 16000;
+                        hparams.audio_n_fft            = 512;
+                        hparams.audio_window_len       = 400;
+                        hparams.audio_hop_len          = 160;
+                    } break;
                 default:
                     break;
             }
@@ -1229,7 +1319,14 @@ struct clip_model_loader {
                 LOG_INF("%s: has_llava_proj:     %d\n", __func__, hparams.has_llava_projector);
                 LOG_INF("%s: minicpmv_version:   %d\n", __func__, hparams.minicpmv_version);
                 LOG_INF("%s: n_merge:            %d\n", __func__, hparams.n_merge);
-                LOG_INF("%s: n_wa_pattern:       %d\n", __func__, hparams.n_wa_pattern);
+                LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern);
+                if (!hparams.wa_layer_indexes.empty()) {
+                    LOG_INF("%s: wa_layer_indexes:  ", __func__);
+                    for (auto & layer : hparams.wa_layer_indexes) {
+                        LOG_INF("%d ", layer);
+                    }
+                    LOG_INF("\n");
+                }
                 if (hparams.image_min_pixels > 0) {
                     LOG_INF("%s: image_min_pixels:   %d%s\n", __func__, hparams.image_min_pixels, hparams.custom_image_min_tokens > 0 ? " (custom value)" : "");
                 }
@@ -1311,6 +1408,10 @@ struct clip_model_loader {
 
         model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, prefix), false);
 
+        if (model.proj_type == PROJECTOR_TYPE_GEMMA3NV) {
+            hparams.n_layer = 0; // gemma3n does not use normal layer structure
+        }
+
         // layers
         model.layers.resize(hparams.n_layer);
         for (int il = 0; il < hparams.n_layer; ++il) {
@@ -1385,6 +1486,7 @@ struct clip_model_loader {
             }
         }
 
+
         switch (model.proj_type) {
             case PROJECTOR_TYPE_MLP:
             case PROJECTOR_TYPE_MLP_NORM:
@@ -1479,8 +1581,8 @@ struct clip_model_loader {
                     model.mm_model_mlp_1_w = get_tensor(string_format(TN_GLM_ADAPTER_D_H_2_4H, "weight"));
                     model.mm_model_mlp_2_w = get_tensor(string_format(TN_GLM_ADAPTER_GATE, "weight"));
                     model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H, "weight"));
-                    model.mm_boi = get_tensor(string_format(TN_TOK_GLM_BOI, "weight"));
-                    model.mm_eoi = get_tensor(string_format(TN_TOK_GLM_EOI, "weight"));
+                    model.mm_boi = get_tensor(string_format(TN_TOK_GLM_BOI));
+                    model.mm_eoi = get_tensor(string_format(TN_TOK_GLM_EOI));
                 } break;
             case PROJECTOR_TYPE_QWEN2VL:
             case PROJECTOR_TYPE_QWEN25VL:
@@ -1497,6 +1599,14 @@ struct clip_model_loader {
                     model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
                     model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
                 } break;
+            case PROJECTOR_TYPE_YOUTUVL:
+                {
+                    model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM);        // merger.ln_q (RMS norm)
+                    model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));  // merger.mlp.0
+                    model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
+                    model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));  // merger.mlp.2
+                    model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
+                } break;
             case PROJECTOR_TYPE_GLM4V:
                 {
                     model.projection     = get_tensor(TN_MM_PROJECTOR);
@@ -1516,12 +1626,115 @@ struct clip_model_loader {
                     model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ);
                     model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N);
                 } break;
+            case PROJECTOR_TYPE_GEMMA3NV:
+                {
+                    model.mobilenet_stem_conv_w = get_tensor(TN_MNV5_STEM_CONV, false);
+                    model.mobilenet_stem_conv_b = get_tensor(TN_MNV5_STEM_BIAS, false);
+                    model.mobilenet_stem_norm_w = get_tensor(TN_MNV5_STEM_BN, false);
+
+                    model.msfa_ffn_expand_w  = get_tensor(TN_MNV5_MSFA_FFN_EXP_W, false);
+                    model.msfa_ffn_expand_bn = get_tensor(TN_MNV5_MSFA_FFN_EXP_BN, false); // Consume BN if present but likely folded
+                    model.msfa_ffn_project_w = get_tensor(TN_MNV5_MSFA_FFN_PROJ_W, false);
+                    model.msfa_ffn_project_bn = get_tensor(TN_MNV5_MSFA_FFN_PROJ_BN, false);
+
+                    model.msfa_concat_norm_w = get_tensor(TN_MNV5_MSFA_NORM, false);
+
+                    // Dynamically load blocks stage by stage
+                    for (int stage = 0; stage < 4; ++stage) {
+                        int blocks_found_in_stage = 0;
+
+                        for (int blk_idx = 0; ; ++blk_idx) {
+                            bool found_block = false;
+                            mobilenetv5_block block;
+
+                            // 1. Check for Edge Residual (S0)
+                            block.s0_conv_exp_w = get_tensor(string_format(TN_MNV5_BLK_S0_EXP_W, stage, blk_idx), false);
+                            if (block.s0_conv_exp_w) {
+                                found_block = true;
+                                block.s0_bn1_w      = get_tensor(string_format(TN_MNV5_BLK_S0_BN1_W, stage, blk_idx), false);
+                                block.s0_conv_pwl_w = get_tensor(string_format(TN_MNV5_BLK_S0_PWL_W, stage, blk_idx), false);
+                                block.s0_bn2_w      = get_tensor(string_format(TN_MNV5_BLK_S0_BN2_W, stage, blk_idx), false);
+                            }
+                            // 2. Check for UIR (Universal Inverted Residual)
+                            else {
+                                // Check for dw_start OR pw_exp (some UIR blocks skip dw_start)
+                                block.dw_start_w = get_tensor(string_format(TN_MNV5_BLK_DW_START_W, stage, blk_idx), false);
+                                block.pw_exp_w   = get_tensor(string_format(TN_MNV5_BLK_PW_EXP_W, stage, blk_idx), false);
+
+                                if (block.dw_start_w || block.pw_exp_w) {
+                                    found_block = true;
+                                    if (block.dw_start_w) {
+                                        block.dw_start_bn_w = get_tensor(string_format(TN_MNV5_BLK_DW_START_BN, stage, blk_idx), false);
+                                    }
+                                    if (block.pw_exp_w) {
+                                        block.pw_exp_bn_w   = get_tensor(string_format(TN_MNV5_BLK_PW_EXP_BN, stage, blk_idx), false);
+                                    }
+                                    block.dw_mid_w      = get_tensor(string_format(TN_MNV5_BLK_DW_MID_W, stage, blk_idx), false);
+                                    if (block.dw_mid_w) {
+                                        block.dw_mid_bn_w   = get_tensor(string_format(TN_MNV5_BLK_DW_MID_BN, stage, blk_idx), false);
+                                    }
+                                    block.pw_proj_w     = get_tensor(string_format(TN_MNV5_BLK_PW_PROJ_W, stage, blk_idx), false);
+                                    if (block.pw_proj_w) {
+                                        block.pw_proj_bn_w  = get_tensor(string_format(TN_MNV5_BLK_PW_PROJ_BN, stage, blk_idx), false);
+                                    }
+                                    block.layer_scale_w = get_tensor(string_format(TN_MNV5_BLK_LAYER_SCALE, stage, blk_idx), false);
+                                }
+                            }
+
+                            // 3. Check for Attention (MQA)
+                            // Even if UIR/Edge check failed, this might be a pure attention block
+                            ggml_tensor* attn_q_check = get_tensor(string_format(TN_MNV5_ATTN_Q_W, stage, blk_idx), false);
+                            if (attn_q_check) {
+                                found_block = true;
+                                block.attn_q_w = attn_q_check;
+                                block.attn_k_w = get_tensor(string_format(TN_MNV5_ATTN_K_W, stage, blk_idx), false);
+                                block.attn_v_w = get_tensor(string_format(TN_MNV5_ATTN_V_W, stage, blk_idx), false);
+                                block.attn_o_w = get_tensor(string_format(TN_MNV5_ATTN_O_W, stage, blk_idx), false);
+                                block.attn_k_dw_w   = get_tensor(string_format(TN_MNV5_ATTN_K_DW, stage, blk_idx), false);
+                                block.attn_k_norm_w = get_tensor(string_format(TN_MNV5_ATTN_K_NORM, stage, blk_idx), false);
+                                block.attn_v_dw_w   = get_tensor(string_format(TN_MNV5_ATTN_V_DW, stage, blk_idx), false);
+                                block.attn_v_norm_w = get_tensor(string_format(TN_MNV5_ATTN_V_NORM, stage, blk_idx), false);
+                                block.attn_norm_w   = get_tensor(string_format(TN_MNV5_ATTN_NORM, stage, blk_idx), false);
+                                // Note: Attention blocks also have layer_scale, load it if not already loaded by UIR check
+                                if (!block.layer_scale_w) {
+                                    block.layer_scale_w = get_tensor(string_format(TN_MNV5_BLK_LAYER_SCALE, stage, blk_idx), false);
+                                }
+                            }
+
+                            if (found_block) {
+                                model.mobilenet_blocks.push_back(block);
+                                blocks_found_in_stage++;
+                            } else {
+                                // End of blocks for this stage
+                                break;
+                            }
+                        }
+
+                        // Track where this stage ends in the flat vector
+                        if (blocks_found_in_stage > 0) {
+                            model.mobilenet_stage_ends.push_back(model.mobilenet_blocks.size() - 1);
+                            LOG_INF("%s: Stage %d ended at global block index %zu\n", __func__, stage, model.mobilenet_blocks.size() - 1);
+                        }
+                    }
+                    model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ);
+                    model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N);
+                } break;
             case PROJECTOR_TYPE_IDEFICS3:
                 {
                     model.projection = get_tensor(TN_MM_PROJECTOR);
                 } break;
             case PROJECTOR_TYPE_LFM2:
+                {
+                    model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
+                    model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B, false);
+                    model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
+                    model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
+                    model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
+                    model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
+                } break;
             case PROJECTOR_TYPE_KIMIVL:
+            case PROJECTOR_TYPE_PADDLEOCR:
+            case PROJECTOR_TYPE_KIMIK25:
                 {
                     model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM);
                     model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B);
@@ -1580,6 +1793,17 @@ struct clip_model_loader {
                     model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
                     model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
                 } break;
+            case PROJECTOR_TYPE_MUSIC_FLAMINGO:
+                {
+                    model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
+                    model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
+                    model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
+                    model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
+                    model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
+                    model.mm_1_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "bias"));
+                    model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
+                    model.mm_2_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "bias"));
+                } break;
             case PROJECTOR_TYPE_INTERNVL:
                 {
                     model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
@@ -1589,6 +1813,12 @@ struct clip_model_loader {
                     model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
                     model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias"));
                 } break;
+            case PROJECTOR_TYPE_NEMOTRON_V2_VL:
+                {
+                    model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
+                    model.mm_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
+                    model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
+                } break;
             case PROJECTOR_TYPE_GLMA:
                 {
                     model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
@@ -1601,8 +1831,8 @@ struct clip_model_loader {
                     model.mm_2_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "bias"));
                     model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight"));
                     model.mm_norm_pre_b = get_tensor(string_format(TN_MM_NORM_PRE, "bias"));
-                    model.mm_boi = get_tensor(string_format(TN_TOK_BOI, "weight"));
-                    model.mm_eoi = get_tensor(string_format(TN_TOK_EOI, "weight"));
+                    model.mm_boi = get_tensor(string_format(TN_TOK_BOI));
+                    model.mm_eoi = get_tensor(string_format(TN_TOK_EOI));
                 } break;
             case PROJECTOR_TYPE_LLAMA4:
                 {
@@ -1628,6 +1858,52 @@ struct clip_model_loader {
                     model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
                     model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
                 } break;
+            case PROJECTOR_TYPE_LFM2A:
+                {
+                    for (int i : {0, 2, 3, 5, 6}) {
+                        model.pre_encode_conv_X_w[i] = get_tensor(string_format(TN_CONV1D, i, "weight"));
+                        model.pre_encode_conv_X_b[i] = get_tensor(string_format(TN_CONV1D, i, "bias"));
+                    }
+                    model.pre_encode_out_w    = get_tensor(string_format(TN_PRE_ENCODE_OUT, "weight"));
+                    model.pre_encode_out_b    = get_tensor(string_format(TN_PRE_ENCODE_OUT, "bias"));
+
+                    model.mm_0_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 0, "weight"));
+                    model.mm_0_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 0, "bias"));
+                    model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
+                    model.mm_1_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "bias"));
+                    model.mm_3_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 3, "weight"));
+                    model.mm_3_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 3, "bias"));
+
+                    for (int il = 0; il < hparams.n_layer; ++il) {
+                        auto & layer = model.layers[il];
+
+                        layer.ff_norm_w   = get_tensor(string_format(TN_FFN_NORM,   prefix, il, "weight"));
+                        layer.ff_norm_b   = get_tensor(string_format(TN_FFN_NORM,   prefix, il, "bias"));
+                        layer.ff_norm_1_w = get_tensor(string_format(TN_FFN_NORM_1, prefix, il, "weight"));
+                        layer.ff_norm_1_b = get_tensor(string_format(TN_FFN_NORM_1, prefix, il, "bias"));
+                        layer.ff_up_1_w   = get_tensor(string_format(TN_FFN_UP_1,   prefix, il, "weight"));
+                        layer.ff_up_1_b   = get_tensor(string_format(TN_FFN_UP_1,   prefix, il, "bias"));
+                        layer.ff_down_1_w = get_tensor(string_format(TN_FFN_DOWN_1, prefix, il, "weight"));
+                        layer.ff_down_1_b = get_tensor(string_format(TN_FFN_DOWN_1, prefix, il, "bias"));
+
+                        layer.pos_bias_u = get_tensor(string_format(TN_POS_BIAS_U, prefix, il));
+                        layer.pos_bias_v = get_tensor(string_format(TN_POS_BIAS_V, prefix, il));
+
+                        layer.norm_conv_w = get_tensor(string_format(TN_NORM_CONV, prefix, il, "weight"));
+                        layer.norm_conv_b = get_tensor(string_format(TN_NORM_CONV, prefix, il, "bias"));
+
+                        layer.linear_pos_w = get_tensor(string_format(TN_LINEAR_POS, prefix, il, "weight"));
+
+                        layer.conv_norm_w  = get_tensor(string_format(TN_CONV_NORM, prefix, il, "weight"));
+                        layer.conv_norm_b  = get_tensor(string_format(TN_CONV_NORM, prefix, il, "bias"));
+                        layer.conv_dw_w    = get_tensor(string_format(TN_CONV_DW,   prefix, il, "weight"));
+                        layer.conv_dw_b    = get_tensor(string_format(TN_CONV_DW,   prefix, il, "bias"));
+                        layer.conv_pw1_w   = get_tensor(string_format(TN_CONV_PW1,  prefix, il, "weight"));
+                        layer.conv_pw1_b   = get_tensor(string_format(TN_CONV_PW1,  prefix, il, "bias"));
+                        layer.conv_pw2_w   = get_tensor(string_format(TN_CONV_PW2,  prefix, il, "weight"));
+                        layer.conv_pw2_b   = get_tensor(string_format(TN_CONV_PW2,  prefix, il, "bias"));
+                    }
+                } break;
             default:
                 GGML_ASSERT(false && "unknown projector type");
         }
@@ -1932,6 +2208,7 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
 
     try {
         clip_model_loader loader(fname);
+        bool skip_audio = false;
 
         if (loader.has_vision) {
             ctx_vision = new clip_ctx(ctx_params);
@@ -1941,10 +2218,14 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
                 loader.warmup(*ctx_vision);
             }
 
+            // TODO: we don't support audio for Gemma 3N, but GGUF contains audio tensors
+            // we can remove this check when we implement audio support for Gemma 3N
+            skip_audio = ctx_vision->model.proj_type == PROJECTOR_TYPE_GEMMA3NV;
+
             // clip_debug_encode(ctx_vision, 24*14, 24*14, 0.5f);
         }
 
-        if (loader.has_audio) {
+        if (loader.has_audio && !skip_audio) {
             ctx_audio = new clip_ctx(ctx_params);
             loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO);
             loader.load_tensors(*ctx_audio);
@@ -2067,7 +2348,7 @@ struct img_tool {
             std::array pad_color = {0, 0, 0}) {
         dst.nx = target_resolution.width;
         dst.ny = target_resolution.height;
-        dst.buf.resize(3 * dst.nx * dst.ny);
+        dst.buf.resize(3 * static_cast(dst.nx) * static_cast(dst.ny));
 
         if (dst.nx == src.nx && dst.ny == src.ny) {
             // no resize needed, simple copy
@@ -2120,7 +2401,7 @@ struct img_tool {
     static void crop(const clip_image_u8 & image, clip_image_u8 & dst, int x, int y, int w, int h) {
         dst.nx = w;
         dst.ny = h;
-        dst.buf.resize(3 * w * h);
+        dst.buf.resize(3 * static_cast(w) * static_cast(h));
 
         for (int i = 0; i < h; ++i) {
             for (int j = 0; j < w; ++j) {
@@ -2217,7 +2498,7 @@ struct img_tool {
     static void resize_bilinear(const clip_image_u8 & src, clip_image_u8 & dst, int target_width, int target_height) {
         dst.nx = target_width;
         dst.ny = target_height;
-        dst.buf.resize(3 * target_width * target_height);
+        dst.buf.resize(3 * static_cast(target_width) * static_cast(target_height));
 
         float x_ratio = static_cast(src.nx - 1) / target_width;
         float y_ratio = static_cast(src.ny - 1) / target_height;
@@ -2256,7 +2537,7 @@ struct img_tool {
 
         dst.nx = target_width;
         dst.ny = target_height;
-        dst.buf.resize(3 * target_width * target_height);
+        dst.buf.resize(3 * static_cast(target_width) * static_cast(target_height));
 
         float Cc;
         float C[5] = {};
@@ -2625,6 +2906,119 @@ struct llava_uhd {
     }
 };
 
+// ref: https://github.com/huggingface/transformers/blob/v5.1.0/src/transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py
+// some of the logic is similar to llava_uhd, but with different hyperparameters and some logic is unique (e.g. grid layout)
+struct lfm2_vl_image_processor {
+    // ref: https://huggingface.co/LiquidAI/LFM2.5-VL-1.6B/blob/main/processor_config.json
+    static constexpr int   min_tiles            = 2;
+    static constexpr int   max_tiles            = 10;
+    static constexpr float max_pixels_tolerance = 2.0f;
+    static constexpr int   tile_size            = 512;
+
+    static llava_uhd::slice_instructions get_slice_instructions(struct clip_ctx * ctx, const clip_image_size & original_size) {
+        llava_uhd::slice_instructions inst;
+        const auto & params  = ctx->model.hparams;
+        const int align_size = params.patch_size * params.n_merge;
+
+        inst.interpolation_overview = img_tool::RESIZE_ALGO_BILINEAR;
+        inst.interpolation_refined  = img_tool::RESIZE_ALGO_BILINEAR;
+        inst.overview_size          = img_tool::calc_size_preserved_ratio(original_size, align_size, params.image_min_pixels, params.image_max_pixels);
+
+        // tile if either dimension exceeds tile_size with tolerance
+        const bool needs_tiling = original_size.width > tile_size * max_pixels_tolerance || original_size.height > tile_size * max_pixels_tolerance;
+
+        if (!needs_tiling) {
+            inst.refined_size = clip_image_size{0, 0};
+            inst.grid_size    = clip_image_size{0, 0};
+            return inst;
+        }
+
+        const clip_image_size grid = get_grid_layout(original_size.height, original_size.width);
+
+        inst.grid_size    = grid;
+        inst.refined_size = clip_image_size{tile_size * grid.width, tile_size * grid.height};
+
+        LOG_DBG("%s: original size: %d x %d, overview size: %d x %d, refined size: %d x %d, grid size: %d x %d\n",
+                __func__,
+                original_size.width, original_size.height,
+                inst.overview_size.width, inst.overview_size.height,
+                inst.refined_size.width, inst.refined_size.height,
+                grid.width, grid.height);
+
+        for (int row = 0; row < grid.height; row++) {
+            for (int col = 0; col < grid.width; col++) {
+                llava_uhd::slice_coordinates slice;
+                slice.x    = col * tile_size;
+                slice.y    = row * tile_size;
+                slice.size = clip_image_size{tile_size, tile_size};
+                inst.slices.push_back(slice);
+                LOG_DBG("%s: slice %d: x=%d, y=%d, size=%d x %d\n",
+                        __func__, (int)inst.slices.size() - 1,
+                        slice.x, slice.y, slice.size.width, slice.size.height);
+            }
+        }
+
+        return inst;
+    }
+
+private:
+    static clip_image_size find_closest_aspect_ratio(
+            float aspect_ratio,
+            const std::vector & target_ratios,
+            int width, int height) {
+        float best_ratio_diff = std::numeric_limits::max();
+        clip_image_size best_ratio = {1, 1};
+        const float area = static_cast(width * height);
+
+        for (const auto & ratio : target_ratios) {
+            const float target_aspect_ratio = static_cast(ratio.width) / ratio.height;
+            const float ratio_diff = std::abs(aspect_ratio - target_aspect_ratio);
+            if (ratio_diff < best_ratio_diff) {
+                best_ratio_diff = ratio_diff;
+                best_ratio = ratio;
+            } else if (ratio_diff == best_ratio_diff) {
+                const float target_area = static_cast(tile_size * tile_size * ratio.width * ratio.height);
+                if (area > 0.5f * target_area) {
+                    best_ratio = ratio;
+                }
+            }
+        }
+        return best_ratio;
+    }
+
+    static std::vector get_target_ratios() {
+        std::vector ratios;
+        for (int n = min_tiles; n <= max_tiles; n++) {
+            for (int w = 1; w <= n; w++) {
+                for (int h = 1; h <= n; h++) {
+                    if (w * h >= min_tiles && w * h <= max_tiles) {
+                        bool found = false;
+                        for (const auto & r : ratios) {
+                            if (r.width == w && r.height == h) {
+                                found = true;
+                                break;
+                            }
+                        }
+                        if (!found) {
+                            ratios.push_back({w, h});
+                        }
+                    }
+                }
+            }
+        }
+        std::sort(ratios.begin(), ratios.end(), [](const clip_image_size & a, const clip_image_size & b) {
+            return a.width * a.height < b.width * b.height;
+        });
+        return ratios;
+    }
+
+    static clip_image_size get_grid_layout(int height, int width) {
+        const float aspect_ratio = static_cast(width) / height;
+        const auto ratios = get_target_ratios();
+        return find_closest_aspect_ratio(aspect_ratio, ratios, width, height);
+    }
+};
+
 // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
 // res_imgs memory is being allocated here, previous allocations will be freed if found
 bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) {
@@ -2652,6 +3046,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
         case PROJECTOR_TYPE_QWEN25VL:
         case PROJECTOR_TYPE_QWEN3VL:
         case PROJECTOR_TYPE_GLM4V:
+        case PROJECTOR_TYPE_PADDLEOCR:
             {
                 GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0);
                 clip_image_u8 resized;
@@ -2668,6 +3063,57 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
                 // res_imgs->data[0] = *res;
                 res_imgs->entries.push_back(std::move(img_f32));
             } break;
+        case PROJECTOR_TYPE_YOUTUVL:
+            {
+                const int patch_size = params.patch_size;  // typically 16
+                const int merge_size = params.n_merge;      // typically 2
+                const int align_size = patch_size * merge_size;  // 32
+
+                const int max_num_patches = params.image_max_pixels > 0 ?
+                    params.image_max_pixels / (patch_size * patch_size) : 256;
+
+                // Linear search for optimal scale to fit within max_num_patches
+                float scale = 1.0f;
+                int target_height = original_size.height;
+                int target_width = original_size.width;
+
+                auto get_scaled_image_size = [align_size](float scale, int size) -> int {
+                    float scaled_size = size * scale;
+                    // Round up to nearest multiple of align_size
+                    int aligned = static_cast(std::ceil(scaled_size / align_size)) * align_size;
+                    // Ensure at least one patch
+                    return std::max(align_size, aligned);
+                };
+
+                // Linear search with 0.02 step size
+                while (scale > 0.0f) {
+                    target_height = get_scaled_image_size(scale, original_size.height);
+                    target_width = get_scaled_image_size(scale, original_size.width);
+
+                    int num_patches_h = target_height / patch_size;
+                    int num_patches_w = target_width / patch_size;
+                    int num_patches = num_patches_h * num_patches_w;
+
+                    if (num_patches > max_num_patches) {
+                        scale -= 0.02f;
+                    } else {
+                        break;
+                    }
+                }
+
+                clip_image_size new_size = {target_width, target_height};
+
+                // Resize the image
+                clip_image_u8 resized;
+                img_tool::resize(*img, resized, new_size, img_tool::RESIZE_ALGO_BILINEAR, false);
+
+                // Normalize to float32
+                clip_image_f32_ptr img_f32(clip_image_f32_init());
+                normalize_image_u8_to_f32(resized, *img_f32, params.image_mean, params.image_std);
+
+                // Add to results
+                res_imgs->entries.push_back(std::move(img_f32));
+            } break;
 
         case PROJECTOR_TYPE_IDEFICS3:
             {
@@ -2721,6 +3167,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
         case PROJECTOR_TYPE_GLM_EDGE:
         case PROJECTOR_TYPE_GEMMA3:
         case PROJECTOR_TYPE_INTERNVL: // TODO @ngxson : support dynamic resolution
+        case PROJECTOR_TYPE_NEMOTRON_V2_VL:
             {
                 clip_image_u8 resized_image;
                 int sz = params.image_size;
@@ -2731,6 +3178,16 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
                 res_imgs->entries.push_back(std::move(img_f32));
             } break;
 
+        case PROJECTOR_TYPE_GEMMA3NV:
+            {
+                clip_image_u8 resized_image;
+                int sz = params.image_size;
+                img_tool::resize(*img, resized_image, {sz, sz}, img_tool::RESIZE_ALGO_BILINEAR, false);
+                clip_image_f32_ptr img_f32(clip_image_f32_init());
+                normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std);
+                res_imgs->entries.push_back(std::move(img_f32));
+            } break;
+
         case PROJECTOR_TYPE_JANUS_PRO:
             {
                 // Janus Pro preprocessing: pad to square with gray(127), resize to 384x384
@@ -2778,6 +3235,20 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
             } break;
 
         case PROJECTOR_TYPE_LFM2:
+            {
+                auto const inst = lfm2_vl_image_processor::get_slice_instructions(ctx, original_size);
+                std::vector imgs = llava_uhd::slice_image(img, inst);
+
+                for (size_t i = 0; i < imgs.size(); ++i) {
+                    clip_image_f32_ptr res(clip_image_f32_init());
+                    normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
+                    res_imgs->entries.push_back(std::move(res));
+                }
+
+                res_imgs->grid_x = inst.grid_size.width;
+                res_imgs->grid_y = inst.grid_size.height;
+            } break;
+
         case PROJECTOR_TYPE_KIMIVL:
             {
                 GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0);
@@ -2789,8 +3260,24 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
                 const std::array pad_color = {122, 116, 104};
 
                 clip_image_u8 resized_img;
-                const bool pad = (ctx->proj_type() != PROJECTOR_TYPE_LFM2);
-                img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BILINEAR, pad, pad_color);
+                img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BILINEAR, true, pad_color);
+                clip_image_f32_ptr res(clip_image_f32_init());
+                normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std);
+                res_imgs->entries.push_back(std::move(res));
+            } break;
+
+        case PROJECTOR_TYPE_KIMIK25:
+            {
+                GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0);
+                const clip_image_size target_size = img_tool::calc_size_preserved_ratio(
+                    original_size,
+                    params.patch_size * params.n_merge,
+                    params.image_min_pixels,
+                    params.image_max_pixels);
+                const std::array pad_color = {0, 0, 0};
+
+                clip_image_u8 resized_img;
+                img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BICUBIC, true, pad_color);
                 clip_image_f32_ptr res(clip_image_f32_init());
                 normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std);
                 res_imgs->entries.push_back(std::move(res));
@@ -2900,6 +3387,8 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 *
         case PROJECTOR_TYPE_QWEN25VL:
         case PROJECTOR_TYPE_QWEN3VL:
         case PROJECTOR_TYPE_GLM4V:
+        case PROJECTOR_TYPE_PADDLEOCR:
+        case PROJECTOR_TYPE_YOUTUVL:
             return (img->nx / params.patch_size) / 2;
         default:
             break;
@@ -2915,6 +3404,8 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 *
         case PROJECTOR_TYPE_QWEN25VL:
         case PROJECTOR_TYPE_QWEN3VL:
         case PROJECTOR_TYPE_GLM4V:
+        case PROJECTOR_TYPE_PADDLEOCR:
+        case PROJECTOR_TYPE_YOUTUVL:
             return (img->ny / params.patch_size) / 2;
         default:
             break;
@@ -2966,6 +3457,9 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
                     } else if (params.minicpmv_version == 6) {
                         // MiniCPM-V 4.5
                         n_patches = 64;
+                    } else if (params.minicpmv_version == 100045) {
+                        // MiniCPM-o 4.5
+                        n_patches = 64;
                     } else {
                         GGML_ABORT("Unknown minicpmv version");
                     }
@@ -2975,6 +3469,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
         case PROJECTOR_TYPE_QWEN25VL:
         case PROJECTOR_TYPE_QWEN3VL:
         case PROJECTOR_TYPE_GLM4V:
+        case PROJECTOR_TYPE_YOUTUVL:
             {
                 // dynamic size (2 conv, so double patch size)
                 int x_patch = img->nx / (params.patch_size * 2);
@@ -2984,14 +3479,22 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
         case PROJECTOR_TYPE_GEMMA3:
         case PROJECTOR_TYPE_IDEFICS3:
         case PROJECTOR_TYPE_INTERNVL:
+        case PROJECTOR_TYPE_NEMOTRON_V2_VL:
         case PROJECTOR_TYPE_LLAMA4:
             {
                 // both X and Y are downscaled by the scale factor
                 int scale_factor = ctx->model.hparams.n_merge;
                 n_patches /= (scale_factor * scale_factor);
             } break;
+        case PROJECTOR_TYPE_GEMMA3NV:
+            {
+                // MobileNetV5 MSFA adapter always outputs fixed 16x16 resolution
+                // regardless of input size (see architecture description)
+                n_patches = ctx->model.hparams.image_size / ctx->model.hparams.patch_size;
+            } break;
         case PROJECTOR_TYPE_LFM2:
         case PROJECTOR_TYPE_KIMIVL:
+        case PROJECTOR_TYPE_KIMIK25:
             {
                 // dynamic size
                 int out_patch_size = params.patch_size * ctx->model.hparams.n_merge;
@@ -2999,6 +3502,13 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
                 int y_patch = CLIP_ALIGN(img->ny, out_patch_size) / out_patch_size;
                 n_patches = x_patch * y_patch;
             } break;
+        case PROJECTOR_TYPE_PADDLEOCR:
+            {
+                // dynamic size
+                int n_merge = ctx->model.hparams.n_merge;
+                int stride = n_merge * n_merge;
+                n_patches = CLIP_ALIGN(n_patches, stride) / stride;
+            } break;
         case PROJECTOR_TYPE_PIXTRAL:
         case PROJECTOR_TYPE_LIGHTONOCR:
             {
@@ -3015,6 +3525,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
         case PROJECTOR_TYPE_VOXTRAL:
         case PROJECTOR_TYPE_ULTRAVOX:
         case PROJECTOR_TYPE_QWEN2A:
+        case PROJECTOR_TYPE_MUSIC_FLAMINGO:
             {
                 n_patches = img->nx;
 
@@ -3047,6 +3558,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
             {
                 n_patches += 2; // for BOI and EOI token embeddings
             } break;
+        case PROJECTOR_TYPE_LFM2A:
+            {
+                n_patches = ((((img->nx + 1) / 2) + 1) / 2 + 1) / 2;
+            } break;
         default:
             GGML_ABORT("unsupported projector type");
     }
@@ -3079,7 +3594,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
     }
 
     // build the inference graph
-    ctx->debug_print_tensors.clear();
     ggml_backend_sched_reset(ctx->sched.get());
     ggml_cgraph * gf = clip_image_build_graph(ctx, imgs);
     ggml_backend_sched_alloc_graph(ctx->sched.get(), gf);
@@ -3097,7 +3611,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
     const int pos_w = image_size_width  / patch_size;
     const int pos_h = image_size_height / patch_size;
 
-    const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl
 
     auto get_inp_tensor = [&gf](const char * name) {
         ggml_tensor * inp = ggml_graph_get_tensor(gf, name);
@@ -3243,12 +3756,38 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
                     }
                 }
 
+                set_input_i32("positions", positions);
+            } break;
+        case PROJECTOR_TYPE_PADDLEOCR:
+            {
+                const int merge_ratio = hparams.n_merge;
+                const int pw = image_size_width  / patch_size;
+                const int ph = image_size_height / patch_size;
+                std::vector positions(n_pos * 4);
+                int ptr = 0;
+                // NOTE: same as Qwen-VL, but x and y are swapped
+                for (int y = 0; y < ph; y += merge_ratio) {
+                    for (int dy = 0; dy < 2; dy++) {
+                        for (int x = 0; x < pw; x += merge_ratio) {
+                            for (int dx = 0; dx < 2; dx++) {
+                                positions[                  ptr] = y + dy;
+                                positions[    num_patches + ptr] = x + dx;
+                                positions[2 * num_patches + ptr] = y + dy;
+                                positions[3 * num_patches + ptr] = x + dx;
+                                ptr++;
+                            }
+                        }
+                    }
+                }
+
                 set_input_i32("positions", positions);
             } break;
         case PROJECTOR_TYPE_QWEN25VL:
+        case PROJECTOR_TYPE_YOUTUVL:
             {
                 // pw * ph = number of tokens output by ViT after apply patch merger
                 // ipw * ipw = number of vision token been processed inside ViT
+                const bool use_window_attn = ctx->model.proj_type == PROJECTOR_TYPE_QWEN25VL ? hparams.n_wa_pattern > 0 : !hparams.wa_layer_indexes.empty();
                 const int merge_ratio = 2;
                 const int pw  = image_size_width  / patch_size / merge_ratio;
                 const int ph  = image_size_height / patch_size / merge_ratio;
@@ -3259,7 +3798,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
                 std::vector inv_idx(ph * pw);
 
                 if (use_window_attn) {
-                    const int attn_window_size = 112;
+                    const int attn_window_size = hparams.attn_window_size > 0 ? hparams.attn_window_size : 112;
                     const int grid_window = attn_window_size / patch_size / merge_ratio;
                     int dst = 0;
                     // [num_vision_tokens, num_vision_tokens] attention mask tensor
@@ -3328,6 +3867,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
             } break;
         case PROJECTOR_TYPE_PIXTRAL:
         case PROJECTOR_TYPE_KIMIVL:
+        case PROJECTOR_TYPE_KIMIK25:
         case PROJECTOR_TYPE_LIGHTONOCR:
             {
                 // set the 2D positions
@@ -3376,13 +3916,16 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
                 set_input_i32("patches", patches);
             } break;
         case PROJECTOR_TYPE_GEMMA3:
+        case PROJECTOR_TYPE_GEMMA3NV:
         case PROJECTOR_TYPE_IDEFICS3:
         case PROJECTOR_TYPE_INTERNVL:
+        case PROJECTOR_TYPE_NEMOTRON_V2_VL:
         case PROJECTOR_TYPE_QWEN2A:
         case PROJECTOR_TYPE_GLMA:
         case PROJECTOR_TYPE_ULTRAVOX:
         case PROJECTOR_TYPE_LFM2:
         case PROJECTOR_TYPE_VOXTRAL:
+        case PROJECTOR_TYPE_MUSIC_FLAMINGO:
         case PROJECTOR_TYPE_JANUS_PRO:
         case PROJECTOR_TYPE_COGVLM:
             {
@@ -3405,6 +3948,27 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
                 }
                 set_input_i32("pos_w", pos_data);
             } break;
+        case PROJECTOR_TYPE_LFM2A:
+            {
+                GGML_ASSERT(imgs.entries.size() == 1);
+                const auto n_frames = clip_n_output_tokens(ctx, imgs.entries.front().get());
+
+                auto d_model = 512;
+                auto seq_len = n_frames * 2 - 1;
+                std::vector pos_emb(d_model*seq_len);
+                std::vector inv_freq(d_model / 2);
+                for (size_t i = 0; i < inv_freq.size(); ++i) {
+                    inv_freq[i] = std::exp(-(std::log(10000.0) / (float)d_model) * (2.0f * (float)(i)));
+                }
+                for (int64_t pos = 0; pos < seq_len; ++pos) {
+                    for (size_t i = 0; i < inv_freq.size(); ++i) {
+                        const float ang = (n_frames - pos - 1) * inv_freq[i];
+                        pos_emb[pos*d_model + 2*i + 0] = sinf(ang);  // even
+                        pos_emb[pos*d_model + 2*i + 1] = cosf(ang);  // odd
+                    }
+                }
+                set_input_f32("pos_emb", pos_emb);
+            } break;
         default:
             GGML_ABORT("Unknown projector type");
     }
@@ -3425,18 +3989,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
         return false;
     }
 
-    // print debug nodes
-    if (ctx->debug_graph) {
-        LOG_INF("\n\n---\n\n");
-        LOG_INF("\n\nDebug graph:\n\n");
-        for (ggml_tensor * t : ctx->debug_print_tensors) {
-            std::vector data(ggml_nbytes(t));
-            ggml_backend_tensor_get(t, data.data(), 0, ggml_nbytes(t));
-            print_tensor_shape(t);
-            print_tensor_data(t, data.data(), 3);
-        }
-    }
-
     // the last node is the embedding tensor
     ggml_tensor * embeddings = ggml_graph_node(gf, -1);
 
@@ -3453,6 +4005,47 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
         ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
     }
 
+    // Debug: dump final embeddings if MTMD_DEBUG_EMBEDDINGS is set
+    if (std::getenv("MTMD_DEBUG_EMBEDDINGS") != nullptr) {
+        const int64_t n_embd = embeddings->ne[0];
+        const int64_t n_tokens = embeddings->ne[1];
+        std::vector emb_data(n_embd * n_tokens);
+        ggml_backend_tensor_get(embeddings, emb_data.data(), 0, ggml_nbytes(embeddings));
+
+        LOG_INF("\n=== MTMD_DEBUG_EMBEDDINGS ===\n");
+        LOG_INF("Shape: [%lld, %lld]\n", (long long)n_embd, (long long)n_tokens);
+
+        // Print first few values of first token
+        LOG_INF("Token 0 (first 16 values): ");
+        for (int i = 0; i < std::min((int64_t)16, n_embd); i++) {
+            LOG_INF("%.6f ", emb_data[i]);
+        }
+        LOG_INF("\n");
+
+        // Print last few values of first token
+        if (n_embd > 16) {
+            LOG_INF("Token 0 (last 16 values):  ");
+            for (int64_t i = n_embd - 16; i < n_embd; i++) {
+                LOG_INF("%.6f ", emb_data[i]);
+            }
+            LOG_INF("\n");
+        }
+
+        // Compute and print statistics
+        float sum = 0.0f, sum_sq = 0.0f, min_val = emb_data[0], max_val = emb_data[0];
+        for (size_t i = 0; i < emb_data.size(); i++) {
+            sum += emb_data[i];
+            sum_sq += emb_data[i] * emb_data[i];
+            min_val = std::min(min_val, emb_data[i]);
+            max_val = std::max(max_val, emb_data[i]);
+        }
+        float mean = sum / emb_data.size();
+        float variance = (sum_sq / emb_data.size()) - (mean * mean);
+        LOG_INF("Stats: mean=%.6f, std=%.6f, min=%.6f, max=%.6f, sum=%.6f\n",
+                mean, sqrtf(variance), min_val, max_val, sum);
+        LOG_INF("=== END MTMD_DEBUG_EMBEDDINGS ===\n\n");
+    }
+
     return true;
 }
 
@@ -3475,18 +4068,22 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
         case PROJECTOR_TYPE_QWEN2VL:
         case PROJECTOR_TYPE_QWEN25VL:
         case PROJECTOR_TYPE_JANUS_PRO:
+        case PROJECTOR_TYPE_YOUTUVL:
             return ctx->model.mm_1_b->ne[0];
         case PROJECTOR_TYPE_QWEN3VL:
             // main path + deepstack paths
             return ctx->model.mm_1_b->ne[0] * (1 + ctx->model.n_deepstack_layers);
         case PROJECTOR_TYPE_GEMMA3:
+        case PROJECTOR_TYPE_GEMMA3NV:
             return ctx->model.mm_input_proj_w->ne[0];
         case PROJECTOR_TYPE_IDEFICS3:
             return ctx->model.projection->ne[1];
         case PROJECTOR_TYPE_ULTRAVOX:
         case PROJECTOR_TYPE_VOXTRAL:
+        case PROJECTOR_TYPE_MUSIC_FLAMINGO:
             return ctx->model.mm_2_w->ne[1];
         case PROJECTOR_TYPE_INTERNVL:
+        case PROJECTOR_TYPE_NEMOTRON_V2_VL:
             return ctx->model.mm_3_w->ne[1];
         case PROJECTOR_TYPE_LLAMA4:
             return ctx->model.mm_model_proj->ne[1];
@@ -3496,9 +4093,13 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
             return ctx->model.mm_2_w->ne[1];
         case PROJECTOR_TYPE_LFM2:
         case PROJECTOR_TYPE_KIMIVL:
+        case PROJECTOR_TYPE_PADDLEOCR:
+        case PROJECTOR_TYPE_KIMIK25:
             return ctx->model.mm_2_w->ne[1];
         case PROJECTOR_TYPE_COGVLM:
             return ctx->model.mm_4h_to_h_w->ne[1];
+        case PROJECTOR_TYPE_LFM2A:
+            return ctx->model.position_embeddings->ne[0];
         case PROJECTOR_TYPE_GLM4V:
             return ctx->model.mm_ffn_down_w->ne[1];
         default:
@@ -3507,6 +4108,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
 }
 
 int clip_is_minicpmv(const struct clip_ctx * ctx) {
+    // TODO: remove this function
     if (ctx->proj_type() == PROJECTOR_TYPE_MINICPMV) {
         return ctx->model.hparams.minicpmv_version;
     }
@@ -3514,24 +4116,14 @@ int clip_is_minicpmv(const struct clip_ctx * ctx) {
 }
 
 bool clip_is_glm(const struct clip_ctx * ctx) {
+    // TODO: remove this function
     return ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE;
 }
 
-bool clip_is_mrope(const struct clip_ctx * ctx) {
-    return ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL
-        || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL
-        || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL
-        || ctx->proj_type() == PROJECTOR_TYPE_GLM4V;
-}
-
 bool clip_is_llava(const struct clip_ctx * ctx) {
     return ctx->model.hparams.has_llava_projector;
 }
 
-bool clip_is_gemma3(const struct clip_ctx * ctx) {
-    return ctx->proj_type() == PROJECTOR_TYPE_GEMMA3;
-}
-
 bool clip_has_vision_encoder(const struct clip_ctx * ctx) {
     return ctx->model.modality == CLIP_MODALITY_VISION;
 }
@@ -3541,10 +4133,16 @@ bool clip_has_audio_encoder(const struct clip_ctx * ctx) {
 }
 
 bool clip_has_whisper_encoder(const struct clip_ctx * ctx) {
-    return ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX
-        || ctx->proj_type() == PROJECTOR_TYPE_QWEN2A
-        || ctx->proj_type() == PROJECTOR_TYPE_GLMA
-        || ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL;
+    switch (ctx->proj_type()) {
+        case PROJECTOR_TYPE_ULTRAVOX:
+        case PROJECTOR_TYPE_QWEN2A:
+        case PROJECTOR_TYPE_GLMA:
+        case PROJECTOR_TYPE_VOXTRAL:
+        case PROJECTOR_TYPE_MUSIC_FLAMINGO:
+            return true;
+        default:
+            return false;
+    }
 }
 
 bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
@@ -3586,7 +4184,6 @@ const clip_hparams * clip_get_hparams(const struct clip_ctx * ctx) {
 //
 // API for debugging
 //
-
 void clip_debug_encode(clip_ctx * ctx, int h, int w, float fill_value) {
     clip_image_f32 img;
     img.nx = w;
@@ -3595,9 +4192,6 @@ void clip_debug_encode(clip_ctx * ctx, int h, int w, float fill_value) {
     for (int i = 0; i < h * w * 3; i++) {
         img.buf[i] = static_cast(fill_value);
     }
-    bool cur_debug_graph = ctx->debug_graph;
-    ctx->debug_graph = true;
     clip_image_encode(ctx, 1, &img, nullptr);
-    ctx->debug_graph = cur_debug_graph;
     GGML_ASSERT(img.buf.empty() && "expected, always stop here");
 }
diff --git a/llama/llama.cpp/tools/mtmd/clip.h b/llama/llama.cpp/tools/mtmd/clip.h
index 68a0d6e857e..71b58484d6b 100644
--- a/llama/llama.cpp/tools/mtmd/clip.h
+++ b/llama/llama.cpp/tools/mtmd/clip.h
@@ -1,6 +1,7 @@
 #pragma once
 
 #include "ggml.h"
+#include "mtmd.h"
 
 #include 
 #include 
@@ -37,6 +38,8 @@ struct clip_context_params {
     int image_min_tokens;
     int image_max_tokens;
     bool warmup;
+    ggml_backend_sched_eval_callback cb_eval;
+    void * cb_eval_user_data;
 };
 
 struct clip_init_result {
@@ -104,9 +107,9 @@ bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct
 
 int clip_is_minicpmv(const struct clip_ctx * ctx);
 bool clip_is_glm(const struct clip_ctx * ctx);
-bool clip_is_mrope(const struct clip_ctx * ctx);
 bool clip_is_llava(const struct clip_ctx * ctx);
-bool clip_is_gemma3(const struct clip_ctx * ctx);
+// note for contributor: this clip_is_(model) pattern is deprecated
+//                       do NOT add new functions like this
 
 bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
 
diff --git a/llama/llama.cpp/tools/mtmd/models/conformer.cpp b/llama/llama.cpp/tools/mtmd/models/conformer.cpp
new file mode 100644
index 00000000000..9b1fab48739
--- /dev/null
+++ b/llama/llama.cpp/tools/mtmd/models/conformer.cpp
@@ -0,0 +1,216 @@
+#include "models.h"
+
+ggml_cgraph * clip_graph_conformer::build() {
+    const int n_frames   = img.nx;
+    const int n_pos      = n_frames / 2;
+    const int n_pos_embd = (((((n_frames + 1) / 2) + 1) / 2 + 1) / 2) * 2 - 1;
+    GGML_ASSERT(model.position_embeddings->ne[1] >= n_pos);
+
+    ggml_tensor * pos_emb = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 512, n_pos_embd);
+    ggml_set_name(pos_emb, "pos_emb");
+    ggml_set_input(pos_emb);
+    ggml_build_forward_expand(gf, pos_emb);
+
+    ggml_tensor * inp = build_inp_raw(1);
+
+    auto * cur = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
+
+    // pre encode, conv subsampling
+    {
+        // layer.0 - conv2d
+        cur = ggml_conv_2d(ctx0, model.pre_encode_conv_X_w[0], cur, 2, 2, 1, 1, 1, 1);
+        cur = ggml_add(ctx0, cur, model.pre_encode_conv_X_b[0]);
+        cb(cur, "conformer.pre_encode.conv.{}", 0);
+
+        // layer.1 - relu
+        cur = ggml_relu_inplace(ctx0, cur);
+
+        // layer.2 conv2d dw
+        cur = ggml_conv_2d_dw_direct(ctx0, model.pre_encode_conv_X_w[2], cur, 2, 2, 1, 1, 1, 1);
+        cur = ggml_add(ctx0, cur, model.pre_encode_conv_X_b[2]);
+        cb(cur, "conformer.pre_encode.conv.{}", 2);
+
+        // layer.3 conv2d
+        cur = ggml_conv_2d_direct(ctx0, model.pre_encode_conv_X_w[3], cur, 1, 1, 0, 0, 1, 1);
+        cur = ggml_add(ctx0, cur, model.pre_encode_conv_X_b[3]);
+        cb(cur, "conformer.pre_encode.conv.{}", 3);
+
+        // layer.4 - relu
+        cur = ggml_relu_inplace(ctx0, cur);
+
+        // layer.5 conv2d dw
+        cur = ggml_conv_2d_dw_direct(ctx0, model.pre_encode_conv_X_w[5], cur, 2, 2, 1, 1, 1, 1);
+        cur = ggml_add(ctx0, cur, model.pre_encode_conv_X_b[5]);
+        cb(cur, "conformer.pre_encode.conv.{}", 5);
+
+        // layer.6 conv2d
+        cur = ggml_conv_2d_direct(ctx0, model.pre_encode_conv_X_w[6], cur, 1, 1, 0, 0, 1, 1);
+        cur = ggml_add(ctx0, cur, model.pre_encode_conv_X_b[6]);
+        cb(cur, "conformer.pre_encode.conv.{}", 6);
+
+        // layer.7 - relu
+        cur = ggml_relu_inplace(ctx0, cur);
+
+        // flatten channel and frequency axis
+        cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 2, 1, 3));
+        cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2]);
+
+        // calculate out
+        cur = ggml_mul_mat(ctx0, model.pre_encode_out_w, cur);
+        cur = ggml_add(ctx0, cur, model.pre_encode_out_b);
+        cb(cur, "conformer.pre_encode.out", -1);
+    }
+
+    // pos_emb
+    cb(pos_emb, "pos_emb", -1);
+
+    for (int il = 0; il < hparams.n_layer; il++) {
+        const auto & layer = model.layers[il];
+
+        auto * residual = cur;
+
+        cb(cur, "layer.in", il);
+
+        // feed_forward1
+        cur = build_norm(cur, layer.ff_norm_w, layer.ff_norm_b, NORM_TYPE_NORMAL, 1e-5, il);
+        cb(cur, "conformer.layers.{}.norm_feed_forward1", il);
+
+        cur = build_ffn(cur, layer.ff_up_w, layer.ff_up_b, nullptr, nullptr, layer.ff_down_w, layer.ff_down_b, FFN_SILU,
+                        il);
+        cb(cur, "conformer.layers.{}.feed_forward1.linear2", il);
+
+        const auto fc_factor = 0.5f;
+        residual             = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, fc_factor));
+
+        // self-attention
+        {
+            cur = build_norm(residual, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, 1e-5, il);
+            cb(cur, "conformer.layers.{}.norm_self_att", il);
+
+            ggml_tensor * Qcur     = ggml_mul_mat(ctx0, layer.q_w, cur);
+            Qcur                   = ggml_add(ctx0, Qcur, layer.q_b);
+            Qcur                   = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, Qcur->ne[1]);
+            ggml_tensor * Q_bias_u = ggml_add(ctx0, Qcur, layer.pos_bias_u);
+            Q_bias_u               = ggml_permute(ctx0, Q_bias_u, 0, 2, 1, 3);
+            ggml_tensor * Q_bias_v = ggml_add(ctx0, Qcur, layer.pos_bias_v);
+            Q_bias_v               = ggml_permute(ctx0, Q_bias_v, 0, 2, 1, 3);
+
+            // TODO @ngxson : some cont can/should be removed when ggml_mul_mat support these cases
+            ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
+            Kcur               = ggml_add(ctx0, Kcur, layer.k_b);
+            Kcur               = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, Kcur->ne[1]);
+            Kcur               = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
+
+            ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
+            Vcur               = ggml_add(ctx0, Vcur, layer.v_b);
+            Vcur               = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, Vcur->ne[1]);
+            Vcur               = ggml_cont(ctx0, ggml_permute(ctx0, Vcur, 1, 2, 0, 3));
+
+            // build_attn won't fit due to matrix_ac and matrix_bd separation
+            ggml_tensor * matrix_ac = ggml_mul_mat(ctx0, Q_bias_u, Kcur);
+            matrix_ac               = ggml_cont(ctx0, ggml_permute(ctx0, matrix_ac, 1, 0, 2, 3));
+            cb(matrix_ac, "conformer.layers.{}.self_attn.id3", il);
+
+            auto * p = ggml_mul_mat(ctx0, layer.linear_pos_w, pos_emb);
+            cb(p, "conformer.layers.{}.self_attn.linear_pos", il);
+            p = ggml_reshape_3d(ctx0, p, d_head, n_head, p->ne[1]);
+            p = ggml_permute(ctx0, p, 0, 2, 1, 3);
+
+            auto * matrix_bd = ggml_mul_mat(ctx0, Q_bias_v, p);
+            matrix_bd        = ggml_cont(ctx0, ggml_permute(ctx0, matrix_bd, 1, 0, 2, 3));
+
+            // rel shift
+            {
+                const auto pos_len = matrix_bd->ne[0];
+                const auto q_len   = matrix_bd->ne[1];
+                const auto h       = matrix_bd->ne[2];
+                matrix_bd          = ggml_pad(ctx0, matrix_bd, 1, 0, 0, 0);
+                matrix_bd          = ggml_roll(ctx0, matrix_bd, 1, 0, 0, 0);
+                matrix_bd          = ggml_reshape_3d(ctx0, matrix_bd, q_len, pos_len + 1, h);
+                matrix_bd          = ggml_view_3d(ctx0, matrix_bd, q_len, pos_len, h, matrix_bd->nb[1],
+                                                        matrix_bd->nb[2], matrix_bd->nb[0] * q_len);
+                matrix_bd          = ggml_cont_3d(ctx0, matrix_bd, pos_len, q_len, h);
+            }
+
+            matrix_bd     = ggml_view_3d(ctx0, matrix_bd, matrix_ac->ne[0], matrix_bd->ne[1],
+                                               matrix_bd->ne[2], matrix_bd->nb[1], matrix_bd->nb[2], 0);
+            auto * scores = ggml_add(ctx0, matrix_ac, matrix_bd);
+            scores        = ggml_scale(ctx0, scores, 1.0f / std::sqrt(d_head));
+            cb(scores, "conformer.layers.{}.self_attn.id0", il);
+
+            ggml_tensor * attn = ggml_soft_max(ctx0, scores);
+            ggml_tensor * x    = ggml_mul_mat(ctx0, attn, Vcur);
+            x                  = ggml_permute(ctx0, x, 2, 0, 1, 3);
+            x                  = ggml_cont_2d(ctx0, x, x->ne[0] * x->ne[1], x->ne[2]);
+
+            ggml_tensor * out = ggml_mul_mat(ctx0, layer.o_w, x);
+            out               = ggml_add(ctx0, out, layer.o_b);
+            cb(out, "conformer.layers.{}.self_attn.linear_out", il);
+
+            cur = out;
+        }
+
+        residual = ggml_add(ctx0, residual, cur);
+        cur      = build_norm(residual, layer.norm_conv_w, layer.norm_conv_b, NORM_TYPE_NORMAL, 1e-5, il);
+        cb(cur, "conformer.layers.{}.norm_conv", il);
+
+        // conv
+        {
+            auto * x = cur;
+            x = ggml_mul_mat(ctx0, layer.conv_pw1_w, x);
+            x = ggml_add(ctx0, x, layer.conv_pw1_b);
+            cb(x, "conformer.layers.{}.conv.pointwise_conv1", il);
+
+            // ggml_glu doesn't support sigmoid
+            // TODO @ngxson : support this ops in ggml
+            {
+                int64_t       d    = x->ne[0] / 2;
+                ggml_tensor * gate = ggml_sigmoid(ctx0, ggml_view_2d(ctx0, x, d, x->ne[1], x->nb[1], d * x->nb[0]));
+                x                  = ggml_mul(ctx0, ggml_view_2d(ctx0, x, d, x->ne[1], x->nb[1], 0), gate);
+                x                  = ggml_cont(ctx0, ggml_transpose(ctx0, x));
+            }
+
+            // use ggml_ssm_conv for f32 precision
+            x = ggml_pad(ctx0, x, 4, 0, 0, 0);
+            x = ggml_roll(ctx0, x, 4, 0, 0, 0);
+            x = ggml_pad(ctx0, x, 4, 0, 0, 0);
+            x = ggml_ssm_conv(ctx0, x, layer.conv_dw_w);
+            x = ggml_add(ctx0, x, layer.conv_dw_b);
+
+            x = ggml_add(ctx0, ggml_mul(ctx0, x, layer.conv_norm_w), layer.conv_norm_b);
+            x = ggml_silu(ctx0, x);
+
+            // pointwise_conv2
+            x = ggml_mul_mat(ctx0, layer.conv_pw2_w, x);
+            x = ggml_add(ctx0, x, layer.conv_pw2_b);
+
+            cur = x;
+        }
+
+        residual = ggml_add(ctx0, residual, cur);
+
+        cur = build_norm(residual, layer.ff_norm_1_w, layer.ff_norm_1_b, NORM_TYPE_NORMAL, 1e-5, il);
+        cb(cur, "conformer.layers.{}.norm_feed_forward2", il);
+
+        cur = build_ffn(cur, layer.ff_up_1_w, layer.ff_up_1_b, nullptr, nullptr, layer.ff_down_1_w, layer.ff_down_1_b,
+                        FFN_SILU, il);  // TODO(tarek): read activation for ffn from hparams
+        cb(cur, "conformer.layers.{}.feed_forward2.linear2", il);
+
+        residual = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, fc_factor));
+        cb(residual, "conformer.layers.{}.conv.id", il);
+
+        cur = build_norm(residual, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, 1e-5, il);
+        cb(cur, "conformer.layers.{}.norm_out", il);
+    }
+
+    // audio adapter
+    cur = build_norm(cur, model.mm_0_w, model.mm_0_b, NORM_TYPE_NORMAL, 1e-5, -1);
+    cb(cur, "audio_adapter.model.{}", 0);
+    cur = build_ffn(cur, model.mm_1_w, model.mm_1_b, nullptr, nullptr, model.mm_3_w, model.mm_3_b, FFN_GELU_ERF, -1);
+
+    cb(cur, "projected", -1);
+
+    ggml_build_forward_expand(gf, cur);
+
+    return gf;
+}
diff --git a/llama/llama.cpp/tools/mtmd/models/glm4v.cpp b/llama/llama.cpp/tools/mtmd/models/glm4v.cpp
index f39b6922eb5..6f52df41ab0 100644
--- a/llama/llama.cpp/tools/mtmd/models/glm4v.cpp
+++ b/llama/llama.cpp/tools/mtmd/models/glm4v.cpp
@@ -2,7 +2,6 @@
 
 ggml_cgraph * clip_graph_glm4v::build() {
     GGML_ASSERT(model.patch_bias != nullptr);
-    GGML_ASSERT(model.position_embeddings != nullptr);
     GGML_ASSERT(model.class_embedding == nullptr);
 
     const int batch_size = 1;
@@ -45,19 +44,22 @@ ggml_cgraph * clip_graph_glm4v::build() {
     // pos-conv norm
     inp = build_norm(inp, model.norm_embd_w, model.norm_embd_b, norm_t, eps, -1);
 
-    // calculate absolute position embedding and apply
-    ggml_tensor * learned_pos_embd = resize_position_embeddings(GGML_SCALE_MODE_BICUBIC);
-    learned_pos_embd = ggml_cont_4d(
-        ctx0, learned_pos_embd,
-        n_embd * 2, n_patches_x / 2, n_patches_y, batch_size);
-    learned_pos_embd = ggml_reshape_4d(
-        ctx0, learned_pos_embd,
-        n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2));
-    learned_pos_embd = ggml_permute(ctx0, learned_pos_embd, 0, 2, 1, 3);
-    learned_pos_embd = ggml_cont_3d(
-        ctx0, learned_pos_embd,
-        n_embd, n_patches_x * n_patches_y, batch_size);
-    cb(learned_pos_embd, "learned_pos_embd", -1);
+    ggml_tensor * learned_pos_embd = nullptr;
+    // Note: GLM-OCR does not have learned position embeddings
+    if (model.position_embeddings != nullptr) {
+        learned_pos_embd = resize_position_embeddings(GGML_SCALE_MODE_BICUBIC);
+        learned_pos_embd = ggml_cont_4d(
+            ctx0, learned_pos_embd,
+            n_embd * 2, n_patches_x / 2, n_patches_y, batch_size);
+        learned_pos_embd = ggml_reshape_4d(
+            ctx0, learned_pos_embd,
+            n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2));
+        learned_pos_embd = ggml_permute(ctx0, learned_pos_embd, 0, 2, 1, 3);
+        learned_pos_embd = ggml_cont_3d(
+            ctx0, learned_pos_embd,
+            n_embd, n_patches_x * n_patches_y, batch_size);
+        cb(learned_pos_embd, "learned_pos_embd", -1);
+    }
 
     auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
         return ggml_rope_multi(
diff --git a/llama/llama.cpp/tools/mtmd/models/kimik25.cpp b/llama/llama.cpp/tools/mtmd/models/kimik25.cpp
new file mode 100644
index 00000000000..cf9f27f63af
--- /dev/null
+++ b/llama/llama.cpp/tools/mtmd/models/kimik25.cpp
@@ -0,0 +1,101 @@
+#include "models.h"
+#include 
+#include 
+
+// note: this is similar to clip_graph::resize_position_embeddings, major difference is having
+// the w/h in ne[1] and ne[2] instead of assuming with sqrt. Could try storing the tensor in 2D instead
+// with a w*h? Also the permute is a bit different at (2, 1, 0, 3) instead of (2, 0, 1, 3).
+ggml_tensor * clip_graph_kimik25::resize_position_embeddings_3d(uint32_t interpolation_mode) {
+    ggml_tensor * pos_embd = model.position_embeddings;
+    const int height       = img.ny / patch_size;
+    const int width        = img.nx / patch_size;
+    const uint32_t mode    = interpolation_mode;
+
+    GGML_ASSERT(pos_embd);
+
+    const int64_t stored_c = pos_embd->ne[0];  // C = 1152
+    const int64_t orig_w = pos_embd->ne[1];    // W = 64
+    const int64_t orig_h = pos_embd->ne[2];    // H = 64
+
+    GGML_ASSERT(stored_c == n_embd);
+
+    if (height == (int)orig_h && width == (int)orig_w) {
+        // No interpolation needed, just flatten to [C, H*W]
+        return ggml_cont_2d(ctx0, pos_embd, n_embd, width * height);
+    }
+
+    pos_embd = ggml_permute(ctx0, pos_embd, 2, 1, 0, 3);
+    pos_embd = ggml_interpolate(ctx0, pos_embd, height, width, n_embd, 1, mode);
+    pos_embd = ggml_permute(ctx0, pos_embd, 2, 1, 0, 3);
+    pos_embd = ggml_cont_2d(ctx0, pos_embd, n_embd, width * height);
+    return pos_embd;
+}
+
+ggml_cgraph * clip_graph_kimik25::build() {
+    ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
+    ggml_set_name(pos_h, "pos_h");
+    ggml_set_input(pos_h);
+
+    ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
+    ggml_set_name(pos_w, "pos_w");
+    ggml_set_input(pos_w);
+
+    ggml_tensor * learned_pos_embd = resize_position_embeddings_3d(GGML_SCALE_MODE_BICUBIC);
+
+    // Kimi-K2.5 uses interleaved 2D RoPE pattern natively, but
+    // Q / K are permuted during conversion to use split format.
+    auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
+        cur = build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false);
+        return cur;
+    };
+
+    ggml_tensor * inp = build_inp();
+
+    // I don't know why, but doing this in the build_vit lead to the ggml_add not occurring?
+    // Doing it manually here does work.
+    inp = ggml_add(ctx0, inp, learned_pos_embd);
+
+    ggml_tensor * cur = build_vit(
+                            inp, n_patches,
+                            NORM_TYPE_NORMAL,
+                            hparams.ffn_op,
+                            nullptr,
+                            add_pos);
+
+    cb(cur, "vit_out", -1);
+
+    {
+        // patch_merger
+        const int scale_factor = model.hparams.n_merge;
+        cur = build_patch_merge_permute(cur, scale_factor);
+
+        // projection norm
+        int proj_inp_dim = cur->ne[0];
+        int n_merged_patches = cur->ne[1];
+        cur = ggml_view_2d(ctx0, cur,
+            n_embd, n_merged_patches * scale_factor * scale_factor,
+            ggml_row_size(cur->type, n_embd), 0);
+        cur = ggml_norm(ctx0, cur, hparams.eps);
+        cur = ggml_mul(ctx0, cur, model.mm_input_norm_w);
+        cur = ggml_add(ctx0, cur, model.mm_input_norm_b);
+        cur = ggml_view_2d(ctx0, cur,
+            proj_inp_dim, n_merged_patches,
+            ggml_row_size(cur->type, proj_inp_dim), 0);
+        cb(cur, "proj_inp_normed", -1);
+
+        // projection mlp
+        cur = build_ffn(cur,
+            model.mm_1_w, model.mm_1_b,
+            nullptr, nullptr,
+            model.mm_2_w, model.mm_2_b,
+            FFN_GELU,
+            -1);
+
+        cb(cur, "proj_out", -1);
+    }
+
+    // build the graph
+    ggml_build_forward_expand(gf, cur);
+
+    return gf;
+}
diff --git a/llama/llama.cpp/tools/mtmd/models/mobilenetv5.cpp b/llama/llama.cpp/tools/mtmd/models/mobilenetv5.cpp
new file mode 100644
index 00000000000..593afa1ddce
--- /dev/null
+++ b/llama/llama.cpp/tools/mtmd/models/mobilenetv5.cpp
@@ -0,0 +1,451 @@
+#include "models.h"
+
+// Helpers for MobileNetV5 Blocks
+// RMS Norm 2D - normalizes over channels for each spatial position
+ggml_tensor * clip_graph_mobilenetv5::rms_norm_2d(ggml_tensor * inp, ggml_tensor * weight, float eps) {
+    // inp: [W, H, C, B]
+
+    ggml_tensor * cur = ggml_permute(ctx0, inp, 2, 1, 0, 3);
+    cur = ggml_cont(ctx0, cur);
+    cur = ggml_rms_norm(ctx0, cur, eps);
+
+    if (weight) {
+        cur = ggml_mul(ctx0, cur, weight);
+    }
+
+    cur = ggml_permute(ctx0, cur, 2, 1, 0, 3);
+    cur = ggml_cont(ctx0, cur);
+
+    return cur;
+}
+
+// Conv2dSame padding - asymmetric SAME padding like PyTorch/TF
+ggml_tensor* clip_graph_mobilenetv5::pad_same_2d(ggml_tensor* inp, int kernel_h, int kernel_w, int stride_h, int stride_w, int dilation_h, int dilation_w) {
+    const int64_t ih = inp->ne[1];  // height
+    const int64_t iw = inp->ne[0];  // width
+
+    // Calculate output size (ceil division)
+    const int64_t oh = (ih + stride_h - 1) / stride_h;
+    const int64_t ow = (iw + stride_w - 1) / stride_w;
+
+    // Calculate padding needed
+    const int64_t pad_h = std::max((int64_t)0, (oh - 1) * stride_h + (kernel_h - 1) * dilation_h + 1 - ih);
+    const int64_t pad_w = std::max((int64_t)0, (ow - 1) * stride_w + (kernel_w - 1) * dilation_w + 1 - iw);
+
+    // Split padding asymmetrically
+    const int pad_h_top = pad_h / 2;
+    const int pad_h_bottom = pad_h - pad_h_top;
+    const int pad_w_left = pad_w / 2;
+    const int pad_w_right = pad_w - pad_w_left;
+
+    // Apply padding if needed
+    // ggml_pad_ext: (ctx, tensor, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3)
+    // For [W, H, C, B]: p0=width, p1=height, p2=channels, p3=batch
+    if (pad_h > 0 || pad_w > 0) {
+        inp = ggml_pad_ext(ctx0, inp,
+            pad_w_left, pad_w_right,     // width padding (dim 0)
+            pad_h_top, pad_h_bottom,      // height padding (dim 1)
+            0, 0,                         // no channel padding (dim 2)
+            0, 0);                        // no batch padding (dim 3)
+    }
+
+    return inp;
+}
+
+
+// Edge Residual Block (Stage 0)
+ggml_tensor * clip_graph_mobilenetv5::build_edge_residual(ggml_tensor * inp, const mobilenetv5_block & block, int stride) {
+    ggml_tensor * cur = inp;
+
+    // 1. Expansion Conv (3x3)
+    if (stride == 2) {
+        // Case: Downsampling (Block 0)
+        // Replicates Conv2dSame(kernel=3, stride=2)
+        cur = pad_same_2d(cur, 3, 3, stride, stride);
+        cur = ggml_conv_2d_direct(ctx0, block.s0_conv_exp_w, cur, stride, stride, 0, 0, 1, 1);
+    } else {
+        // Case: Normal 3x3 Block (Block 1, 2)
+        // Replicates Conv2d(kernel=3, stride=1, padding=1)
+        cur = ggml_conv_2d_direct(ctx0, block.s0_conv_exp_w, cur, stride, stride, 1, 1, 1, 1);
+    }
+
+    // BN + Activation
+    if (block.s0_bn1_w) cur = rms_norm_2d(cur, block.s0_bn1_w);
+    cur = ggml_gelu(ctx0, cur);
+
+    // 2. Pointwise Linear Conv (1x1)
+    // 1x1 Convs usually have padding=0 and stride=1
+    cur = ggml_conv_2d_direct(ctx0, block.s0_conv_pwl_w, cur, 1, 1, 0, 0, 1, 1);
+    if (block.s0_bn2_w) cur = rms_norm_2d(cur, block.s0_bn2_w);
+
+    // 3. Residual Connection
+    // Only apply residual if spatial dimensions and channels match (stride 1)
+    if (stride == 1 && inp->ne[2] == cur->ne[2] && inp->ne[0] == cur->ne[0]) {
+        cur = ggml_add(ctx0, cur, inp);
+    }
+
+    return cur;
+}
+
+// Universal Inverted Residual Block (Stage 1+)
+ggml_tensor * clip_graph_mobilenetv5::build_inverted_residual(ggml_tensor * inp, const mobilenetv5_block & block, int stride) {
+    ggml_tensor * cur = inp;
+
+    // 1. Depthwise Start (Optional)
+    // NOTE: dw_start always has stride=1 (no downsampling here)
+    if (block.dw_start_w) {
+        int k = block.dw_start_w->ne[0]; // 3 or 5
+        int p = k / 2;
+        cur = ggml_conv_2d_dw(ctx0, block.dw_start_w, cur, 1, 1, p, p, 1, 1);
+        if (block.dw_start_bn_w) cur = rms_norm_2d(cur, block.dw_start_bn_w);
+    }
+
+    // 2. Pointwise Expansion (1x1)
+    if (block.pw_exp_w) {
+        // Standard 1x1 conv, pad=0, stride=1
+        cur = ggml_conv_2d_direct(ctx0, block.pw_exp_w, cur, 1, 1, 0, 0, 1, 1);
+        if (block.pw_exp_bn_w) cur = rms_norm_2d(cur, block.pw_exp_bn_w);
+        cur = ggml_gelu(ctx0, cur);
+    }
+
+    // 3. Depthwise Mid (Optional)
+    // NOTE: dw_mid is where downsampling happens (stride=2 for first block of stage)
+    if (block.dw_mid_w) {
+        int k = block.dw_mid_w->ne[0]; // 3 or 5
+
+        if (stride > 1) {
+            // Case: Stride 2 (Downsample) -> Use Asymmetric "Same" Padding
+            cur = pad_same_2d(cur, k, k, stride, stride);
+            cur = ggml_conv_2d_dw(ctx0, block.dw_mid_w, cur, stride, stride, 0, 0, 1, 1); // pad=0
+        } else {
+            // Case: Stride 1 -> Use Standard Symmetric Padding
+            int p = k / 2;
+            cur = ggml_conv_2d_dw(ctx0, block.dw_mid_w, cur, stride, stride, p, p, 1, 1);
+        }
+
+        if (block.dw_mid_bn_w) cur = rms_norm_2d(cur, block.dw_mid_bn_w);
+        cur = ggml_gelu(ctx0, cur);
+    }
+
+    // 4. Pointwise Projection (1x1)
+    if (block.pw_proj_w) {
+        cur = ggml_conv_2d_direct(ctx0, block.pw_proj_w, cur, 1, 1, 0, 0, 1, 1);
+        if (block.pw_proj_bn_w) cur = rms_norm_2d(cur, block.pw_proj_bn_w);
+    }
+
+    // Apply Layer Scaling if present
+    if (block.layer_scale_w) {
+        cur = ggml_mul(ctx0, cur, block.layer_scale_w);
+    }
+
+    // 5. Residual Connection
+    bool same_spatial = (inp->ne[0] == cur->ne[0]) && (inp->ne[1] == cur->ne[1]);
+    bool same_channel = (inp->ne[2] == cur->ne[2]);
+    if (same_spatial && same_channel) {
+        cur = ggml_add(ctx0, cur, inp);
+    }
+
+    return cur;
+}
+
+// Attention Block (MQA)
+ggml_tensor * clip_graph_mobilenetv5::build_mobilenet_attn(ggml_tensor * inp, const mobilenetv5_block & block) {
+    ggml_tensor * cur = inp;
+
+    // Norm
+    if (block.attn_norm_w) {
+        cur = rms_norm_2d(cur, block.attn_norm_w, 1e-6f);
+    }
+
+    // 1. Q Calculation
+    ggml_tensor * q = ggml_conv_2d_direct(ctx0, block.attn_q_w, cur, 1, 1, 0, 0, 1, 1);
+
+    // 2. K Calculation (Downsampled)
+    // Uses Conv2dSame(640, 640, kernel_size=(3, 3), stride=(2, 2), groups=640)
+    ggml_tensor * k_inp = cur;
+    if (block.attn_k_dw_w) {
+        int k_size = block.attn_k_dw_w->ne[0];  // Usually 3
+        k_inp = pad_same_2d(cur, k_size, k_size, 2, 2);  // Apply SAME padding
+        k_inp = ggml_conv_2d_dw(ctx0, block.attn_k_dw_w, k_inp, 2, 2, 0, 0, 1, 1);  // padding=0
+        if (block.attn_k_norm_w) {
+            k_inp = rms_norm_2d(k_inp, block.attn_k_norm_w, 1e-6f);
+        }
+    }
+    ggml_tensor * k = ggml_conv_2d_direct(ctx0, block.attn_k_w, k_inp, 1, 1, 0, 0, 1, 1);
+
+    // 3. V Calculation (Downsampled)
+    // Uses Conv2dSame(640, 640, kernel_size=(3, 3), stride=(2, 2), groups=640)
+    ggml_tensor * v_inp = cur;
+    if (block.attn_v_dw_w) {
+        int v_size = block.attn_v_dw_w->ne[0];  // Usually 3
+        v_inp = pad_same_2d(cur, v_size, v_size, 2, 2);  // Apply SAME padding
+        v_inp = ggml_conv_2d_dw(ctx0, block.attn_v_dw_w, v_inp, 2, 2, 0, 0, 1, 1);  // padding=0
+        if (block.attn_v_norm_w) {
+            v_inp = rms_norm_2d(v_inp, block.attn_v_norm_w, 1e-6f);
+        }
+    }
+    ggml_tensor * v = ggml_conv_2d_direct(ctx0, block.attn_v_w, v_inp, 1, 1, 0, 0, 1, 1);
+
+    const int W = cur->ne[0]; const int H = cur->ne[1]; const int B = cur->ne[3];
+    const int D = k->ne[2]; // Head dimension
+    const int n_head = q->ne[2] / D;
+    const int N = W * H;
+
+    // Process Q: [W, H, D*n_head, B] -> [D, N, n_head, B]
+    q = ggml_reshape_3d(ctx0, q, N, D*n_head, B);
+    q = ggml_reshape_4d(ctx0, q, N, D, n_head, B);
+    q = ggml_permute(ctx0, q, 1, 0, 2, 3); // [D, N, n_head, B]
+    q = ggml_cont(ctx0, q);
+
+    const int Wk = k->ne[0]; const int Hk = k->ne[1];
+    const int M = Wk * Hk;
+
+    // Process K: [Wk, Hk, D, B] -> [D, M, 1, B]
+    k = ggml_reshape_3d(ctx0, k, M, D, B);
+    k = ggml_reshape_4d(ctx0, k, M, D, 1, B);
+    k = ggml_permute(ctx0, k, 1, 0, 2, 3); // [D, M, 1, B]
+    k = ggml_cont(ctx0, k);
+
+    // Process V: [Wk, Hk, D, B] -> [M, D, 1, B]
+    v = ggml_reshape_3d(ctx0, v, M, D, B);
+    v = ggml_reshape_4d(ctx0, v, M, D, 1, B);
+    v = ggml_cont(ctx0, v); // [M, D, 1, B]
+
+    // Multi-Query Attention
+    float scale = 1.0f / sqrtf((float)D);
+
+    // Step 1: Compute Q @ K.T
+    ggml_tensor * scores = ggml_mul_mat(ctx0, k, q);
+
+    scores = ggml_scale(ctx0, scores, scale);
+
+    scores = ggml_soft_max(ctx0, scores);
+
+    ggml_tensor * kqv = ggml_mul_mat(ctx0, v, scores);
+
+    kqv = ggml_permute(ctx0, kqv, 1, 0, 2, 3);
+    kqv = ggml_cont(ctx0, kqv);
+
+
+    kqv = ggml_reshape_3d(ctx0, kqv, N, D * n_head, B);
+    kqv = ggml_reshape_4d(ctx0, kqv, W, H, D * n_head, B);
+    kqv = ggml_cont(ctx0, kqv);
+
+    // Output projection
+    cur = ggml_conv_2d_direct(ctx0, block.attn_o_w, kqv, 1, 1, 0, 0, 1, 1);
+
+    // Residual & Layer Scale
+    if (inp->ne[0] == cur->ne[0] && inp->ne[2] == cur->ne[2]) {
+        if (block.layer_scale_w) {
+            cur = ggml_mul(ctx0, cur, block.layer_scale_w);
+        }
+        cur = ggml_add(ctx0, cur, inp);
+    }
+
+    return cur;
+}
+
+ggml_cgraph * clip_graph_mobilenetv5::build() {
+    ggml_tensor * inp = build_inp_raw();
+
+    // 1. Stem - Conv2dSame(3, 64, kernel_size=(3, 3), stride=(2, 2))
+    ggml_tensor * cur = pad_same_2d(inp, 3, 3, 2, 2);  // Apply SAME padding
+
+    cur = ggml_conv_2d_direct(ctx0, model.mobilenet_stem_conv_w, cur, 2, 2, 0, 0, 1, 1);  // padding=0
+    if (model.mobilenet_stem_conv_b) {
+        cur = ggml_add(ctx0, cur, model.mobilenet_stem_conv_b);
+    }
+    if (model.mobilenet_stem_norm_w) cur = rms_norm_2d(cur, model.mobilenet_stem_norm_w);
+    cur = ggml_gelu(ctx0, cur);
+
+
+    // 2. Blocks
+    std::vector intermediate_features;
+    const int total_blocks = model.mobilenet_blocks.size();
+
+    auto is_stage_start = [&](int i) {
+        if (i == 0) return true;
+        for (int end_idx : model.mobilenet_stage_ends) {
+            if (i == end_idx + 1) return true;
+        }
+        return false;
+    };
+
+    auto is_fusion_point = [&](int i) {
+        if (model.mobilenet_stage_ends.size() >= 4) {
+                if (i == model.mobilenet_stage_ends[2]) return true; // End of Stage 2
+                if (i == model.mobilenet_stage_ends[3]) return true; // End of Stage 3
+        } else {
+            if (i == total_blocks - 1) return true;
+        }
+        return false;
+    };
+
+    for (int i = 0; i < total_blocks; i++) {
+        const auto & block = model.mobilenet_blocks[i];
+        int stride = is_stage_start(i) ? 2 : 1;
+
+        if (block.s0_conv_exp_w)      cur = build_edge_residual(cur, block, stride);
+        else if (block.attn_q_w)      cur = build_mobilenet_attn(cur, block);
+        else                          cur = build_inverted_residual(cur, block, stride);
+
+        if (is_fusion_point(i)) {
+
+            intermediate_features.push_back(cur);
+        }
+    }
+
+    // 3. Multi-Scale Fusion Adapter (MSFA)
+    if (!intermediate_features.empty()) {
+
+        // A. Reference Resolution: PyTorch implementation uses inputs[0]
+        // We assume intermediate_features[0] is the "High Resolution" target.
+        // In MobileNet designs, this is typically the feature map with the smallest stride (e.g. 32x32).
+        ggml_tensor* target_feat = intermediate_features[0];
+        int high_res_w = target_feat->ne[0];
+        int high_res_h = target_feat->ne[1];
+
+        std::vector resized_feats;
+
+        // B. Resize inputs to match inputs[0] (High Resolution)
+        for (auto feat : intermediate_features) {
+            int feat_w = feat->ne[0];
+            int feat_h = feat->ne[1];
+
+            // PyTorch: if feat_size < high_resolution: interpolate
+            if (feat_w < high_res_w || feat_h < high_res_h) {
+                // Calculate scale factor.
+                // Note: PyTorch 'nearest' works on arbitrary float scales.
+                // ggml_upscale generally takes integer factors or target sizes depending on helper.
+                // Assuming standard power-of-2 scaling (e.g. 16 -> 32 means scale=2).
+                int scale_w = high_res_w / feat_w;
+                // int scale_h = high_res_h / feat_h;
+
+                // Safety check for non-integer scaling if strictly replicating
+                GGML_ASSERT(high_res_w % feat_w == 0);
+
+                // Upsample (Nearest Neighbor)
+                // 2 is the scale factor
+                feat = ggml_upscale(ctx0, feat, scale_w, ggml_scale_mode::GGML_SCALE_MODE_NEAREST);
+            }
+            resized_feats.push_back(feat);
+        }
+
+        // C. Concatenate at High Resolution (Channel Dim = 2 in ggml)
+        cur = resized_feats[0];
+        for (size_t k = 1; k < resized_feats.size(); ++k) {
+            cur = ggml_concat(ctx0, cur, resized_feats[k], 2);
+        }
+
+        // D. FFN (UniversalInvertedResidual)
+        // Structure: Expand Conv -> Norm -> GELU -> Project Conv -> Norm
+
+        // 1. Expansion
+        if (model.msfa_ffn_expand_w) {
+            // 1x1 Conv
+            cur = ggml_conv_2d_direct(ctx0, model.msfa_ffn_expand_w, cur, 1, 1, 0, 0, 1, 1);
+
+            if (model.msfa_ffn_expand_bn) {
+                cur = rms_norm_2d(cur, model.msfa_ffn_expand_bn);
+            }
+
+            cur = ggml_gelu(ctx0, cur);
+
+        }
+
+        // 2. Projection (No DW because kernel_size=0)
+        if (model.msfa_ffn_project_w) {
+            // 1x1 Conv
+            cur = ggml_conv_2d_direct(ctx0, model.msfa_ffn_project_w, cur, 1, 1, 0, 0, 1, 1);
+
+            // UniversalInvertedResidual typically has a norm after projection
+            if (model.msfa_ffn_project_bn) {
+                cur = rms_norm_2d(cur, model.msfa_ffn_project_bn);
+            }
+
+        }
+
+        // E. Final Downsample to Target Resolution (Output Resolution)
+        // PyTorch: matches self.output_resolution (e.g. 16x16)
+        const int target_out_res = 16;
+        int current_w = cur->ne[0];
+
+        if (current_w > target_out_res) {
+            int s = current_w / target_out_res;
+
+            GGML_ASSERT(current_w % target_out_res == 0);
+
+            // Avg Pool: Kernel=s, Stride=s
+            cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_AVG, s, s, s, s, 0, 0);
+
+        }
+
+        // F. Final Norm
+        if (model.msfa_concat_norm_w) {
+            cur = rms_norm_2d(cur, model.msfa_concat_norm_w);
+
+        }
+    }
+
+    // 4. Gemma 3n Multimodal Projection (Embedder)
+    // Input: 'cur' is [Width, Height, Channels, Batch]
+    int W = cur->ne[0];
+    int H = cur->ne[1];
+    int C = cur->ne[2];
+    int B = cur->ne[3];
+
+    GGML_ASSERT(C == hparams.n_embd);
+
+    // 1. Permute and Flatten to [Channels, Tokens, Batch]
+    // PyTorch expects (Batch, Seq, Hidden), GGML usually processes (Hidden, Seq, Batch)
+    cur = ggml_permute(ctx0, cur, 2, 1, 0, 3); // -> [C, H, W, B]
+    cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); // -> [C, W, H, B]
+    cur = ggml_cont(ctx0, cur);
+    cur = ggml_reshape_3d(ctx0, cur, C, W*H, B);
+    cur = ggml_cont(ctx0, cur);
+
+
+    // 2. FEATURE SCALING
+    // PyTorch: vision_outputs *= self.config.vision_config.hidden_size**0.5
+    const float scale_factor = sqrtf((float)C);
+    cur = ggml_scale(ctx0, cur, scale_factor);
+
+
+    // 3. SOFT EMBEDDING NORM
+    // PyTorch: self._norm(x) * self.weight
+    // We must normalize regardless, then multiply if weight exists.
+    {
+        const float eps = 1e-6f; // Gemma3n uses 1e-6
+        cur = ggml_rms_norm(ctx0, cur, eps);
+
+        if (model.mm_soft_emb_norm_w) {
+            // Weight shape is (2048,) -> Element-wise broadcast multiply
+            cur = ggml_mul(ctx0, cur, model.mm_soft_emb_norm_w);
+        }
+
+    }
+
+    // 4. PROJECTION
+    // PyTorch: embedding_projection = nn.Linear(vision_hidden, text_hidden, bias=False)
+    // Weight stored as [out_features, in_features] = [text_hidden_size, vision_hidden_size]
+    if (model.mm_input_proj_w) {
+        cur = ggml_mul_mat(ctx0, model.mm_input_proj_w, cur);
+    }
+
+    // 5. POST PROJECTION NORM
+    // PyTorch: embedding_post_projection_norm = Gemma3nRMSNorm(..., with_scale=False)
+    // with_scale=False means weight is registered as buffer with value 1.0
+    // So output = rms_norm(x) * 1.0 = rms_norm(x), magnitude ~1
+    {
+        const float eps = 1e-6f;
+        cur = ggml_rms_norm(ctx0, cur, eps);
+
+        if (model.mm_post_proj_norm_w) {
+            // If weight is loaded, multiply (should be ~1.0 anyway)
+            cur = ggml_mul(ctx0, cur, model.mm_post_proj_norm_w);
+        }
+    }
+
+    ggml_build_forward_expand(gf, cur);
+    return gf;
+}
diff --git a/llama/llama.cpp/tools/mtmd/models/models.h b/llama/llama.cpp/tools/mtmd/models/models.h
index 0496d6b22f1..aff222c71d3 100644
--- a/llama/llama.cpp/tools/mtmd/models/models.h
+++ b/llama/llama.cpp/tools/mtmd/models/models.h
@@ -2,6 +2,11 @@
 
 #include "../clip-graph.h"
 
+/*
+ * IMPORTANT: The mtmd module does NOT accept pull requests that are fully or predominantly AI-generated.
+ * We encourage human contributors to ensure the quality and reliability of the codebase.
+ */
+
 struct clip_graph_siglip : clip_graph {
     clip_graph_siglip(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
     ggml_cgraph * build() override;
@@ -22,6 +27,11 @@ struct clip_graph_qwen3vl : clip_graph {
     ggml_cgraph * build() override;
 };
 
+struct clip_graph_youtuvl : clip_graph {
+    clip_graph_youtuvl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
+    ggml_cgraph * build() override;
+};
+
 struct clip_graph_minicpmv : clip_graph {
     clip_graph_minicpmv(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
     ggml_cgraph * build() override;
@@ -32,6 +42,11 @@ struct clip_graph_internvl : clip_graph {
     ggml_cgraph * build() override;
 };
 
+struct clip_graph_nemotron_v2_vl : clip_graph {
+    clip_graph_nemotron_v2_vl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
+    ggml_cgraph * build() override;
+};
+
 struct clip_graph_llama4 : clip_graph {
     clip_graph_llama4(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
     ggml_cgraph * build() override;
@@ -42,6 +57,11 @@ struct clip_graph_kimivl : clip_graph {
     ggml_cgraph * build() override;
 };
 
+struct clip_graph_paddleocr : clip_graph {
+    clip_graph_paddleocr(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
+    ggml_cgraph * build() override;
+};
+
 struct clip_graph_cogvlm : clip_graph {
     clip_graph_cogvlm(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
     ggml_cgraph * build() override;
@@ -57,7 +77,52 @@ struct clip_graph_whisper_enc : clip_graph {
     ggml_cgraph * build() override;
 };
 
+struct clip_graph_conformer : clip_graph {
+    clip_graph_conformer(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
+    ggml_cgraph * build() override;
+};
+
 struct clip_graph_glm4v : clip_graph {
     clip_graph_glm4v(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
     ggml_cgraph * build() override;
 };
+
+struct clip_graph_mobilenetv5 : clip_graph {
+    clip_graph_mobilenetv5(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
+    ggml_cgraph * build() override;
+
+    ggml_tensor * rms_norm_2d(
+        ggml_tensor * inp,
+        ggml_tensor * weight,
+        float eps = 1e-6f);
+
+    ggml_tensor* pad_same_2d(
+        ggml_tensor* inp,
+        int kernel_h,
+        int kernel_w,
+        int stride_h,
+        int stride_w,
+        int dilation_h = 1,
+        int dilation_w = 1);
+
+    ggml_tensor * build_edge_residual(
+        ggml_tensor * inp,
+        const mobilenetv5_block & block,
+        int stride);
+
+    ggml_tensor * build_inverted_residual(
+        ggml_tensor * inp,
+        const mobilenetv5_block & block,
+        int stride);
+
+    ggml_tensor * build_mobilenet_attn(
+        ggml_tensor * inp,
+        const mobilenetv5_block & block);
+};
+
+struct clip_graph_kimik25 : clip_graph {
+    clip_graph_kimik25(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
+    ggml_cgraph * build() override;
+
+    ggml_tensor * resize_position_embeddings_3d(uint32_t interpolation_mode);
+};
diff --git a/llama/llama.cpp/tools/mtmd/models/nemotron-v2-vl.cpp b/llama/llama.cpp/tools/mtmd/models/nemotron-v2-vl.cpp
new file mode 100644
index 00000000000..03094be1b27
--- /dev/null
+++ b/llama/llama.cpp/tools/mtmd/models/nemotron-v2-vl.cpp
@@ -0,0 +1,35 @@
+#include "models.h"
+
+ggml_cgraph * clip_graph_nemotron_v2_vl::build() {
+    GGML_ASSERT(model.class_embedding != nullptr);
+    GGML_ASSERT(model.position_embeddings != nullptr);
+
+    const int n_registers = model.class_embedding->ne[1];
+    const int n_pos = n_patches + n_registers;
+
+    ggml_tensor * inp = build_inp();
+
+    // add position embeddings (pre-downsampled during GGUF conversion for fixed 512x512 input)
+    inp = ggml_add(ctx0, inp, model.position_embeddings);
+    cb(inp, "inp_pos", -1);
+
+    inp = ggml_concat(ctx0, model.class_embedding, inp, 1);
+
+    ggml_tensor * cur = build_vit(inp, n_pos, NORM_TYPE_NORMAL, hparams.ffn_op, nullptr, nullptr);
+
+    cur = ggml_view_2d(ctx0, cur,
+        n_embd, n_patches,
+        ggml_row_size(cur->type, n_embd),
+        n_registers * ggml_row_size(cur->type, n_embd));
+
+    cur = build_patch_merge_permute(cur, model.hparams.n_merge);
+
+    {
+        cur = build_norm(cur, model.mm_0_w, nullptr, NORM_TYPE_RMS, 1e-6, -1);
+        cur = build_ffn(cur, model.mm_1_w, nullptr, nullptr, nullptr, model.mm_3_w, nullptr, FFN_RELU_SQR, -1);
+    }
+
+    ggml_build_forward_expand(gf, cur);
+
+    return gf;
+}
diff --git a/llama/llama.cpp/tools/mtmd/models/paddleocr.cpp b/llama/llama.cpp/tools/mtmd/models/paddleocr.cpp
new file mode 100644
index 00000000000..5d3a13fb571
--- /dev/null
+++ b/llama/llama.cpp/tools/mtmd/models/paddleocr.cpp
@@ -0,0 +1,52 @@
+#include "models.h"
+
+ggml_cgraph * clip_graph_paddleocr::build() {
+    const int n_pos            = n_patches;
+    const int num_position_ids = n_pos * 4; // m-rope requires 4 dim per position
+
+    int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
+
+    ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
+    ggml_set_name(positions, "positions");
+    ggml_set_input(positions);
+
+    auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
+        return ggml_rope_multi(
+                    ctx0, cur, positions, nullptr,
+                    d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION,
+                    32768, 10000, 1, 0, 1, 32, 1);
+    };
+
+    ggml_tensor * learned_pos_embd = resize_position_embeddings();
+    ggml_tensor * inp = build_inp();
+    ggml_tensor * cur = build_vit(
+                            inp, n_patches,
+                            NORM_TYPE_NORMAL,
+                            hparams.ffn_op,
+                            learned_pos_embd,
+                            add_pos);
+
+    cb(cur, "vit_out", -1);
+
+    {
+        // mlp_AR paddleocr projector
+        float proj_norm_eps = 1e-5;
+        cur = build_norm(cur,
+                    model.mm_input_norm_w, model.mm_input_norm_b,
+                    NORM_TYPE_NORMAL, proj_norm_eps, -1);
+
+        const int scale_factor = model.hparams.n_merge;
+        cur = build_patch_merge_permute(cur, scale_factor);
+        cur = build_ffn(cur,
+                    model.mm_1_w, model.mm_1_b,
+                    nullptr, nullptr,
+                    model.mm_2_w, model.mm_2_b,
+                    hparams.ffn_op, -1);
+        cb(cur, "mlp_out", -1);
+    }
+
+    // build the graph
+    ggml_build_forward_expand(gf, cur);
+
+    return gf;
+}
diff --git a/llama/llama.cpp/tools/mtmd/models/qwen3vl.cpp b/llama/llama.cpp/tools/mtmd/models/qwen3vl.cpp
index 35a42cb84d6..5ecb10fe438 100644
--- a/llama/llama.cpp/tools/mtmd/models/qwen3vl.cpp
+++ b/llama/llama.cpp/tools/mtmd/models/qwen3vl.cpp
@@ -182,7 +182,9 @@ ggml_cgraph * clip_graph_qwen3vl::build() {
         model.mm_1_w, model.mm_1_b,
         ffn_op_type::FFN_GELU, -1);
 
-    embeddings = ggml_concat(ctx0, embeddings, deepstack_features, 0); // concat along the feature dimension
+    if (deepstack_features) {
+        embeddings = ggml_concat(ctx0, embeddings, deepstack_features, 0);
+    } // concat along the feature dimension
 
     // build the graph
     ggml_build_forward_expand(gf, embeddings);
diff --git a/llama/llama.cpp/tools/mtmd/models/siglip.cpp b/llama/llama.cpp/tools/mtmd/models/siglip.cpp
index ef094cfd0eb..b866a11c5aa 100644
--- a/llama/llama.cpp/tools/mtmd/models/siglip.cpp
+++ b/llama/llama.cpp/tools/mtmd/models/siglip.cpp
@@ -50,10 +50,15 @@ ggml_cgraph * clip_graph_siglip::build() {
         const int scale_factor = model.hparams.n_merge;
         cur = build_patch_merge_permute(cur, scale_factor);
 
-        // projection
-        cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm
-        cur = ggml_mul(ctx0, cur, model.mm_input_norm_w);
-        cur = ggml_add(ctx0, cur, model.mm_input_norm_b);
+        // projection, in LFM2-VL input norm is optional
+        if (model.mm_input_norm_w) {
+            cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm
+            cur = ggml_mul(ctx0, cur, model.mm_input_norm_w);
+        }
+
+        if (model.mm_input_norm_b) {
+            cur = ggml_add(ctx0, cur, model.mm_input_norm_b);
+        }
 
         cur = build_ffn(cur,
             model.mm_1_w, model.mm_1_b,
diff --git a/llama/llama.cpp/tools/mtmd/models/whisper-enc.cpp b/llama/llama.cpp/tools/mtmd/models/whisper-enc.cpp
index 2870d854ab8..2f2b1277551 100644
--- a/llama/llama.cpp/tools/mtmd/models/whisper-enc.cpp
+++ b/llama/llama.cpp/tools/mtmd/models/whisper-enc.cpp
@@ -86,6 +86,15 @@ ggml_cgraph * clip_graph_whisper_enc::build() {
             FFN_GELU_ERF,
             -1);
 
+    } else if (proj_type == PROJECTOR_TYPE_MUSIC_FLAMINGO) {
+        // projector
+        cur = build_ffn(cur,
+            model.mm_1_w, model.mm_1_b,
+            nullptr, nullptr,
+            model.mm_2_w, model.mm_2_b,
+            FFN_GELU_ERF,
+            -1);
+
     } else if (proj_type == PROJECTOR_TYPE_GLMA) {
             cur = ggml_norm(ctx0, cur, hparams.eps);
             cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
diff --git a/llama/llama.cpp/tools/mtmd/models/youtuvl.cpp b/llama/llama.cpp/tools/mtmd/models/youtuvl.cpp
new file mode 100644
index 00000000000..ffbf2be5547
--- /dev/null
+++ b/llama/llama.cpp/tools/mtmd/models/youtuvl.cpp
@@ -0,0 +1,179 @@
+#include "models.h"
+
+ggml_cgraph * clip_graph_youtuvl::build() {
+    GGML_ASSERT(model.class_embedding == nullptr);
+    const int batch_size       = 1;
+    const bool use_window_attn = !hparams.wa_layer_indexes.empty();
+    const int n_pos            = n_patches;
+    const int num_position_ids = n_pos * 4;
+    const int m = 2;
+    const int Wp = n_patches_x;
+    const int Hp = n_patches_y;
+    const int Hm = Hp / m;
+    const int Wm = Wp / m;
+    norm_type norm_t = NORM_TYPE_NORMAL;
+
+    int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
+
+    ggml_tensor * inp = build_inp_raw();
+
+    // change conv3d to linear
+    // reshape and permute to get patches, permute from (patch_size, m, Wm, patch_size, m, Hm, C) to (C, patch_size, patch_size, m, m, Wm, Hm)
+    {
+        inp = ggml_reshape_4d(
+            ctx0, inp,
+            Wm * m * patch_size, m * patch_size, Hm, 3);
+        inp = ggml_permute(ctx0, inp, 1, 2, 3, 0);
+        inp = ggml_cont_4d(
+            ctx0, inp,
+            m * patch_size * 3, Wm, m * patch_size, Hm);
+
+        inp = ggml_permute(ctx0, inp, 0, 2, 1, 3);
+        inp = ggml_cont_4d(
+            ctx0, inp,
+            m * patch_size * 3, patch_size, m, Hm * Wm);
+
+        inp = ggml_permute(ctx0, inp, 1, 0, 2, 3);
+        inp = ggml_cont_4d(
+            ctx0, inp,
+            patch_size, 3, patch_size, Hm * Wm * m * m);
+
+        inp = ggml_permute(ctx0, inp, 2, 0, 1, 3);
+        inp = ggml_cont_3d(
+            ctx0, inp,
+            3*patch_size* patch_size,  Hm * Wm * m * m, 1);
+    }
+    inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp);
+
+    if (model.patch_bias) {
+        inp = ggml_add(ctx0, inp, model.patch_bias);
+    }
+
+    inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches);
+
+    ggml_tensor * inpL           = inp;
+    ggml_tensor * window_mask    = nullptr;
+    ggml_tensor * window_idx     = nullptr;
+    ggml_tensor * inv_window_idx = nullptr;
+
+    ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
+    ggml_set_name(positions, "positions");
+    ggml_set_input(positions);
+
+    // pre-layernorm
+    if (model.pre_ln_w) {
+        inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1);
+    }
+    if (use_window_attn) {
+        inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4);
+        ggml_set_name(inv_window_idx, "inv_window_idx");
+        ggml_set_input(inv_window_idx);
+        // mask for window attention
+        window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_pos, n_pos);
+        ggml_set_name(window_mask, "window_mask");
+        ggml_set_input(window_mask);
+
+        // if flash attn is used, we need to pad the mask and cast to f16
+        if (flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
+            window_mask = ggml_cast(ctx0, window_mask, GGML_TYPE_F16);
+        }
+
+        // inpL shape: [n_embd, n_patches_x * n_patches_y, batch_size]
+        GGML_ASSERT(batch_size == 1);
+        inpL = ggml_reshape_2d(ctx0, inpL, n_embd * 4, n_patches_x * n_patches_y * batch_size / 4);
+        inpL = ggml_get_rows(ctx0, inpL, inv_window_idx);
+        inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_patches_x * n_patches_y, batch_size);
+    }
+
+    // loop over layers
+    for (int il = 0; il < n_layer; il++) {
+        const auto & layer = model.layers[il];
+        const bool full_attn = use_window_attn ? hparams.wa_layer_indexes.count(il) > 0 : true;
+
+        ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
+
+        // layernorm1
+        cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il);
+        // self-attention
+        {
+            ggml_tensor * Qcur = ggml_add(ctx0,
+                ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b);
+            ggml_tensor * Kcur = ggml_add(ctx0,
+                ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b);
+            ggml_tensor * Vcur = ggml_add(ctx0,
+                ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b);
+
+            Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches);
+            Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_patches);
+
+            Qcur = ggml_rope_multi(
+                ctx0, Qcur, positions, nullptr,
+                d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
+            Kcur = ggml_rope_multi(
+                ctx0, Kcur, positions, nullptr,
+                d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
+
+            ggml_tensor * attn_mask = full_attn ? nullptr : window_mask;
+
+            cur = build_attn(layer.o_w, layer.o_b,
+                Qcur, Kcur, Vcur, attn_mask, kq_scale, il);
+        }
+        // re-add the layer input, e.g., residual
+        cur = ggml_add(ctx0, cur, inpL);
+
+        inpL = cur; // inpL = residual, cur = hidden_states
+
+        // layernorm2
+        cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il);
+
+        // ffn
+        cur = build_ffn(cur,
+            layer.ff_up_w, layer.ff_up_b,
+            nullptr, nullptr,
+            layer.ff_down_w, layer.ff_down_b,
+            hparams.ffn_op, il);
+
+        // residual 2
+        cur = ggml_add(ctx0, inpL, cur);
+
+        inpL = cur;
+    }
+
+    ggml_tensor * embeddings = inpL;
+    if (use_window_attn) {
+        const int spatial_merge_unit = 4;
+        window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / spatial_merge_unit);
+        ggml_set_name(window_idx, "window_idx");
+        ggml_set_input(window_idx);
+        GGML_ASSERT(batch_size == 1);
+        embeddings = ggml_reshape_2d(ctx0, embeddings, n_embd * spatial_merge_unit, n_patches / spatial_merge_unit);
+        embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
+        embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd, n_patches, batch_size);
+        cb(embeddings, "window_order_restored", -1);
+    }
+
+    // post-layernorm (part of Siglip2VisionTransformer, applied after encoder)
+    if (model.post_ln_w) {
+        embeddings = build_norm(embeddings, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer);
+    }
+
+    // Now apply merger (VLPatchMerger):
+    // 1. Apply RMS norm (ln_q in VLPatchMerger)
+    embeddings = build_norm(embeddings, model.mm_input_norm_w, nullptr, NORM_TYPE_RMS, 1e-6, -1);
+    cb(embeddings, "merger_normed", -1);
+
+    // 2. First reshape for spatial merge (merge 2x2 patches)
+    embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size);
+    cb(embeddings, "merger_reshaped", -1);
+
+    embeddings = build_ffn(embeddings,
+                    model.mm_0_w, model.mm_0_b,
+                    nullptr, nullptr,
+                    model.mm_1_w, model.mm_1_b,
+                    FFN_GELU,
+                    -1);
+    ggml_build_forward_expand(gf, embeddings);
+
+    return gf;
+}
diff --git a/llama/llama.cpp/tools/mtmd/mtmd-audio.cpp b/llama/llama.cpp/tools/mtmd/mtmd-audio.cpp
index 2024d3d37a8..a208c778997 100644
--- a/llama/llama.cpp/tools/mtmd/mtmd-audio.cpp
+++ b/llama/llama.cpp/tools/mtmd/mtmd-audio.cpp
@@ -9,207 +9,250 @@
 #include 
 #include 
 
-// most of the code here is copied from whisper.cpp
+// some of the code here is copied from whisper.cpp
 
 constexpr bool DEBUG = false;
 
-struct mtmd_audio_mel_filters {
-    int32_t n_mel;
-    int32_t n_fft;
-
-    std::vector data;
-};
-
-// note: this global cache is shared among all preprocessors
-//       if we want to use multiple preprocessors at the same time,
-//       we will need to enclose it in the preprocessor class in the future
-static struct mtmd_audio_global_cache {
-    // precomputed sin/cos table for FFT
-    std::vector sin_vals;
-    std::vector cos_vals;
-
-    // hann window
-    std::vector hann_window;
-
-    // mel filter bank
-    mtmd_audio_mel_filters filters;
-
-    void fill_sin_cos_table(int n) {
-        sin_vals.resize(n);
-        cos_vals.resize(n);
-        for (int i = 0; i < n; i++) {
-            double theta = (2 * M_PI * i) / n;
-            sin_vals[i] = sinf(theta);
-            cos_vals[i] = cosf(theta);
-        }
+void mtmd_audio_cache::fill_sin_cos_table(int n) {
+    sin_vals.resize(n);
+    cos_vals.resize(n);
+    for (int i = 0; i < n; i++) {
+        double theta = (2 * M_PI * i) / n;
+        sin_vals[i]  = sinf(theta);
+        cos_vals[i]  = cosf(theta);
     }
+}
 
-    void fill_hann_window(int length, bool periodic) {
-        hann_window.resize(length);
-        int offset = -1;
-        if (periodic) {
-            offset = 0;
-        }
-        for (int i = 0; i < length; i++) {
-            hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
-        }
+void mtmd_audio_cache::fill_hann_window(int length, bool periodic) {
+    hann_window.resize(length);
+    int offset = -1;
+    if (periodic) {
+        offset = 0;
+    }
+    for (int i = 0; i < length; i++) {
+        hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
     }
+}
 
-    // Build mel filterbank matrix [n_mel × n_fft_bins] at runtime.
-    // n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257.
-    void fill_mel_filterbank_matrix(
-        int n_mel,
-        int n_fft,
-        int sample_rate,            // e.g. 16000
-        float fmin = 0.0f,          // e.g. 0.0
-        float fmax = -1.0f,         // e.g. sr/2; pass -1 for auto
-        bool slaney_area_norm = true,
-        float scale = 1.0f          // optional extra scaling; use 1.0f/1000.0f to mimic your code
-    ) {
-        GGML_ASSERT(n_mel > 0 && n_fft > 1);
-        if (fmax <= 0.0f) {
-            fmax = 0.5f * sample_rate;
-        }
+void mtmd_audio_cache::fill_mel_filterbank_matrix(int   n_mel,
+                                                  int   n_fft,
+                                                  int   sample_rate,
+                                                  float fmin,
+                                                  float fmax,
+                                                  bool  slaney_area_norm,
+                                                  float scale) {
+    GGML_ASSERT(n_mel > 0 && n_fft > 1);
+    if (fmax <= 0.0f) {
+        fmax = 0.5f * sample_rate;
+    }
 
-        // Slaney scale (matches librosa default)
-        const double min_log_hz = 1000.0;
-        const double lin_slope = 3 / 200.;
-        const double min_log_mel = min_log_hz * lin_slope;
-        const double log_step = log(6.4) / 27.0;
-        auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double {
-            return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step;
-        };
-        auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double {
-            return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step);
-        };
-
-        // infer N_fft from n_fft_bins
-        const double bin_hz_step = double(sample_rate) / double(n_fft);
-
-        // mel grid: n_mel + 2 edges
-        const double m_lo = hz_to_mel(fmin);
-        const double m_hi = hz_to_mel(fmax);
-        std::vector mel_pts(n_mel + 2);
-        for (int i = 0; i < n_mel + 2; ++i) {
-            mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1));
-        }
+    // Slaney scale (matches librosa default)
+    const double min_log_hz  = 1000.0;
+    const double lin_slope   = 3 / 200.;
+    const double min_log_mel = min_log_hz * lin_slope;
+    const double log_step    = log(6.4) / 27.0;
+    auto         hz_to_mel   = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double {
+        return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step;
+    };
+    auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double {
+        return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step);
+    };
+
+    // infer N_fft from n_fft_bins
+    const double bin_hz_step = double(sample_rate) / double(n_fft);
+
+    // mel grid: n_mel + 2 edges
+    const double        m_lo = hz_to_mel(fmin);
+    const double        m_hi = hz_to_mel(fmax);
+    std::vector mel_pts(n_mel + 2);
+    for (int i = 0; i < n_mel + 2; ++i) {
+        mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1));
+    }
 
-        // convert to Hz
-        std::vector hz_pts(n_mel + 2);
-        for (int i = 0; i < n_mel + 2; ++i) {
-            hz_pts[i] = mel_to_hz(mel_pts[i]);
-        }
+    // convert to Hz
+    std::vector hz_pts(n_mel + 2);
+    for (int i = 0; i < n_mel + 2; ++i) {
+        hz_pts[i] = mel_to_hz(mel_pts[i]);
+    }
 
-        const int n_fft_bins = n_fft / 2 + 1;
-
-        // filterbank
-        std::vector out(n_mel * n_fft_bins, 0);
-        for (int m = 0; m < n_mel; ++m) {
-            const double f_left   = hz_pts[m];
-            const double f_center = hz_pts[m + 1];
-            const double f_right  = hz_pts[m + 2];
-
-            const double denom_l = std::max(1e-30, f_center - f_left);
-            const double denom_r = std::max(1e-30, f_right  - f_center);
-            const double enorm   = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0;
-
-            for (int k = 0; k < n_fft_bins; ++k) {
-                const double f = k * bin_hz_step;
-                double w = 0.0;
-                if (f >= f_left && f <= f_center) {
-                    w = (f - f_left) / denom_l;
-                } else if (f > f_center && f <= f_right) {
-                    w = (f_right - f) / denom_r;
-                }
-                out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale);
+    const int n_fft_bins = n_fft / 2 + 1;
+
+    // filterbank
+    std::vector out(n_mel * n_fft_bins, 0);
+    for (int m = 0; m < n_mel; ++m) {
+        const double f_left   = hz_pts[m];
+        const double f_center = hz_pts[m + 1];
+        const double f_right  = hz_pts[m + 2];
+
+        const double denom_l = std::max(1e-30, f_center - f_left);
+        const double denom_r = std::max(1e-30, f_right - f_center);
+        const double enorm   = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0;
+
+        for (int k = 0; k < n_fft_bins; ++k) {
+            const double f = k * bin_hz_step;
+            double       w = 0.0;
+            if (f >= f_left && f <= f_center) {
+                w = (f - f_left) / denom_l;
+            } else if (f > f_center && f <= f_right) {
+                w = (f_right - f) / denom_r;
             }
+            out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale);
         }
+    }
 
-        filters.n_mel = n_mel;
-        filters.n_fft = n_fft;
-        filters.data  = std::move(out);
+    filters.n_mel = n_mel;
+    filters.n_fft = n_fft;
+    filters.data  = std::move(out);
 
-        if (DEBUG) { // debug
-            for (size_t i = 0; i < filters.data.size(); ++i) {
-                if (filters.data[i] != 0.0f) {
-                    printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f);
-                }
+    if (DEBUG) {  // debug
+        for (size_t i = 0; i < filters.data.size(); ++i) {
+            if (filters.data[i] != 0.0f) {
+                printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f);
             }
         }
     }
-} g_cache;
+}
 
-// naive Discrete Fourier Transform
-// input is real-valued
-// output is complex-valued
-static void dft(const float * in, int N, float * out) {
-    const int n_sin_cos_vals = g_cache.sin_vals.size();
-    const int sin_cos_step = n_sin_cos_vals / N;
+// Unified DFT implementation for both forward and inverse transforms
+// Template parameters:
+//   Inverse: false = DFT with exp(-2πi·k·n/N), no scaling
+//            true  = IDFT with exp(+2πi·k·n/N), scales by 1/N
+//   RealInput: true = input is real-valued (stride 1), avoids imaginary computations
+//              false = input is complex-valued (interleaved real/imag, stride 2)
+template 
+static void dft_impl(const mtmd_audio_cache & cache, const float * in, int N, float * out) {
+    const int n_sin_cos_vals = cache.sin_vals.size();
+    const int sin_cos_step   = n_sin_cos_vals / N;
+
+    constexpr float sign  = Inverse ? 1.0f : -1.0f;
+    const float     scale = Inverse ? (1.0f / N) : 1.0f;
 
     for (int k = 0; k < N; k++) {
         float re = 0;
         float im = 0;
 
         for (int n = 0; n < N; n++) {
-            int idx = (k * n * sin_cos_step) % (n_sin_cos_vals); // t = 2*M_PI*k*n/N
-            re += in[n] * g_cache.cos_vals[idx]; // cos(t)
-            im -= in[n] * g_cache.sin_vals[idx]; // sin(t)
+            int   idx     = (k * n * sin_cos_step) % n_sin_cos_vals;
+            float cos_val = cache.cos_vals[idx];
+            float sin_val = cache.sin_vals[idx];
+
+            if constexpr (RealInput) {
+                // Real input: in_im = 0, simplifies to:
+                // re += in_re * cos_val
+                // im += sign * in_re * sin_val
+                float in_re = in[n];
+                re += in_re * cos_val;
+                im += sign * in_re * sin_val;
+            } else {
+                float in_re = in[n * 2 + 0];
+                float in_im = in[n * 2 + 1];
+                // (a + bi) * (cos + sign*i*sin) = (a*cos - sign*b*sin) + (sign*a*sin + b*cos)i
+                re += in_re * cos_val - sign * in_im * sin_val;
+                im += sign * in_re * sin_val + in_im * cos_val;
+            }
         }
 
-        out[k*2 + 0] = re;
-        out[k*2 + 1] = im;
+        out[k * 2 + 0] = re * scale;
+        out[k * 2 + 1] = im * scale;
     }
 }
 
-// Cooley-Tukey FFT
-// poor man's implementation - use something better
-// input is real-valued
-// output is complex-valued
-static void fft(float * in, int N, float * out) {
-    const int n_sin_cos_vals = g_cache.sin_vals.size();
+// Cooley-Tukey FFT/IFFT unified implementation
+// Template parameters:
+//   Inverse: false = FFT with exp(-2πi·k/N), no scaling
+//            true  = IFFT with exp(+2πi·k/N), scales by 0.5 at each level
+//   RealInput: true = input is real-valued (stride 1)
+//              false = input is complex-valued (interleaved real/imag, stride 2)
+template 
+static void fft_impl(const mtmd_audio_cache & cache, float * in, int N, float * out) {
+    const int n_sin_cos_vals = cache.sin_vals.size();
+
     if (N == 1) {
         out[0] = in[0];
-        out[1] = 0;
+        if constexpr (RealInput) {
+            out[1] = 0.0f;
+        } else {
+            out[1] = in[1];
+        }
         return;
     }
 
     const int half_N = N / 2;
-    if (N - half_N*2 == 1) {
-        dft(in, N, out);
+    if (N - half_N * 2 == 1) {
+        // Odd N: fall back to DFT
+        dft_impl(cache, in, N, out);
         return;
     }
 
-    float* even = in + N;
-    for (int i = 0; i < half_N; ++i) {
-        even[i]= in[2*i];
-    }
-    float* even_fft = out + 2 * N;
-    fft(even, half_N, even_fft);
+    // Split into even and odd
+    if constexpr (RealInput) {
+        // Real input: stride is 1, copy only real values
+        float * even = in + N;
+        for (int i = 0; i < half_N; ++i) {
+            even[i] = in[2 * i];
+        }
+        float * even_fft = out + 2 * N;
+        fft_impl(cache, even, half_N, even_fft);
+
+        float * odd = even;
+        for (int i = 0; i < half_N; ++i) {
+            odd[i] = in[2 * i + 1];
+        }
+        float * odd_fft = even_fft + N;
+        fft_impl(cache, odd, half_N, odd_fft);
+    } else {
+        // Complex input: stride is 2, copy complex pairs
+        float * even = in + N * 2;
+        for (int i = 0; i < half_N; ++i) {
+            even[i * 2 + 0] = in[2 * i * 2 + 0];
+            even[i * 2 + 1] = in[2 * i * 2 + 1];
+        }
+        float * even_fft = out + 2 * N;
+        fft_impl(cache, even, half_N, even_fft);
 
-    float* odd = even;
-    for (int i = 0; i < half_N; ++i) {
-        odd[i] = in[2*i + 1];
+        float * odd = even;
+        for (int i = 0; i < half_N; ++i) {
+            odd[i * 2 + 0] = in[(2 * i + 1) * 2 + 0];
+            odd[i * 2 + 1] = in[(2 * i + 1) * 2 + 1];
+        }
+        float * odd_fft = even_fft + N;
+        fft_impl(cache, odd, half_N, odd_fft);
     }
-    float* odd_fft = even_fft + N;
-    fft(odd, half_N, odd_fft);
+
+    float * even_fft = out + 2 * N;
+    float * odd_fft  = even_fft + N;
 
     const int sin_cos_step = n_sin_cos_vals / N;
+
+    constexpr float sign  = Inverse ? 1.0f : -1.0f;
+    constexpr float scale = Inverse ? 0.5f : 1.0f;
+
     for (int k = 0; k < half_N; k++) {
-        int idx = k * sin_cos_step; // t = 2*M_PI*k/N
-        float re =  g_cache.cos_vals[idx]; // cos(t)
-        float im = -g_cache.sin_vals[idx]; // sin(t)
+        int   idx = k * sin_cos_step;  // t = 2*M_PI*k/N
+        float re  = cache.cos_vals[idx];
+        float im  = sign * cache.sin_vals[idx];
 
-        float re_odd = odd_fft[2*k + 0];
-        float im_odd = odd_fft[2*k + 1];
+        float re_odd = odd_fft[2 * k + 0];
+        float im_odd = odd_fft[2 * k + 1];
 
-        out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
-        out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
+        out[2 * k + 0] = scale * (even_fft[2 * k + 0] + re * re_odd - im * im_odd);
+        out[2 * k + 1] = scale * (even_fft[2 * k + 1] + re * im_odd + im * re_odd);
 
-        out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
-        out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
+        out[2 * (k + half_N) + 0] = scale * (even_fft[2 * k + 0] - re * re_odd + im * im_odd);
+        out[2 * (k + half_N) + 1] = scale * (even_fft[2 * k + 1] - re * im_odd - im * re_odd);
     }
 }
 
+// Forward FFT for real input (used by mel spectrogram)
+static void fft(const mtmd_audio_cache & cache, float * in, int N, float * out) {
+    fft_impl(cache, in, N, out);
+}
+
+// Inverse FFT for complex input
+static void ifft(const mtmd_audio_cache & cache, float * in, int N, float * out) {
+    fft_impl(cache, in, N, out);
+}
+
 struct filter_params {
     int32_t n_mel;
     int32_t n_fft_bins;
@@ -222,20 +265,27 @@ struct filter_params {
     bool    norm_per_feature = false;
 };
 
-static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector & samples,
-                                              int n_samples, int frame_size, int frame_step, int n_threads,
-                                              const filter_params & params, mtmd_audio_mel & out) {
+static void log_mel_spectrogram_worker_thread(int                        ith,
+                                              const float *              hann,
+                                              const std::vector & samples,
+                                              int                        n_samples,
+                                              int                        frame_size,
+                                              int                        frame_step,
+                                              int                        n_threads,
+                                              const filter_params &      params,
+                                              const mtmd_audio_cache &   cache,
+                                              mtmd_audio_mel &           out) {
     std::vector fft_in(frame_size * 2, 0.0);
     std::vector fft_out(frame_size * 2 * 2 * 2);
 
     int n_fft_bins = params.n_fft_bins;
     int i = ith;
 
-    const auto & filters = g_cache.filters;
+    const auto & filters = cache.filters;
 
     // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
     GGML_ASSERT(n_fft_bins == 1 + (frame_size / 2));
-    GGML_ASSERT(g_cache.sin_vals.size() == g_cache.cos_vals.size());
+    GGML_ASSERT(cache.sin_vals.size() == cache.cos_vals.size());
     // calculate FFT only when fft_in are not all zero
     for (; i < std::min(n_samples / frame_step + 1, out.n_len); i += n_threads) {
         const int offset = i * frame_step;
@@ -251,7 +301,7 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const
         }
 
         // FFT
-        fft(fft_in.data(), frame_size, fft_out.data());
+        fft(cache, fft_in.data(), frame_size, fft_out.data());
 
         // Calculate modulus^2 of complex numbers
         // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
@@ -298,6 +348,7 @@ static bool log_mel_spectrogram(
         const int     n_samples_in,
         const int     n_threads,
         const filter_params & params,
+        const mtmd_audio_cache & cache,
         mtmd_audio_mel & out) {
     //const int64_t t_start_us = ggml_time_us();
 
@@ -305,9 +356,9 @@ static bool log_mel_spectrogram(
     int n_samples = n_samples_in;
 
     // Hann window
-    const float * hann = g_cache.hann_window.data();
-    const int frame_size = (params.n_fft_bins - 1) * 2;
-    const int frame_step = params.hop_length;
+    const float * hann       = cache.hann_window.data();
+    const int     frame_size = (params.n_fft_bins - 1) * 2;
+    const int     frame_step = params.hop_length;
 
     // Padding
     std::vector samples_padded;
@@ -335,9 +386,9 @@ static bool log_mel_spectrogram(
 
     // preemphasis
     if (params.preemph) {
-        const int pad_amount = frame_size / 2;
+        const int   pad_amount = frame_size / 2;
         const float preemph = 0.97f;
-        float prev = samples_padded[pad_amount];
+        float       prev = samples_padded[pad_amount];
         for (int i = pad_amount + 1; i + pad_amount < n_samples; ++i) {
             float cur = samples_padded[i];
             samples_padded[i] = cur - preemph * prev;
@@ -372,14 +423,14 @@ static bool log_mel_spectrogram(
     {
         std::vector workers(n_threads - 1);
         for (int iw = 0; iw < n_threads - 1; ++iw) {
-            workers[iw] = std::thread(
-                    log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded),
-                    n_samples, frame_size, frame_step, n_threads,
-                    std::cref(params), std::ref(out));
+            workers[iw] =
+                std::thread(log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded), n_samples,
+                            frame_size, frame_step, n_threads, std::cref(params), std::cref(cache), std::ref(out));
         }
 
         // main thread
-        log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params, out);
+        log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params,
+                                          cache, out);
         for (int iw = 0; iw < n_threads - 1; ++iw) {
             workers[iw].join();
         }
@@ -404,7 +455,7 @@ static bool log_mel_spectrogram(
 
             for (int j = 0; j < effective_n_len; ++j) {
                 auto &value = out.data[i * out.n_len + j];
-                value = (value - mean) / mstd;
+                value        = (value - mean) / mstd;
             }
 
             // pad the rest with zeros
@@ -450,18 +501,14 @@ static bool log_mel_spectrogram(
 //
 
 void mtmd_audio_preprocessor_whisper::initialize() {
-    g_cache.fill_sin_cos_table(hparams.audio_n_fft);
-    g_cache.fill_hann_window(hparams.audio_window_len, true);
-    g_cache.fill_mel_filterbank_matrix(
-        hparams.n_mel_bins,
-        hparams.audio_n_fft,
-        hparams.audio_sample_rate);
+    cache.fill_sin_cos_table(hparams.audio_n_fft);
+    cache.fill_hann_window(hparams.audio_window_len, true);
+    cache.fill_mel_filterbank_matrix(hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate);
 }
 
-bool mtmd_audio_preprocessor_whisper::preprocess(
-        const float * samples,
-        size_t n_samples,
-        std::vector & output) {
+bool mtmd_audio_preprocessor_whisper::preprocess(const float *                 samples,
+                                                 size_t                        n_samples,
+                                                 std::vector & output) {
     if (n_samples == 0) {
         // empty audio
         return false;
@@ -471,7 +518,7 @@ bool mtmd_audio_preprocessor_whisper::preprocess(
     // if input is too short, pad with zeros
     // this is to avoid potential issues with stage1/2 padding in log_mel_spectrogram
     // TODO: maybe handle this better
-    size_t min_samples = (size_t)hparams.audio_sample_rate * (hparams.audio_chunk_len + 1); // +1 second margin
+    size_t min_samples = (size_t) hparams.audio_sample_rate * (hparams.audio_chunk_len + 1);  // +1 second margin
     if (n_samples < min_samples) {
         smpl.resize(min_samples, 0.0f);
         std::memcpy(smpl.data(), samples, n_samples * sizeof(float));
@@ -486,22 +533,19 @@ bool mtmd_audio_preprocessor_whisper::preprocess(
     params.hop_length       = hparams.audio_hop_len;
     params.sample_rate      = hparams.audio_sample_rate;
     params.center_padding   = false;
-    params.preemph          = 0.0f; // disabled
+    params.preemph          = 0.0f;  // disabled
     params.use_natural_log  = false;
     params.norm_per_feature = false;
 
-    // make sure the global cache is initialized
-    GGML_ASSERT(!g_cache.sin_vals.empty());
-    GGML_ASSERT(!g_cache.cos_vals.empty());
-    GGML_ASSERT(!g_cache.filters.data.empty());
+    // make sure the cache is initialized
+    GGML_ASSERT(!cache.sin_vals.empty());
+    GGML_ASSERT(!cache.cos_vals.empty());
+    GGML_ASSERT(!cache.filters.data.empty());
 
     mtmd_audio_mel out_full;
-    bool ok = log_mel_spectrogram(
-                samples,
-                n_samples,
-                4, // n_threads
-                params,
-                out_full);
+    bool           ok = log_mel_spectrogram(samples, n_samples,
+                                            4,  // n_threads
+                                            params, cache, out_full);
     if (!ok) {
         return false;
     }
@@ -512,21 +556,21 @@ bool mtmd_audio_preprocessor_whisper::preprocess(
         printf("output: n_mel = %d, n_len = %d\n", out_full.n_mel, out_full.n_len);
     }
     const size_t frames_per_chunk = 3000;
-    GGML_ASSERT((size_t)out_full.n_len > frames_per_chunk);
-    for (size_t off = 0; off < (size_t)out_full.n_len; off += frames_per_chunk) {
-        int n_len = std::min(frames_per_chunk, (size_t)out_full.n_len - off);
-        if ((size_t)n_len < frames_per_chunk) {
-            break; // last uncomplete chunk will always be a padded chunk, safe to ignore
+    GGML_ASSERT((size_t) out_full.n_len > frames_per_chunk);
+    for (size_t off = 0; off < (size_t) out_full.n_len; off += frames_per_chunk) {
+        int n_len = std::min(frames_per_chunk, (size_t) out_full.n_len - off);
+        if ((size_t) n_len < frames_per_chunk) {
+            break;  // last uncomplete chunk will always be a padded chunk, safe to ignore
         }
 
         mtmd_audio_mel out_chunk;
         out_chunk.n_len     = n_len;
         out_chunk.n_mel     = out_full.n_mel;
-        out_chunk.n_len_org = out_full.n_mel; // unused
+        out_chunk.n_len_org = out_full.n_mel;  // unused
         out_chunk.data.reserve(out_chunk.n_mel * out_chunk.n_len);
 
         for (int i = 0; i < out_full.n_mel; i++) {
-            auto src = out_full.data.begin() + i*out_full.n_len + off;
+            auto src = out_full.data.begin() + i * out_full.n_len + off;
             out_chunk.data.insert(out_chunk.data.end(), src, src + frames_per_chunk);
         }
 
@@ -535,3 +579,152 @@ bool mtmd_audio_preprocessor_whisper::preprocess(
 
     return true;
 }
+
+//
+// mtmd_audio_preprocessor_conformer
+//
+
+void mtmd_audio_preprocessor_conformer::initialize() {
+    cache.fill_sin_cos_table(hparams.audio_n_fft);
+    cache.fill_hann_window(hparams.audio_window_len, true);
+    cache.fill_mel_filterbank_matrix(hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate);
+}
+
+bool mtmd_audio_preprocessor_conformer::preprocess(const float *                 samples,
+                                                   size_t                        n_samples,
+                                                   std::vector & output) {
+    // empty audio
+    if (n_samples == 0) {
+        return false;
+    }
+
+    filter_params params;
+    params.n_mel            = hparams.n_mel_bins;
+    params.n_fft_bins       = 1 + (hparams.audio_n_fft / 2);
+    params.hann_window_size = hparams.audio_window_len;
+    params.hop_length       = hparams.audio_hop_len;
+    params.sample_rate      = hparams.audio_sample_rate;
+    params.center_padding   = true;
+    params.preemph          = 0.97f;
+    params.use_natural_log  = true;
+    params.norm_per_feature = true;
+
+    // make sure the cache is initialized
+    GGML_ASSERT(!cache.sin_vals.empty());
+    GGML_ASSERT(!cache.cos_vals.empty());
+    GGML_ASSERT(!cache.filters.data.empty());
+
+    mtmd_audio_mel out_full;
+    bool           ok = log_mel_spectrogram(samples, n_samples,
+                                            4,  // n_threads
+                                            params, cache, out_full);
+    if (!ok) {
+        return false;
+    }
+
+    output.push_back(std::move(out_full));
+    return true;
+}
+
+//
+// mtmd_audio_streaming_istft implementation
+//
+
+mtmd_audio_streaming_istft::mtmd_audio_streaming_istft(int n_fft, int hop_length) :
+    n_fft(n_fft),
+    hop_length(hop_length),
+    n_fft_bins(n_fft / 2 + 1),
+    overlap_buffer(n_fft, 0.0f),
+    window_sum_buffer(n_fft, 0.0f),
+    padding_to_remove((n_fft - hop_length) / 2),
+    ifft_in(n_fft * 2 * 4, 0.0f),  // extra space for recursive IFFT
+    ifft_out(n_fft * 2 * 4, 0.0f) {
+    cache.fill_sin_cos_table(n_fft);
+    cache.fill_hann_window(n_fft, true);
+}
+
+void mtmd_audio_streaming_istft::reset() {
+    std::fill(overlap_buffer.begin(), overlap_buffer.end(), 0.0f);
+    std::fill(window_sum_buffer.begin(), window_sum_buffer.end(), 0.0f);
+    padding_to_remove = (n_fft - hop_length) / 2;
+}
+
+std::vector mtmd_audio_streaming_istft::process_frame(const float * frame_spectrum) {
+    std::vector output(hop_length);
+
+    // copy frequencies
+    for (int j = 0; j < n_fft_bins; j++) {
+        ifft_in[j * 2 + 0] = frame_spectrum[j * 2 + 0];
+        ifft_in[j * 2 + 1] = frame_spectrum[j * 2 + 1];
+    }
+
+    // mirror negative frequencies
+    for (int j = 1; j < n_fft_bins - 1; j++) {
+        int mirror_idx              = n_fft - j;
+        ifft_in[mirror_idx * 2 + 0] = ifft_in[j * 2 + 0];
+        ifft_in[mirror_idx * 2 + 1] = -ifft_in[j * 2 + 1];  // conjugate
+    }
+
+    ifft(cache, ifft_in.data(), n_fft, ifft_out.data());
+
+    // update window sum and overlap buffer
+    for (int j = 0; j < n_fft; j++) {
+        window_sum_buffer[j] += cache.hann_window[j] * cache.hann_window[j];
+        overlap_buffer[j] += ifft_out[j * 2] * cache.hann_window[j];
+    }
+
+    // extract hop_length samples with normalization
+    for (int i = 0; i < hop_length; i++) {
+        if (window_sum_buffer[i] > 1e-8f) {
+            output[i] = overlap_buffer[i] / window_sum_buffer[i];
+        } else {
+            output[i] = overlap_buffer[i];
+        }
+    }
+
+    // shift buffers left by hop_length
+    std::copy(overlap_buffer.begin() + hop_length, overlap_buffer.end(), overlap_buffer.begin());
+    std::fill(overlap_buffer.end() - hop_length, overlap_buffer.end(), 0.0f);
+
+    std::copy(window_sum_buffer.begin() + hop_length, window_sum_buffer.end(), window_sum_buffer.begin());
+    std::fill(window_sum_buffer.end() - hop_length, window_sum_buffer.end(), 0.0f);
+
+    // Remove padding if needed
+    int to_remove = std::min(padding_to_remove, (int) output.size());
+    padding_to_remove -= to_remove;
+    output.erase(output.begin(), output.begin() + to_remove);
+
+    return output;
+}
+
+std::vector mtmd_audio_streaming_istft::flush() {
+    std::vector output;
+
+    // Extract remaining samples from overlap buffer
+    // Continue until we've extracted all meaningful samples
+    int remaining = n_fft - hop_length;
+    while (remaining > 0) {
+        int chunk_size = std::min(remaining, hop_length);
+
+        for (int i = 0; i < chunk_size; i++) {
+            float sample;
+            if (window_sum_buffer[i] > 1e-8f) {
+                sample = overlap_buffer[i] / window_sum_buffer[i];
+            } else {
+                sample = overlap_buffer[i];
+            }
+            output.push_back(sample);
+        }
+
+        // Shift buffers
+        std::copy(overlap_buffer.begin() + chunk_size, overlap_buffer.end(), overlap_buffer.begin());
+        std::fill(overlap_buffer.end() - chunk_size, overlap_buffer.end(), 0.0f);
+
+        std::copy(window_sum_buffer.begin() + chunk_size, window_sum_buffer.end(), window_sum_buffer.begin());
+        std::fill(window_sum_buffer.end() - chunk_size, window_sum_buffer.end(), 0.0f);
+
+        remaining -= chunk_size;
+    }
+
+    return output;
+}
diff --git a/llama/llama.cpp/tools/mtmd/mtmd-audio.h b/llama/llama.cpp/tools/mtmd/mtmd-audio.h
index 1b454337cbe..016c7392e4f 100644
--- a/llama/llama.cpp/tools/mtmd/mtmd-audio.h
+++ b/llama/llama.cpp/tools/mtmd/mtmd-audio.h
@@ -17,6 +17,38 @@ struct mtmd_audio_mel {
     std::vector data;
 };
 
+struct mtmd_audio_mel_filters {
+    int32_t n_mel;
+    int32_t n_fft;
+
+    std::vector data;
+};
+
+// cache for audio processing, each processor instance owns its own cache
+struct mtmd_audio_cache {
+    std::vector sin_vals;
+    std::vector cos_vals;
+
+    std::vector hann_window;
+
+    mtmd_audio_mel_filters filters;
+
+    void fill_sin_cos_table(int n);
+
+    void fill_hann_window(int length, bool periodic);
+
+    // Build mel filterbank matrix [n_mel × n_fft_bins] at runtime.
+    // n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257.
+    void fill_mel_filterbank_matrix(int   n_mel,
+                                    int   n_fft,
+                                    int   sample_rate,               // e.g. 16000
+                                    float fmin             = 0.0f,   // e.g. 0.0
+                                    float fmax             = -1.0f,  // e.g. sr/2; pass -1 for auto
+                                    bool  slaney_area_norm = true,
+                                    float scale = 1.0f  // optional extra scaling
+    );
+};
+
 struct mtmd_audio_preprocessor {
     const clip_hparams & hparams;
 
@@ -31,4 +63,51 @@ struct mtmd_audio_preprocessor_whisper : mtmd_audio_preprocessor {
     mtmd_audio_preprocessor_whisper(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {}
     void initialize() override;
     bool preprocess(const float * samples, size_t n_samples, std::vector & output) override;
+
+  private:
+    mtmd_audio_cache cache;
+};
+
+struct mtmd_audio_preprocessor_conformer : mtmd_audio_preprocessor {
+    mtmd_audio_preprocessor_conformer(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {}
+    void initialize() override;
+    bool preprocess(const float * samples, size_t n_samples, std::vector & output) override;
+
+  private:
+    mtmd_audio_cache cache;
+};
+
+//
+// streaming ISTFT - converts spectrogram frames back to audio one frame at a time
+//
+struct mtmd_audio_streaming_istft {
+    mtmd_audio_streaming_istft(int n_fft, int hop_length);
+
+    // reset streaming state
+    void reset();
+
+    // process a single STFT frame (streaming)
+    // frame_spectrum: [n_fft_bins x 2] interleaved real/imag
+    // returns: up to hop_length samples
+    std::vector process_frame(const float * frame_spectrum);
+
+    // flush remaining samples at end of stream
+    std::vector flush();
+
+  private:
+    int n_fft;
+    int hop_length;
+    int n_fft_bins;
+
+    // Own cache for output processing
+    mtmd_audio_cache cache;
+
+    // Streaming state
+    std::vector overlap_buffer;
+    std::vector window_sum_buffer;
+    int                padding_to_remove;
+
+    // Working buffers for IFFT
+    std::vector ifft_in;
+    std::vector ifft_out;
 };
diff --git a/llama/llama.cpp/tools/mtmd/mtmd-helper.cpp b/llama/llama.cpp/tools/mtmd/mtmd-helper.cpp
index 902a4b456d9..c75f90730f1 100644
--- a/llama/llama.cpp/tools/mtmd/mtmd-helper.cpp
+++ b/llama/llama.cpp/tools/mtmd/mtmd-helper.cpp
@@ -248,7 +248,7 @@ int32_t mtmd_helper_decode_image_chunk(
 
     int32_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk);
     int32_t i_batch = 0;
-    int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
+    int32_t n_img_batches = (n_tokens + n_batch - 1) / n_batch;
     decode_embd_batch batch_embd(encoded_embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
 
     if (mtmd_decode_use_mrope(ctx)) {
diff --git a/llama/llama.cpp/tools/mtmd/mtmd.cpp b/llama/llama.cpp/tools/mtmd/mtmd.cpp
index c4e905a4e9c..03fcb32e74c 100644
--- a/llama/llama.cpp/tools/mtmd/mtmd.cpp
+++ b/llama/llama.cpp/tools/mtmd/mtmd.cpp
@@ -85,6 +85,7 @@ enum mtmd_slice_tmpl {
     MTMD_SLICE_TMPL_MINICPMV_2_6,
     MTMD_SLICE_TMPL_LLAMA4,
     MTMD_SLICE_TMPL_IDEFICS3,
+    MTMD_SLICE_TMPL_LFM2,
 };
 
 mtmd_input_text* mtmd_input_text_init(const char * text, bool add_special, bool parse_special) {
@@ -121,6 +122,8 @@ mtmd_context_params mtmd_context_params_default() {
         /* warmup            */ true,
         /* image_min_tokens  */ -1,
         /* image_max_tokens  */ -1,
+        /* cb_eval           */ nullptr,
+        /* cb_eval_user_data */ nullptr,
     };
     return params;
 }
@@ -156,8 +159,6 @@ struct mtmd_context {
     bool        tok_row_end_trail = false;
     bool        ov_img_first      = false;
 
-    bool use_mrope = false; // for Qwen2VL, we need to use M-RoPE
-
     // string template for slice image delimiters with row/col (idefics3)
     std::string sli_img_start_tmpl;
 
@@ -184,10 +185,12 @@ struct mtmd_context {
 
         clip_context_params ctx_clip_params {
             /* use_gpu           */ ctx_params.use_gpu,
-            /* flash_attn_type   */ CLIP_FLASH_ATTN_TYPE_AUTO,
+            /* flash_attn_type   */ mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type),
             /* image_min_tokens  */ ctx_params.image_min_tokens,
             /* image_max_tokens  */ ctx_params.image_max_tokens,
             /* warmup            */ ctx_params.warmup,
+            /* cb_eval           */ ctx_params.cb_eval,
+            /* cb_eval_user_data */ ctx_params.cb_eval_user_data,
         };
 
         auto res = clip_init(mmproj_fname, ctx_clip_params);
@@ -227,7 +230,6 @@ struct mtmd_context {
 
     void init_vision() {
         GGML_ASSERT(ctx_v != nullptr);
-        use_mrope = clip_is_mrope(ctx_v);
 
         projector_type proj = clip_get_projector_type(ctx_v);
         int minicpmv_version = clip_is_minicpmv(ctx_v);
@@ -245,7 +247,7 @@ struct mtmd_context {
             tok_row_end_trail = false; // no trailing end-of-row token
             ov_img_first      = true;
 
-        } else if (minicpmv_version == 3 || minicpmv_version == 4 || minicpmv_version == 5 || minicpmv_version == 6) {
+        } else if (minicpmv_version == 3 || minicpmv_version == 4 || minicpmv_version == 5 || minicpmv_version == 6 || minicpmv_version == 100045) {
             // minicpmv 2.6 format:
             //  (overview)  (slice)  (slice) \n ...
             slice_tmpl        = MTMD_SLICE_TMPL_MINICPMV_2_6;
@@ -276,7 +278,7 @@ struct mtmd_context {
         }
 
         // set boi/eoi
-        if (proj == PROJECTOR_TYPE_GEMMA3) {
+        if (proj == PROJECTOR_TYPE_GEMMA3 || proj == PROJECTOR_TYPE_GEMMA3NV) {
             //  ... (image embeddings) ... 
             img_beg = "";
             img_end = "";
@@ -293,7 +295,7 @@ struct mtmd_context {
             // https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
             img_end = "[IMG_END]";
 
-        } else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL || proj == PROJECTOR_TYPE_QWEN3VL) {
+        } else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL || proj == PROJECTOR_TYPE_QWEN3VL || proj == PROJECTOR_TYPE_YOUTUVL) {
             // <|vision_start|> ... (image embeddings) ... <|vision_end|>
             img_beg = "<|vision_start|>";
             img_end = "<|vision_end|>";
@@ -316,13 +318,27 @@ struct mtmd_context {
             img_end = "<|im_end|>";
 
         } else if (proj == PROJECTOR_TYPE_LFM2) {
-            img_beg = "<|image_start|>";
-            img_end = "<|image_end|>";
-
+            // multi-tile:
+            //   <|image_start|>
+            //     <|img_row_1_col_1|> (tile) <|img_row_1_col_2|> (tile) ...
+            //     <|img_thumbnail|> (thumbnail)
+            //   <|image_end|>
+            // single-tile:
+            //   <|image_start|> (image) <|image_end|>
+            img_beg            = "<|image_start|>";
+            img_end            = "<|image_end|>";
+            slice_tmpl         = MTMD_SLICE_TMPL_LFM2;
+            sli_img_start_tmpl = "<|img_row_%d_col_%d|>";
+            tok_ov_img_start   = {lookup_token("<|img_thumbnail|>")};
+            ov_img_first       = false;
         } else if (proj == PROJECTOR_TYPE_GLM4V) {
             img_beg = "<|begin_of_image|>";
             img_end = "<|end_of_image|>";
 
+        } else if (proj == PROJECTOR_TYPE_PADDLEOCR) {
+            // <|IMAGE_START|> ... (image embeddings) ... <|IMAGE_END|>
+            img_beg = "<|IMAGE_START|>";
+            img_end = "<|IMAGE_END|>";
         }
     }
 
@@ -339,8 +355,13 @@ struct mtmd_context {
             case PROJECTOR_TYPE_QWEN25O:
             case PROJECTOR_TYPE_ULTRAVOX:
             case PROJECTOR_TYPE_VOXTRAL:
+            case PROJECTOR_TYPE_GLMA:
+            case PROJECTOR_TYPE_MUSIC_FLAMINGO:
                 audio_preproc = std::make_unique(ctx_a);
                 break;
+            case PROJECTOR_TYPE_LFM2A:
+                audio_preproc = std::make_unique(ctx_a);
+                break;
             default:
                 GGML_ABORT("unsupported audio projector type");
         }
@@ -358,6 +379,9 @@ struct mtmd_context {
             // [BEGIN_AUDIO] ... (embeddings) ...
             aud_beg = "[BEGIN_AUDIO]";
 
+        } else if (proj == PROJECTOR_TYPE_MUSIC_FLAMINGO) {
+            //  ... (embeddings) ...
+            aud_beg = "";
         }
     }
 
@@ -563,11 +587,13 @@ struct mtmd_tokenizer {
             }
 
             // handle llava-uhd style preprocessing
+            const bool has_tiling_grid = batch_f32.grid_x > 0 && batch_f32.grid_y > 0;
             if (
                 ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_5
                 || ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_6
                 || ctx->slice_tmpl == MTMD_SLICE_TMPL_LLAMA4
                 || ctx->slice_tmpl == MTMD_SLICE_TMPL_IDEFICS3
+                || (ctx->slice_tmpl == MTMD_SLICE_TMPL_LFM2 && has_tiling_grid)
             ) {
                 const int n_col = batch_f32.grid_x;
                 const int n_row = batch_f32.grid_y;
@@ -629,7 +655,7 @@ struct mtmd_tokenizer {
                 }
 
                 mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
-                if (ctx->use_mrope) {
+                if (mtmd_decode_use_mrope(ctx)) {
                     // for Qwen2VL, we need this information for M-RoPE decoding positions
                     image_tokens->nx = clip_n_output_tokens_x(ctx->ctx_v, batch_f32.entries[0].get());
                     image_tokens->ny = clip_n_output_tokens_y(ctx->ctx_v, batch_f32.entries[0].get());
@@ -864,14 +890,25 @@ float * mtmd_get_output_embd(mtmd_context * ctx) {
 }
 
 bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
-    if (ctx->ctx_v && clip_get_projector_type(ctx->ctx_v) == PROJECTOR_TYPE_GEMMA3) {
-        return true;
+    switch (ctx->proj_type_v()) {
+        case PROJECTOR_TYPE_GEMMA3:
+            return true;
+        default:
+            return false;
     }
-    return false;
 }
 
 bool mtmd_decode_use_mrope(mtmd_context * ctx) {
-    return ctx->use_mrope;
+    switch (ctx->proj_type_v()) {
+        case PROJECTOR_TYPE_QWEN2VL:
+        case PROJECTOR_TYPE_QWEN25VL:
+        case PROJECTOR_TYPE_QWEN3VL:
+        case PROJECTOR_TYPE_GLM4V:
+        case PROJECTOR_TYPE_PADDLEOCR:
+            return true;
+        default:
+            return false;
+    }
 }
 
 bool mtmd_support_vision(mtmd_context * ctx) {
diff --git a/llama/llama.cpp/tools/mtmd/mtmd.h b/llama/llama.cpp/tools/mtmd/mtmd.h
index 72cec193774..a4a45b29955 100644
--- a/llama/llama.cpp/tools/mtmd/mtmd.h
+++ b/llama/llama.cpp/tools/mtmd/mtmd.h
@@ -27,6 +27,9 @@
  * - Make sure the C API is aligned with the libllama C API (as in llama.h)
  * - Do not include model name (e.g., qwen, gemma) in the API, use generic terms instead
  * - Keep the API minimal, do not expose internal details unless necessary
+ *
+ * IMPORTANT: The mtmd module does NOT accept pull requests that are fully or predominantly AI-generated.
+ * We encourage human contributors to ensure the quality and reliability of the codebase.
  */
 
 #ifdef LLAMA_SHARED
@@ -95,6 +98,10 @@ struct mtmd_context_params {
     // limit number of image tokens, only for vision models with dynamic resolution
     int image_min_tokens; // minimum number of tokens for image input (default: read from metadata)
     int image_max_tokens; // maximum number of tokens for image input (default: read from metadata)
+
+    // callback function passed over to mtmd proper
+    ggml_backend_sched_eval_callback cb_eval;
+    void * cb_eval_user_data;
 };
 
 MTMD_API const char * mtmd_default_marker(void);
@@ -220,7 +227,7 @@ MTMD_API int32_t mtmd_encode_chunk(mtmd_context * ctx,
 
 // get output embeddings from the last encode pass
 // the reading size (in bytes) is equal to:
-// llama_model_n_embd(model) * mtmd_input_chunk_get_n_tokens(chunk) * sizeof(float)
+// llama_model_n_embd_inp(model) * mtmd_input_chunk_get_n_tokens(chunk) * sizeof(float)
 MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
 
 // Set callback for all future logging events.
@@ -273,12 +280,12 @@ struct bitmap {
         ptr.reset(mtmd_bitmap_init(nx, ny, data));
     }
     ~bitmap() = default;
-    uint32_t nx() { return mtmd_bitmap_get_nx(ptr.get()); }
-    uint32_t ny() { return mtmd_bitmap_get_ny(ptr.get()); }
-    const unsigned char * data() { return mtmd_bitmap_get_data(ptr.get()); }
-    size_t n_bytes() { return mtmd_bitmap_get_n_bytes(ptr.get()); }
-    std::string id() { return mtmd_bitmap_get_id(ptr.get()); }
-    void set_id(const char * id) { mtmd_bitmap_set_id(ptr.get(), id); }
+    uint32_t nx() const { return mtmd_bitmap_get_nx(ptr.get()); }
+    uint32_t ny() const { return mtmd_bitmap_get_ny(ptr.get()); }
+    const unsigned char * data() const { return mtmd_bitmap_get_data(ptr.get()); }
+    size_t n_bytes() const { return mtmd_bitmap_get_n_bytes(ptr.get()); }
+    std::string id() const { return mtmd_bitmap_get_id(ptr.get()); }
+    void set_id(const char * id) const { mtmd_bitmap_set_id(ptr.get(), id); }
 };
 
 struct bitmaps {
@@ -302,8 +309,8 @@ struct input_chunks {
     input_chunks() = default;
     input_chunks(mtmd_input_chunks * chunks) : ptr(chunks) {}
     ~input_chunks() = default;
-    size_t size() { return mtmd_input_chunks_size(ptr.get()); }
-    const mtmd_input_chunk * operator[](size_t idx) {
+    size_t size() const { return mtmd_input_chunks_size(ptr.get()); }
+    const mtmd_input_chunk * operator[](size_t idx) const {
         return mtmd_input_chunks_get(ptr.get(), idx);
     }
 };
diff --git a/llama/llama.cpp/vendor/miniaudio/miniaudio.h b/llama/llama.cpp/vendor/miniaudio/miniaudio.h
index 2f5b9c4eaf3..24e676bb264 100644
--- a/llama/llama.cpp/vendor/miniaudio/miniaudio.h
+++ b/llama/llama.cpp/vendor/miniaudio/miniaudio.h
@@ -1,6 +1,6 @@
 /*
 Audio playback and capture library. Choice of public domain or MIT-0. See license statements at the end of this file.
-miniaudio - v0.11.24 - TBD
+miniaudio - v0.11.24 - 2026-01-17
 
 David Reid - mackron@gmail.com
 
@@ -3858,7 +3858,7 @@ typedef ma_uint16 wchar_t;
 
 
 /* Platform/backend detection. */
-#if defined(_WIN32) || defined(__COSMOPOLITAN__)
+#if defined(_WIN32)
     #define MA_WIN32
     #if defined(MA_FORCE_UWP) || (defined(WINAPI_FAMILY) && ((defined(WINAPI_FAMILY_PC_APP) && WINAPI_FAMILY == WINAPI_FAMILY_PC_APP) || (defined(WINAPI_FAMILY_PHONE_APP) && WINAPI_FAMILY == WINAPI_FAMILY_PHONE_APP)))
         #define MA_WIN32_UWP
@@ -4182,9 +4182,13 @@ typedef enum
     MA_CHANNEL_AUX_29             = 49,
     MA_CHANNEL_AUX_30             = 50,
     MA_CHANNEL_AUX_31             = 51,
+
+    /* Count. */
+    MA_CHANNEL_POSITION_COUNT,
+
+    /* Aliases. */
     MA_CHANNEL_LEFT               = MA_CHANNEL_FRONT_LEFT,
     MA_CHANNEL_RIGHT              = MA_CHANNEL_FRONT_RIGHT,
-    MA_CHANNEL_POSITION_COUNT     = (MA_CHANNEL_AUX_31 + 1)
 } _ma_channel_position; /* Do not use `_ma_channel_position` directly. Use `ma_channel` instead. */
 
 typedef enum
@@ -6604,16 +6608,12 @@ This section contains the APIs for device playback and capture. Here is where yo
     #if defined(MA_WIN32_DESKTOP)   /* DirectSound and WinMM backends are only supported on desktops. */
         #define MA_SUPPORT_DSOUND
         #define MA_SUPPORT_WINMM
-
-        /* Don't enable JACK here if compiling with Cosmopolitan. It'll be enabled in the Linux section below. */
-        #if !defined(__COSMOPOLITAN__)
-            #define MA_SUPPORT_JACK    /* JACK is technically supported on Windows, but I don't know how many people use it in practice... */
-        #endif
+        #define MA_SUPPORT_JACK     /* JACK is technically supported on Windows, but I don't know how many people use it in practice... */
     #endif
 #endif
 #if defined(MA_UNIX) && !defined(MA_ORBIS) && !defined(MA_PROSPERO)
     #if defined(MA_LINUX)
-        #if !defined(MA_ANDROID) && !defined(__COSMOPOLITAN__)   /* ALSA is not supported on Android. */
+        #if !defined(MA_ANDROID) && !defined(MA_EMSCRIPTEN)   /* ALSA is not supported on Android. */
             #define MA_SUPPORT_ALSA
         #endif
     #endif
@@ -10520,6 +10520,7 @@ typedef struct
     ma_decoding_backend_vtable** ppCustomDecodingBackendVTables;
     ma_uint32 customDecodingBackendCount;
     void* pCustomDecodingBackendUserData;
+    ma_resampler_config resampling;
 } ma_resource_manager_config;
 
 MA_API ma_resource_manager_config ma_resource_manager_config_init(void);
@@ -10847,6 +10848,7 @@ MA_API ma_result ma_node_graph_read_pcm_frames(ma_node_graph* pNodeGraph, void*
 MA_API ma_uint32 ma_node_graph_get_channels(const ma_node_graph* pNodeGraph);
 MA_API ma_uint64 ma_node_graph_get_time(const ma_node_graph* pNodeGraph);
 MA_API ma_result ma_node_graph_set_time(ma_node_graph* pNodeGraph, ma_uint64 globalTime);
+MA_API ma_uint32 ma_node_graph_get_processing_size_in_frames(const ma_node_graph* pNodeGraph);
 
 
 
@@ -11154,6 +11156,7 @@ typedef struct
     ma_bool8 isPitchDisabled;           /* Pitching can be explicitly disabled with MA_SOUND_FLAG_NO_PITCH to optimize processing. */
     ma_bool8 isSpatializationDisabled;  /* Spatialization can be explicitly disabled with MA_SOUND_FLAG_NO_SPATIALIZATION. */
     ma_uint8 pinnedListenerIndex;       /* The index of the listener this node should always use for spatialization. If set to MA_LISTENER_INDEX_CLOSEST the engine will use the closest listener. */
+    ma_resampler_config resampling;
 } ma_engine_node_config;
 
 MA_API ma_engine_node_config ma_engine_node_config_init(ma_engine* pEngine, ma_engine_node_type type, ma_uint32 flags);
@@ -11168,7 +11171,7 @@ typedef struct
     ma_uint32 volumeSmoothTimeInPCMFrames;
     ma_mono_expansion_mode monoExpansionMode;
     ma_fader fader;
-    ma_linear_resampler resampler;                      /* For pitch shift. */
+    ma_resampler resampler;                             /* For pitch shift. */
     ma_spatializer spatializer;
     ma_panner panner;
     ma_gainer volumeGainer;                             /* This will only be used if volumeSmoothTimeInPCMFrames is > 0. */
@@ -11224,6 +11227,7 @@ typedef struct
     ma_uint64 loopPointEndInPCMFrames;
     ma_sound_end_proc endCallback;              /* Fired when the sound reaches the end. Will be fired from the audio thread. Do not restart, uninitialize or otherwise change the state of the sound from here. Instead fire an event or set a variable to indicate to a different thread to change the start of the sound. Will not be fired in response to a scheduled stop with ma_sound_set_stop_time_*(). */
     void* pEndCallbackUserData;
+    ma_resampler_config pitchResampling;
 #ifndef MA_NO_RESOURCE_MANAGER
     ma_resource_manager_pipeline_notifications initNotifications;
 #endif
@@ -11242,7 +11246,10 @@ struct ma_sound
     MA_ATOMIC(4, ma_bool32) atEnd;
     ma_sound_end_proc endCallback;
     void* pEndCallbackUserData;
-    ma_bool8 ownsDataSource;
+    float* pProcessingCache;            /* Will be null if pDataSource is null. */
+    ma_uint32 processingCacheFramesRemaining;
+    ma_uint32 processingCacheCap;
+    ma_bool8 ownsDataSource;    
 
     /*
     We're declaring a resource manager data source object here to save us a malloc when loading a
@@ -11300,6 +11307,8 @@ typedef struct
     ma_vfs* pResourceManagerVFS;                    /* A pointer to a pre-allocated VFS object to use with the resource manager. This is ignored if pResourceManager is not NULL. */
     ma_engine_process_proc onProcess;               /* Fired at the end of each call to ma_engine_read_pcm_frames(). For engine's that manage their own internal device (the default configuration), this will be fired from the audio thread, and you do not need to call ma_engine_read_pcm_frames() manually in order to trigger this. */
     void* pProcessUserData;                         /* User data that's passed into onProcess. */
+    ma_resampler_config resourceManagerResampling;  /* The resampling config to use with the resource manager. */
+    ma_resampler_config pitchResampling;            /* The resampling config for the pitch and Doppler effects. You will typically want this to be a fast resampler. For high quality stuff, it's recommended that you pre-resample. */
 } ma_engine_config;
 
 MA_API ma_engine_config ma_engine_config_init(void);
@@ -11329,6 +11338,7 @@ struct ma_engine
     ma_mono_expansion_mode monoExpansionMode;
     ma_engine_process_proc onProcess;
     void* pProcessUserData;
+    ma_resampler_config pitchResamplingConfig;
 };
 
 MA_API ma_result ma_engine_init(const ma_engine_config* pConfig, ma_engine* pEngine);
@@ -11389,8 +11399,12 @@ MA_API ma_engine* ma_sound_get_engine(const ma_sound* pSound);
 MA_API ma_data_source* ma_sound_get_data_source(const ma_sound* pSound);
 MA_API ma_result ma_sound_start(ma_sound* pSound);
 MA_API ma_result ma_sound_stop(ma_sound* pSound);
-MA_API ma_result ma_sound_stop_with_fade_in_pcm_frames(ma_sound* pSound, ma_uint64 fadeLengthInFrames);     /* Will overwrite any scheduled stop and fade. */
-MA_API ma_result ma_sound_stop_with_fade_in_milliseconds(ma_sound* pSound, ma_uint64 fadeLengthInFrames);   /* Will overwrite any scheduled stop and fade. */
+MA_API ma_result ma_sound_stop_with_fade_in_pcm_frames(ma_sound* pSound, ma_uint64 fadeLengthInFrames);     /* Will overwrite any scheduled stop and fade. If you want to restart the sound, first reset it with `ma_sound_reset_stop_time_and_fade()`. There are plans to make this less awkward in the future. */
+MA_API ma_result ma_sound_stop_with_fade_in_milliseconds(ma_sound* pSound, ma_uint64 fadeLengthInFrames);   /* Will overwrite any scheduled stop and fade. If you want to restart the sound, first reset it with `ma_sound_reset_stop_time_and_fade()`. There are plans to make this less awkward in the future. */
+MA_API void ma_sound_reset_start_time(ma_sound* pSound);
+MA_API void ma_sound_reset_stop_time(ma_sound* pSound);
+MA_API void ma_sound_reset_fade(ma_sound* pSound);
+MA_API void ma_sound_reset_stop_time_and_fade(ma_sound* pSound);  /* Resets fades and scheduled stop time. Does not seek back to the start. */
 MA_API void ma_sound_set_volume(ma_sound* pSound, float volume);
 MA_API float ma_sound_get_volume(const ma_sound* pSound);
 MA_API void ma_sound_set_pan(ma_sound* pSound, float pan);
@@ -11643,7 +11657,7 @@ IMPLEMENTATION
 #endif
 
 /* Intrinsics Support */
-#if (defined(MA_X64) || defined(MA_X86)) && !defined(__COSMOPOLITAN__)
+#if defined(MA_X64) || defined(MA_X86)
     #if defined(_MSC_VER) && !defined(__clang__)
         /* MSVC. */
         #if _MSC_VER >= 1400 && !defined(MA_NO_SSE2)   /* 2005 */
@@ -12080,7 +12094,7 @@ static MA_INLINE unsigned int ma_disable_denormals(void)
     }
     #elif defined(MA_X86) || defined(MA_X64)
     {
-        #if defined(MA_SUPPORT_SSE2) && defined(__SSE2__) && !(defined(__TINYC__) || defined(__WATCOMC__) || defined(__COSMOPOLITAN__)) /* <-- Add compilers that lack support for _mm_getcsr() and _mm_setcsr() to this list. */
+        #if defined(MA_SUPPORT_SSE2) && defined(__SSE2__) && !(defined(__TINYC__) || defined(__WATCOMC__)) /* <-- Add compilers that lack support for _mm_getcsr() and _mm_setcsr() to this list. */
         {
             prevState = _mm_getcsr();
             _mm_setcsr(prevState | MA_MM_DENORMALS_ZERO_MASK | MA_MM_FLUSH_ZERO_MASK);
@@ -12120,7 +12134,7 @@ static MA_INLINE void ma_restore_denormals(unsigned int prevState)
     }
     #elif defined(MA_X86) || defined(MA_X64)
     {
-        #if defined(MA_SUPPORT_SSE2) && defined(__SSE2__) && !(defined(__TINYC__) || defined(__WATCOMC__) || defined(__COSMOPOLITAN__))   /* <-- Add compilers that lack support for _mm_getcsr() and _mm_setcsr() to this list. */
+        #if defined(MA_SUPPORT_SSE2) && defined(__SSE2__) && !(defined(__TINYC__) || defined(__WATCOMC__))   /* <-- Add compilers that lack support for _mm_getcsr() and _mm_setcsr() to this list. */
         {
             _mm_setcsr(prevState);
         }
@@ -17616,7 +17630,7 @@ static ma_result ma_thread_create__posix(ma_thread* pThread, ma_thread_priority
             int priorityStep = (priorityMax - priorityMin) / 7;  /* 7 = number of priorities supported by miniaudio. */
 
             struct sched_param sched;
-            if (pthread_attr_getschedparam(&attr, &sched) == 0) {
+            if (priorityMin != -1 && priorityMax != -1 && pthread_attr_getschedparam(&attr, &sched) == 0) {
                 if (priority == ma_thread_priority_idle) {
                     sched.sched_priority = priorityMin;
                 } else if (priority == ma_thread_priority_realtime) {
@@ -20073,7 +20087,7 @@ Timing
             struct timespec newTime;
             clock_gettime(MA_CLOCK_ID, &newTime);
 
-            pTimer->counter = (newTime.tv_sec * 1000000000) + newTime.tv_nsec;
+            pTimer->counter = ((ma_int64)newTime.tv_sec * 1000000000) + newTime.tv_nsec;
         }
 
         static MA_INLINE double ma_timer_get_time_in_seconds(ma_timer* pTimer)
@@ -20084,7 +20098,7 @@ Timing
             struct timespec newTime;
             clock_gettime(MA_CLOCK_ID, &newTime);
 
-            newTimeCounter = (newTime.tv_sec * 1000000000) + newTime.tv_nsec;
+            newTimeCounter = ((ma_uint64)newTime.tv_sec * 1000000000) + newTime.tv_nsec;
             oldTimeCounter = pTimer->counter;
 
             return (newTimeCounter - oldTimeCounter) / 1000000000.0;
@@ -20095,7 +20109,7 @@ Timing
             struct timeval newTime;
             gettimeofday(&newTime, NULL);
 
-            pTimer->counter = (newTime.tv_sec * 1000000) + newTime.tv_usec;
+            pTimer->counter = ((ma_int64)newTime.tv_sec * 1000000) + newTime.tv_usec;
         }
 
         static MA_INLINE double ma_timer_get_time_in_seconds(ma_timer* pTimer)
@@ -20106,7 +20120,7 @@ Timing
             struct timeval newTime;
             gettimeofday(&newTime, NULL);
 
-            newTimeCounter = (newTime.tv_sec * 1000000) + newTime.tv_usec;
+            newTimeCounter = ((ma_uint64)newTime.tv_sec * 1000000) + newTime.tv_usec;
             oldTimeCounter = pTimer->counter;
 
             return (newTimeCounter - oldTimeCounter) / 1000000.0;
@@ -31228,6 +31242,7 @@ static ma_result ma_init_pa_mainloop_and_pa_context__pulse(ma_context* pContext,
     result = ma_result_from_pulse(((ma_pa_context_connect_proc)pContext->pulse.pa_context_connect)((ma_pa_context*)pPulseContext, pServerName, (tryAutoSpawn) ? MA_PA_CONTEXT_NOFLAGS : MA_PA_CONTEXT_NOAUTOSPAWN, NULL));
     if (result != MA_SUCCESS) {
         ma_log_postf(ma_context_get_log(pContext), MA_LOG_LEVEL_ERROR, "[PulseAudio] Failed to connect PulseAudio context.");
+        ((ma_pa_context_unref_proc)pContext->pulse.pa_context_unref)((ma_pa_context*)(pPulseContext));
         ((ma_pa_mainloop_free_proc)pContext->pulse.pa_mainloop_free)((ma_pa_mainloop*)(pMainLoop));
         return result;
     }
@@ -31236,6 +31251,7 @@ static ma_result ma_init_pa_mainloop_and_pa_context__pulse(ma_context* pContext,
     result = ma_wait_for_pa_context_to_connect__pulse(pContext, pMainLoop, pPulseContext);
     if (result != MA_SUCCESS) {
         ma_log_postf(ma_context_get_log(pContext), MA_LOG_LEVEL_ERROR, "[PulseAudio] Waiting for connection failed.");
+        ((ma_pa_context_unref_proc)pContext->pulse.pa_context_unref)((ma_pa_context*)(pPulseContext));
         ((ma_pa_mainloop_free_proc)pContext->pulse.pa_mainloop_free)((ma_pa_mainloop*)(pMainLoop));
         return result;
     }
@@ -41747,8 +41763,11 @@ static EM_BOOL ma_audio_worklet_process_callback__webaudio(int inputCount, const
         frameCount = pDevice->capture.internalPeriodSizeInFrames;
     }
 
+    /*
+    If this is called by the device has not yet been started we need to return early, making sure we output silence to
+    the output buffer.
+    */
     if (ma_device_get_state(pDevice) != ma_device_state_started) {
-        /* Fill the output buffer with zero to avoid a noise sound */
         for (int i = 0; i < outputCount; i += 1) {
             MA_ZERO_MEMORY(pOutputs[i].data, pOutputs[i].numberOfChannels * frameCount * sizeof(float));
         }
@@ -41770,7 +41789,9 @@ static EM_BOOL ma_audio_worklet_process_callback__webaudio(int inputCount, const
     if (outputCount > 0) {
         /* If it's a capture-only device, we'll need to output silence. */
         if (pDevice->type == ma_device_type_capture) {
-            MA_ZERO_MEMORY(pOutputs[0].data, frameCount * pDevice->playback.internalChannels * sizeof(float));
+            for (int i = 0; i < outputCount; i += 1) {
+                MA_ZERO_MEMORY(pOutputs[i].data, pOutputs[i].numberOfChannels * frameCount * sizeof(float));
+            }
         } else {
             ma_device_process_pcm_frames_playback__webaudio(pDevice, frameCount, pDevice->webaudio.pIntermediaryBuffer);
 
@@ -41780,6 +41801,14 @@ static EM_BOOL ma_audio_worklet_process_callback__webaudio(int inputCount, const
                     pOutputs[0].data[frameCount*iChannel + iFrame] = pDevice->webaudio.pIntermediaryBuffer[iFrame*pDevice->playback.internalChannels + iChannel];
                 }
             }
+
+            /*
+            Just above we output data to the first output buffer. Here we just make sure we're putting silence into any
+            remaining output buffers.
+            */
+            for (int i = 1; i < outputCount; i += 1) {  /* <-- Note that the counter starts at 1 instead of 0. */
+                MA_ZERO_MEMORY(pOutputs[i].data, pOutputs[i].numberOfChannels * frameCount * sizeof(float));
+            }
         }
     }
 
@@ -50850,15 +50879,15 @@ static /*__attribute__((noinline))*/ ma_result ma_gainer_process_pcm_frames_inte
                     a += d;
                 }
             }
+
+            pFramesOut = ma_offset_ptr(pFramesOut, interpolatedFrameCount * sizeof(float));
+            pFramesIn  = ma_offset_ptr(pFramesIn,  interpolatedFrameCount * sizeof(float));
         }
 
+        frameCount -= interpolatedFrameCount;
+
         /* Make sure the timer is updated. */
         pGainer->t = (ma_uint32)ma_min(pGainer->t + interpolatedFrameCount, pGainer->config.smoothTimeInFrames);
-
-        /* Adjust our arguments so the next part can work normally. */
-        frameCount -= interpolatedFrameCount;
-        pFramesOut  = ma_offset_ptr(pFramesOut, interpolatedFrameCount * sizeof(float));
-        pFramesIn   = ma_offset_ptr(pFramesIn,  interpolatedFrameCount * sizeof(float));
     }
 
     /* All we need to do here is apply the new gains using an optimized path. */
@@ -52286,13 +52315,16 @@ static float ma_calculate_angular_gain(ma_vec3f dirA, ma_vec3f dirB, float coneI
 
 MA_API ma_result ma_spatializer_process_pcm_frames(ma_spatializer* pSpatializer, ma_spatializer_listener* pListener, void* pFramesOut, const void* pFramesIn, ma_uint64 frameCount)
 {
-    ma_channel* pChannelMapIn  = pSpatializer->pChannelMapIn;
-    ma_channel* pChannelMapOut = pListener->config.pChannelMapOut;
+    ma_channel* pChannelMapIn;
+    ma_channel* pChannelMapOut;
 
-    if (pSpatializer == NULL) {
+    if (pSpatializer == NULL || pListener == NULL) {
         return MA_INVALID_ARGS;
     }
 
+    pChannelMapIn = pSpatializer->pChannelMapIn;
+    pChannelMapOut = pListener->config.pChannelMapOut;
+
     /* If we're not spatializing we need to run an optimized path. */
     if (ma_atomic_load_i32(&pSpatializer->attenuationModel) == ma_attenuation_model_none) {
         if (ma_spatializer_listener_is_enabled(pListener)) {
@@ -52337,23 +52369,17 @@ MA_API ma_result ma_spatializer_process_pcm_frames(ma_spatializer* pSpatializer,
         We'll need the listener velocity for doppler pitch calculations. The speed of sound is
         defined by the listener, so we'll grab that here too.
         */
-        if (pListener != NULL) {
-            listenerVel  = ma_spatializer_listener_get_velocity(pListener);
-            speedOfSound = pListener->config.speedOfSound;
-        } else {
-            listenerVel  = ma_vec3f_init_3f(0, 0, 0);
-            speedOfSound = MA_DEFAULT_SPEED_OF_SOUND;
-        }
+        listenerVel  = ma_spatializer_listener_get_velocity(pListener);
+        speedOfSound = pListener->config.speedOfSound;
 
-        if (pListener == NULL || ma_spatializer_get_positioning(pSpatializer) == ma_positioning_relative) {
-            /* There's no listener or we're using relative positioning. */
+        if (ma_spatializer_get_positioning(pSpatializer) == ma_positioning_relative) {
             relativePos = ma_spatializer_get_position(pSpatializer);
             relativeDir = ma_spatializer_get_direction(pSpatializer);
         } else {
             /*
-            We've found a listener and we're using absolute positioning. We need to transform the
-            sound's position and direction so that it's relative to listener. Later on we'll use
-            this for determining the factors to apply to each channel to apply the panning effect.
+            We're using absolute positioning. We need to transform the sound's position and
+            direction so that it's relative to listener. Later on we'll use this for determining
+            the factors to apply to each channel to apply the panning effect.
             */
             ma_spatializer_get_relative_position_and_direction(pSpatializer, pListener, &relativePos, &relativeDir);
         }
@@ -54388,7 +54414,7 @@ static ma_bool32 ma_is_spatial_channel_position(ma_channel channelPosition)
         return MA_FALSE;
     }
 
-    if (channelPosition >= MA_CHANNEL_AUX_0 && channelPosition <= MA_CHANNEL_AUX_31) {
+    if (channelPosition >= MA_CHANNEL_AUX_0) {
         return MA_FALSE;
     }
 
@@ -61676,7 +61702,6 @@ static ma_result ma_default_vfs_info(ma_vfs* pVFS, ma_vfs_file file, ma_file_inf
 
     if (result == MA_NOT_IMPLEMENTED) {
         /* Not implemented. Fall back to seek/tell/seek. */
-        ma_result result;
         ma_int64 cursor;
         ma_int64 sizeInBytes;
         
@@ -61884,6 +61909,8 @@ Decoding and Encoding Headers. These are auto-generated from a tool.
 
 **************************************************************************************************************************************************************/
 #if !defined(MA_NO_WAV) && (!defined(MA_NO_DECODING) || !defined(MA_NO_ENCODING))
+#define MA_HAS_WAV
+
 /* dr_wav_h begin */
 #ifndef ma_dr_wav_h
 #define ma_dr_wav_h
@@ -61894,7 +61921,7 @@ extern "C" {
 #define MA_DR_WAV_XSTRINGIFY(x)     MA_DR_WAV_STRINGIFY(x)
 #define MA_DR_WAV_VERSION_MAJOR     0
 #define MA_DR_WAV_VERSION_MINOR     14
-#define MA_DR_WAV_VERSION_REVISION  1
+#define MA_DR_WAV_VERSION_REVISION  4
 #define MA_DR_WAV_VERSION_STRING    MA_DR_WAV_XSTRINGIFY(MA_DR_WAV_VERSION_MAJOR) "." MA_DR_WAV_XSTRINGIFY(MA_DR_WAV_VERSION_MINOR) "." MA_DR_WAV_XSTRINGIFY(MA_DR_WAV_VERSION_REVISION)
 #include 
 #define MA_DR_WAVE_FORMAT_PCM          0x1
@@ -62317,6 +62344,8 @@ MA_API ma_bool32 ma_dr_wav_fourcc_equal(const ma_uint8* a, const char* b);
 #endif  /* MA_NO_WAV */
 
 #if !defined(MA_NO_FLAC) && !defined(MA_NO_DECODING)
+#define MA_HAS_FLAC
+
 /* dr_flac_h begin */
 #ifndef ma_dr_flac_h
 #define ma_dr_flac_h
@@ -62327,7 +62356,7 @@ extern "C" {
 #define MA_DR_FLAC_XSTRINGIFY(x)     MA_DR_FLAC_STRINGIFY(x)
 #define MA_DR_FLAC_VERSION_MAJOR     0
 #define MA_DR_FLAC_VERSION_MINOR     13
-#define MA_DR_FLAC_VERSION_REVISION  2
+#define MA_DR_FLAC_VERSION_REVISION  3
 #define MA_DR_FLAC_VERSION_STRING    MA_DR_FLAC_XSTRINGIFY(MA_DR_FLAC_VERSION_MAJOR) "." MA_DR_FLAC_XSTRINGIFY(MA_DR_FLAC_VERSION_MINOR) "." MA_DR_FLAC_XSTRINGIFY(MA_DR_FLAC_VERSION_REVISION)
 #include 
 #if defined(_MSC_VER) && _MSC_VER >= 1700
@@ -62609,6 +62638,8 @@ MA_API ma_bool32 ma_dr_flac_next_cuesheet_track(ma_dr_flac_cuesheet_track_iterat
 #endif  /* MA_NO_FLAC */
 
 #if !defined(MA_NO_MP3) && !defined(MA_NO_DECODING)
+#define MA_HAS_MP3
+
 #ifndef MA_DR_MP3_NO_SIMD
     #if (defined(MA_NO_NEON) && defined(MA_ARM)) || (defined(MA_NO_SSE2) && (defined(MA_X86) || defined(MA_X64)))
     #define MA_DR_MP3_NO_SIMD
@@ -62625,7 +62656,7 @@ extern "C" {
 #define MA_DR_MP3_XSTRINGIFY(x)     MA_DR_MP3_STRINGIFY(x)
 #define MA_DR_MP3_VERSION_MAJOR     0
 #define MA_DR_MP3_VERSION_MINOR     7
-#define MA_DR_MP3_VERSION_REVISION  2
+#define MA_DR_MP3_VERSION_REVISION  3
 #define MA_DR_MP3_VERSION_STRING    MA_DR_MP3_XSTRINGIFY(MA_DR_MP3_VERSION_MAJOR) "." MA_DR_MP3_XSTRINGIFY(MA_DR_MP3_VERSION_MINOR) "." MA_DR_MP3_XSTRINGIFY(MA_DR_MP3_VERSION_REVISION)
 #include 
 #define MA_DR_MP3_MAX_PCM_FRAMES_PER_MP3_FRAME  1152
@@ -63229,7 +63260,6 @@ static ma_result ma_decoder_init_custom_from_memory__internal(const void* pData,
 
 /* WAV */
 #ifdef ma_dr_wav_h
-#define MA_HAS_WAV
 
 typedef struct
 {
@@ -63935,7 +63965,6 @@ static ma_result ma_decoder_init_wav_from_memory__internal(const void* pData, si
 
 /* FLAC */
 #ifdef ma_dr_flac_h
-#define MA_HAS_FLAC
 
 typedef struct
 {
@@ -64579,7 +64608,6 @@ static ma_result ma_decoder_init_flac_from_memory__internal(const void* pData, s
 
 /* MP3 */
 #ifdef ma_dr_mp3_h
-#define MA_HAS_MP3
 
 typedef struct
 {
@@ -66257,11 +66285,9 @@ static ma_result ma_decoder_init__internal(ma_decoder_read_proc onRead, ma_decod
         We use trial and error to open a decoder. We prioritize custom decoders so that if they
         implement the same encoding format they take priority over the built-in decoders.
         */
+        result = ma_decoder_init_custom__internal(pConfig, pDecoder);
         if (result != MA_SUCCESS) {
-            result = ma_decoder_init_custom__internal(pConfig, pDecoder);
-            if (result != MA_SUCCESS) {
-                onSeek(pDecoder, 0, ma_seek_origin_start);
-            }
+            onSeek(pDecoder, 0, ma_seek_origin_start);
         }
 
         /*
@@ -66525,14 +66551,6 @@ MA_API ma_result ma_decoder_init_memory(const void* pData, size_t dataSize, cons
         /* Initialization was successful. Finish up. */
         result = ma_decoder__postinit(&config, pDecoder);
         if (result != MA_SUCCESS) {
-            /*
-            The backend was initialized successfully, but for some reason post-initialization failed. This is most likely
-            due to an out of memory error. We're going to abort with an error here and not try to recover.
-            */
-            if (pDecoder->pBackendVTable != NULL && pDecoder->pBackendVTable->onUninit != NULL) {
-                pDecoder->pBackendVTable->onUninit(pDecoder->pBackendUserData, &pDecoder->pBackend, &pDecoder->allocationCallbacks);
-            }
-
             return result;
         }
     } else {
@@ -66833,11 +66851,9 @@ MA_API ma_result ma_decoder_init_vfs(ma_vfs* pVFS, const char* pFilePath, const
         We use trial and error to open a decoder. We prioritize custom decoders so that if they
         implement the same encoding format they take priority over the built-in decoders.
         */
+        result = ma_decoder_init_custom__internal(&config, pDecoder);
         if (result != MA_SUCCESS) {
-            result = ma_decoder_init_custom__internal(&config, pDecoder);
-            if (result != MA_SUCCESS) {
-                ma_decoder__on_seek_vfs(pDecoder, 0, ma_seek_origin_start);
-            }
+            ma_decoder__on_seek_vfs(pDecoder, 0, ma_seek_origin_start);
         }
 
         /*
@@ -66966,11 +66982,9 @@ MA_API ma_result ma_decoder_init_vfs_w(ma_vfs* pVFS, const wchar_t* pFilePath, c
         We use trial and error to open a decoder. We prioritize custom decoders so that if they
         implement the same encoding format they take priority over the built-in decoders.
         */
+        result = ma_decoder_init_custom__internal(&config, pDecoder);
         if (result != MA_SUCCESS) {
-            result = ma_decoder_init_custom__internal(&config, pDecoder);
-            if (result != MA_SUCCESS) {
-                ma_decoder__on_seek_vfs(pDecoder, 0, ma_seek_origin_start);
-            }
+            ma_decoder__on_seek_vfs(pDecoder, 0, ma_seek_origin_start);
         }
 
         /*
@@ -67152,14 +67166,6 @@ MA_API ma_result ma_decoder_init_file(const char* pFilePath, const ma_decoder_co
         /* Initialization was successful. Finish up. */
         result = ma_decoder__postinit(&config, pDecoder);
         if (result != MA_SUCCESS) {
-            /*
-            The backend was initialized successfully, but for some reason post-initialization failed. This is most likely
-            due to an out of memory error. We're going to abort with an error here and not try to recover.
-            */
-            if (pDecoder->pBackendVTable != NULL && pDecoder->pBackendVTable->onUninit != NULL) {
-                pDecoder->pBackendVTable->onUninit(pDecoder->pBackendUserData, &pDecoder->pBackend, &pDecoder->allocationCallbacks);
-            }
-
             return result;
         }
     } else {
@@ -67302,14 +67308,6 @@ MA_API ma_result ma_decoder_init_file_w(const wchar_t* pFilePath, const ma_decod
         /* Initialization was successful. Finish up. */
         result = ma_decoder__postinit(&config, pDecoder);
         if (result != MA_SUCCESS) {
-            /*
-            The backend was initialized successfully, but for some reason post-initialization failed. This is most likely
-            due to an out of memory error. We're going to abort with an error here and not try to recover.
-            */
-            if (pDecoder->pBackendVTable != NULL && pDecoder->pBackendVTable->onUninit != NULL) {
-                pDecoder->pBackendVTable->onUninit(pDecoder->pBackendUserData, &pDecoder->pBackend, &pDecoder->allocationCallbacks);
-            }
-
             return result;
         }
     } else {
@@ -69955,6 +69953,7 @@ MA_API ma_resource_manager_config ma_resource_manager_config_init(void)
     config.decodedSampleRate = 0;
     config.jobThreadCount    = 1;   /* A single miniaudio-managed job thread by default. */
     config.jobQueueCapacity  = MA_JOB_TYPE_RESOURCE_MANAGER_QUEUE_CAPACITY;
+    config.resampling        = ma_resampler_config_init(ma_format_unknown, 0, 0, 0, ma_resample_algorithm_linear); /* Format/channels/rate doesn't matter here. */
 
     /* Flags. */
     config.flags = 0;
@@ -70208,6 +70207,7 @@ static ma_decoder_config ma_resource_manager__init_decoder_config(ma_resource_ma
     config.ppCustomBackendVTables = pResourceManager->config.ppCustomDecodingBackendVTables;
     config.customBackendCount     = pResourceManager->config.customDecodingBackendCount;
     config.pCustomBackendUserData = pResourceManager->config.pCustomDecodingBackendUserData;
+    config.resampling = pResourceManager->config.resampling;
 
     return config;
 }
@@ -71533,13 +71533,13 @@ MA_API ma_result ma_resource_manager_data_buffer_get_data_format(ma_resource_man
 
 MA_API ma_result ma_resource_manager_data_buffer_get_cursor_in_pcm_frames(ma_resource_manager_data_buffer* pDataBuffer, ma_uint64* pCursor)
 {
-    /* We cannot be using the data source after it's been uninitialized. */
-    MA_ASSERT(ma_resource_manager_data_buffer_node_result(pDataBuffer->pNode) != MA_UNAVAILABLE);
-
     if (pDataBuffer == NULL || pCursor == NULL) {
         return MA_INVALID_ARGS;
     }
 
+    /* We cannot be using the data source after it's been uninitialized. */
+    MA_ASSERT(ma_resource_manager_data_buffer_node_result(pDataBuffer->pNode) != MA_UNAVAILABLE);
+
     *pCursor = 0;
 
     switch (ma_resource_manager_data_buffer_node_get_data_supply_type(pDataBuffer->pNode))
@@ -71573,13 +71573,13 @@ MA_API ma_result ma_resource_manager_data_buffer_get_cursor_in_pcm_frames(ma_res
 
 MA_API ma_result ma_resource_manager_data_buffer_get_length_in_pcm_frames(ma_resource_manager_data_buffer* pDataBuffer, ma_uint64* pLength)
 {
-    /* We cannot be using the data source after it's been uninitialized. */
-    MA_ASSERT(ma_resource_manager_data_buffer_node_result(pDataBuffer->pNode) != MA_UNAVAILABLE);
-
     if (pDataBuffer == NULL || pLength == NULL) {
         return MA_INVALID_ARGS;
     }
 
+    /* We cannot be using the data source after it's been uninitialized. */
+    MA_ASSERT(ma_resource_manager_data_buffer_node_result(pDataBuffer->pNode) != MA_UNAVAILABLE);
+
     if (ma_resource_manager_data_buffer_node_get_data_supply_type(pDataBuffer->pNode) == ma_resource_manager_data_supply_type_unknown) {
         return MA_BUSY; /* Still loading. */
     }
@@ -72934,8 +72934,6 @@ static ma_result ma_job_process__resource_manager__free_data_buffer_node(ma_job*
         return ma_resource_manager_post_job(pResourceManager, pJob);    /* Out of order. */
     }
 
-    ma_resource_manager_data_buffer_node_free(pResourceManager, pDataBufferNode);
-
     /* The event needs to be signalled last. */
     if (pJob->data.resourceManager.freeDataBufferNode.pDoneNotification != NULL) {
         ma_async_notification_signal(pJob->data.resourceManager.freeDataBufferNode.pDoneNotification);
@@ -72946,6 +72944,9 @@ static ma_result ma_job_process__resource_manager__free_data_buffer_node(ma_job*
     }
 
     ma_atomic_fetch_add_32(&pDataBufferNode->executionPointer, 1);
+
+    ma_resource_manager_data_buffer_node_free(pResourceManager, pDataBufferNode);
+
     return MA_SUCCESS;
 }
 
@@ -73818,6 +73819,15 @@ MA_API ma_result ma_node_graph_set_time(ma_node_graph* pNodeGraph, ma_uint64 glo
     return ma_node_set_time(&pNodeGraph->endpoint, globalTime); /* Global time is just the local time of the endpoint. */
 }
 
+MA_API ma_uint32 ma_node_graph_get_processing_size_in_frames(const ma_node_graph* pNodeGraph)
+{
+    if (pNodeGraph == NULL) {
+        return 0;
+    }
+
+    return pNodeGraph->processingSizeInFrames;
+}
+
 
 #define MA_NODE_OUTPUT_BUS_FLAG_HAS_READ    0x01    /* Whether or not this bus ready to read more data. Only used on nodes with multiple output buses. */
 
@@ -74977,12 +74987,12 @@ MA_API ma_node_state ma_node_get_state_by_time_range(const ma_node* pNode, ma_ui
     its start time not having been reached yet. Also, the stop time may have also been reached in
     which case it'll be considered stopped.
     */
-    if (ma_node_get_state_time(pNode, ma_node_state_started) > globalTimeBeg) {
-        return ma_node_state_stopped;   /* Start time has not yet been reached. */
+    if (ma_node_get_state_time(pNode, ma_node_state_stopped) < globalTimeBeg) {
+        return ma_node_state_stopped;   /* End time is before the start of the range. */
     }
 
-    if (ma_node_get_state_time(pNode, ma_node_state_stopped) <= globalTimeEnd) {
-        return ma_node_state_stopped;   /* Stop time has been reached. */
+    if (ma_node_get_state_time(pNode, ma_node_state_started) > globalTimeEnd) {
+        return ma_node_state_stopped;   /* Start time is after the end of the range. */
     }
 
     /* Getting here means the node is marked as started and is within its start/stop times. */
@@ -75062,14 +75072,14 @@ static ma_result ma_node_read_pcm_frames(ma_node* pNode, ma_uint32 outputBusInde
         return MA_INVALID_ARGS; /* Invalid output bus index. */
     }
 
+    globalTimeBeg = globalTime;
+    globalTimeEnd = globalTime + frameCount;
+
     /* Don't do anything if we're in a stopped state. */
-    if (ma_node_get_state_by_time_range(pNode, globalTime, globalTime + frameCount) != ma_node_state_started) {
+    if (ma_node_get_state_by_time_range(pNode, globalTimeBeg, globalTimeEnd) != ma_node_state_started) {
         return MA_SUCCESS;  /* We're in a stopped state. This is not an error - we just need to not read anything. */
     }
 
-
-    globalTimeBeg = globalTime;
-    globalTimeEnd = globalTime + frameCount;
     startTime = ma_node_get_state_time(pNode, ma_node_state_started);
     stopTime  = ma_node_get_state_time(pNode, ma_node_state_stopped);
 
@@ -75082,11 +75092,16 @@ static ma_result ma_node_read_pcm_frames(ma_node* pNode, ma_uint32 outputBusInde
     therefore need to offset it by a number of frames to accommodate. The same thing applies for
     the stop time.
     */
-    timeOffsetBeg = (globalTimeBeg < startTime) ? (ma_uint32)(globalTimeEnd - startTime) : 0;
+    timeOffsetBeg = (globalTimeBeg < startTime) ? (ma_uint32)(startTime - globalTimeBeg) : 0;
     timeOffsetEnd = (globalTimeEnd > stopTime)  ? (ma_uint32)(globalTimeEnd - stopTime)  : 0;
 
     /* Trim based on the start offset. We need to silence the start of the buffer. */
     if (timeOffsetBeg > 0) {
+        MA_ASSERT(timeOffsetBeg <= frameCount);
+        if (timeOffsetBeg > frameCount) {
+            timeOffsetBeg = frameCount;
+        }
+
         ma_silence_pcm_frames(pFramesOut, timeOffsetBeg, ma_format_f32, ma_node_get_output_channels(pNode, outputBusIndex));
         pFramesOut += timeOffsetBeg * ma_node_get_output_channels(pNode, outputBusIndex);
         frameCount -= timeOffsetBeg;
@@ -75094,6 +75109,11 @@ static ma_result ma_node_read_pcm_frames(ma_node* pNode, ma_uint32 outputBusInde
 
     /* Trim based on the end offset. We don't need to silence the tail section because we'll just have a reduced value written to pFramesRead. */
     if (timeOffsetEnd > 0) {
+        MA_ASSERT(timeOffsetEnd <= frameCount);
+        if (timeOffsetEnd > frameCount) {
+            timeOffsetEnd = frameCount;
+        }
+
         frameCount -= timeOffsetEnd;
     }
 
@@ -76508,12 +76528,20 @@ static void ma_sound_set_at_end(ma_sound* pSound, ma_bool32 atEnd)
     MA_ASSERT(pSound != NULL);
     ma_atomic_exchange_32(&pSound->atEnd, atEnd);
 
+    /*
+    When this function is called the state of the sound will not yet be in a stopped state. This makes it confusing
+    because an end callback will intuitively expect ma_sound_is_playing() to return false from inside the callback.
+    I'm therefore no longer firing the callback here and will instead fire it manually in the *next* processing step
+    when the state should be set to stopped as expected.
+    */
+    #if 0
     /* Fire any callbacks or events. */
     if (atEnd) {
         if (pSound->endCallback != NULL) {
             pSound->endCallback(pSound->pEndCallbackUserData, pSound);
         }
     }
+    #endif
 }
 
 static ma_bool32 ma_sound_get_at_end(const ma_sound* pSound)
@@ -76533,6 +76561,7 @@ MA_API ma_engine_node_config ma_engine_node_config_init(ma_engine* pEngine, ma_e
     config.isPitchDisabled          = (flags & MA_SOUND_FLAG_NO_PITCH) != 0;
     config.isSpatializationDisabled = (flags & MA_SOUND_FLAG_NO_SPATIALIZATION) != 0;
     config.monoExpansionMode        = pEngine->monoExpansionMode;
+    config.resampling               = pEngine->pitchResamplingConfig;
 
     return config;
 }
@@ -76559,7 +76588,7 @@ static void ma_engine_node_update_pitch_if_required(ma_engine_node* pEngineNode)
 
     if (isUpdateRequired) {
         float basePitch = (float)pEngineNode->sampleRate / ma_engine_get_sample_rate(pEngineNode->pEngine);
-        ma_linear_resampler_set_rate_ratio(&pEngineNode->resampler, basePitch * pEngineNode->oldPitch * pEngineNode->oldDopplerPitch);
+        ma_resampler_set_rate_ratio(&pEngineNode->resampler, basePitch * pEngineNode->oldPitch * pEngineNode->oldDopplerPitch);
     }
 }
 
@@ -76578,22 +76607,6 @@ static ma_bool32 ma_engine_node_is_spatialization_enabled(const ma_engine_node*
     return !ma_atomic_load_explicit_32(&pEngineNode->isSpatializationDisabled, ma_atomic_memory_order_acquire);
 }
 
-static ma_uint64 ma_engine_node_get_required_input_frame_count(const ma_engine_node* pEngineNode, ma_uint64 outputFrameCount)
-{
-    ma_uint64 inputFrameCount = 0;
-
-    if (ma_engine_node_is_pitching_enabled(pEngineNode)) {
-        ma_result result = ma_linear_resampler_get_required_input_frame_count(&pEngineNode->resampler, outputFrameCount, &inputFrameCount);
-        if (result != MA_SUCCESS) {
-            inputFrameCount = 0;
-        }
-    } else {
-        inputFrameCount = outputFrameCount;    /* No resampling, so 1:1. */
-    }
-
-    return inputFrameCount;
-}
-
 static ma_result ma_engine_node_set_volume(ma_engine_node* pEngineNode, float volume)
 {
     if (pEngineNode == NULL) {
@@ -76735,7 +76748,7 @@ static void ma_engine_node_process_pcm_frames__general(ma_engine_node* pEngineNo
             ma_uint64 resampleFrameCountIn  = framesAvailableIn;
             ma_uint64 resampleFrameCountOut = framesAvailableOut;
 
-            ma_linear_resampler_process_pcm_frames(&pEngineNode->resampler, pRunningFramesIn, &resampleFrameCountIn, pWorkingBuffer, &resampleFrameCountOut);
+            ma_resampler_process_pcm_frames(&pEngineNode->resampler, pRunningFramesIn, &resampleFrameCountIn, pWorkingBuffer, &resampleFrameCountOut);
             isWorkingBufferValid = MA_TRUE;
 
             framesJustProcessedIn  = (ma_uint32)resampleFrameCountIn;
@@ -76859,6 +76872,11 @@ static void ma_engine_node_process_pcm_frames__sound(ma_node* pNode, const float
     /* If we're marked at the end we need to stop the sound and do nothing. */
     if (ma_sound_at_end(pSound)) {
         ma_sound_stop(pSound);
+
+        if (pSound->endCallback != NULL) {
+            pSound->endCallback(pSound->pEndCallbackUserData, pSound);
+        }
+
         *pFrameCountOut = 0;
         return;
     }
@@ -76896,55 +76914,74 @@ static void ma_engine_node_process_pcm_frames__sound(ma_node* pNode, const float
         /* Keep reading until we've read as much as was requested or we reach the end of the data source. */
         while (totalFramesRead < frameCount) {
             ma_uint32 framesRemaining = frameCount - totalFramesRead;
-            ma_uint32 framesToRead;
             ma_uint64 framesJustRead;
             ma_uint32 frameCountIn;
             ma_uint32 frameCountOut;
             const float* pRunningFramesIn;
             float* pRunningFramesOut;
 
-            /*
-            The first thing we need to do is read into the temporary buffer. We can calculate exactly
-            how many input frames we'll need after resampling.
-            */
-            framesToRead = (ma_uint32)ma_engine_node_get_required_input_frame_count(&pSound->engineNode, framesRemaining);
-            if (framesToRead > tempCapInFrames) {
-                framesToRead = tempCapInFrames;
-            }
+            /* If there's any input frames sitting in the cache get those processed first. */
+            if (pSound->processingCacheFramesRemaining > 0) {
+                pRunningFramesIn = pSound->pProcessingCache;
+                frameCountIn = pSound->processingCacheFramesRemaining;
 
-            result = ma_data_source_read_pcm_frames(pSound->pDataSource, temp, framesToRead, &framesJustRead);
+                pRunningFramesOut = ma_offset_pcm_frames_ptr_f32(ppFramesOut[0], totalFramesRead, ma_node_get_output_channels(pNode, 0));
+                frameCountOut = framesRemaining;
 
-            /* If we reached the end of the sound we'll want to mark it as at the end and stop it. This should never be returned for looping sounds. */
-            if (result == MA_AT_END) {
-                ma_sound_set_at_end(pSound, MA_TRUE);   /* This will be set to false in ma_sound_start(). */
-            }
+                ma_engine_node_process_pcm_frames__general(&pSound->engineNode, &pRunningFramesIn, &frameCountIn, &pRunningFramesOut, &frameCountOut);
 
-            pRunningFramesOut = ma_offset_pcm_frames_ptr_f32(ppFramesOut[0], totalFramesRead, ma_node_get_output_channels(pNode, 0));
+                MA_ASSERT(frameCountIn <= pSound->processingCacheFramesRemaining);
+                pSound->processingCacheFramesRemaining -= frameCountIn;
 
-            frameCountIn = (ma_uint32)framesJustRead;
-            frameCountOut = framesRemaining;
+                /* Move any remaining data in the cache down. */
+                if (pSound->processingCacheFramesRemaining > 0) {
+                    MA_MOVE_MEMORY(pSound->pProcessingCache, ma_offset_pcm_frames_ptr_f32(pSound->pProcessingCache, frameCountIn, dataSourceChannels), pSound->processingCacheFramesRemaining * ma_get_bytes_per_frame(ma_format_f32, dataSourceChannels));
+                }
+                
+                totalFramesRead += (ma_uint32)frameCountOut;   /* Safe cast. */
 
-            /* Convert if necessary. */
-            if (dataSourceFormat == ma_format_f32) {
-                /* Fast path. No data conversion necessary. */
-                pRunningFramesIn = (float*)temp;
-                ma_engine_node_process_pcm_frames__general(&pSound->engineNode, &pRunningFramesIn, &frameCountIn, &pRunningFramesOut, &frameCountOut);
+                if (result != MA_SUCCESS || ma_sound_at_end(pSound)) {
+                    break;  /* Might have reached the end. */
+                }
             } else {
-                /* Slow path. Need to do sample format conversion to f32. If we give the f32 buffer the same count as the first temp buffer, we're guaranteed it'll be large enough. */
-                float tempf32[MA_DATA_CONVERTER_STACK_BUFFER_SIZE]; /* Do not do `MA_DATA_CONVERTER_STACK_BUFFER_SIZE/sizeof(float)` here like we've done in other places. */
-                ma_convert_pcm_frames_format(tempf32, ma_format_f32, temp, dataSourceFormat, framesJustRead, dataSourceChannels, ma_dither_mode_none);
+                /* Getting here means there's nothing in the cache. Read more data from the data source. */
+                if (dataSourceFormat == ma_format_f32) {
+                    /* Fast path. No conversion to f32 necessary. */
+                    result = ma_data_source_read_pcm_frames(pSound->pDataSource, pSound->pProcessingCache, pSound->processingCacheCap, &framesJustRead);
+                } else {
+                    /* Slow path. Need to convert to f32. */
+                    ma_uint64 totalFramesConverted = 0;
+
+                    while (totalFramesConverted < pSound->processingCacheCap) {
+                        ma_uint64 framesConverted;
+                        ma_uint32 framesToConvertThisIteration = pSound->processingCacheCap - (ma_uint32)totalFramesConverted;
+                        if (framesToConvertThisIteration > tempCapInFrames) {
+                            framesToConvertThisIteration = tempCapInFrames;
+                        }
 
-                /* Now that we have our samples in f32 format we can process like normal. */
-                pRunningFramesIn = tempf32;
-                ma_engine_node_process_pcm_frames__general(&pSound->engineNode, &pRunningFramesIn, &frameCountIn, &pRunningFramesOut, &frameCountOut);
-            }
+                        result = ma_data_source_read_pcm_frames(pSound->pDataSource, temp, framesToConvertThisIteration, &framesConverted);
+                        if (result != MA_SUCCESS) {
+                            break;
+                        }
+
+                        ma_convert_pcm_frames_format(ma_offset_pcm_frames_ptr_f32(pSound->pProcessingCache, totalFramesConverted, dataSourceChannels), ma_format_f32, temp, dataSourceFormat, framesConverted, dataSourceChannels, ma_dither_mode_none);
+                        totalFramesConverted += framesConverted;
+                    }
+
+                    framesJustRead = totalFramesConverted;
+                }
+
+                MA_ASSERT(framesJustRead <= pSound->processingCacheCap);
+                pSound->processingCacheFramesRemaining = (ma_uint32)framesJustRead;
 
-            /* We should have processed all of our input frames since we calculated the required number of input frames at the top. */
-            MA_ASSERT(frameCountIn == framesJustRead);
-            totalFramesRead += (ma_uint32)frameCountOut;   /* Safe cast. */
+                /* If we reached the end of the sound we'll want to mark it as at the end and stop it. This should never be returned for looping sounds. */
+                if (result == MA_AT_END) {
+                    ma_sound_set_at_end(pSound, MA_TRUE);   /* This will be set to false in ma_sound_start(). */
+                }
 
-            if (result != MA_SUCCESS || ma_sound_at_end(pSound)) {
-                break;  /* Might have reached the end. */
+                if (result != MA_SUCCESS || ma_sound_at_end(pSound)) {
+                    break;
+                }
             }
         }
     }
@@ -76967,25 +77004,6 @@ static void ma_engine_node_process_pcm_frames__group(ma_node* pNode, const float
     ma_engine_node_process_pcm_frames__general((ma_engine_node*)pNode, ppFramesIn, pFrameCountIn, ppFramesOut, pFrameCountOut);
 }
 
-static ma_result ma_engine_node_get_required_input_frame_count__group(ma_node* pNode, ma_uint32 outputFrameCount, ma_uint32* pInputFrameCount)
-{
-    ma_uint64 inputFrameCount;
-
-    MA_ASSERT(pInputFrameCount != NULL);
-
-    /* Our pitch will affect this calculation. We need to update it. */
-    ma_engine_node_update_pitch_if_required((ma_engine_node*)pNode);
-
-    inputFrameCount = ma_engine_node_get_required_input_frame_count((ma_engine_node*)pNode, outputFrameCount);
-    if (inputFrameCount > 0xFFFFFFFF) {
-        inputFrameCount = 0xFFFFFFFF;    /* Will never happen because miniaudio will only ever process in relatively small chunks. */
-    }
-
-    *pInputFrameCount = (ma_uint32)inputFrameCount;
-
-    return MA_SUCCESS;
-}
-
 
 static ma_node_vtable g_ma_engine_node_vtable__sound =
 {
@@ -76999,7 +77017,7 @@ static ma_node_vtable g_ma_engine_node_vtable__sound =
 static ma_node_vtable g_ma_engine_node_vtable__group =
 {
     ma_engine_node_process_pcm_frames__group,
-    ma_engine_node_get_required_input_frame_count__group,
+    NULL,   /* onGetRequiredInputFrameCount */
     1,      /* Groups have one input bus. */
     1,      /* Groups have one output bus. */
     MA_NODE_FLAG_DIFFERENT_PROCESSING_RATES /* The engine node does resampling so should let miniaudio know about it. */
@@ -77045,9 +77063,10 @@ static ma_result ma_engine_node_get_heap_layout(const ma_engine_node_config* pCo
     ma_result result;
     size_t tempHeapSize;
     ma_node_config baseNodeConfig;
-    ma_linear_resampler_config resamplerConfig;
+    ma_resampler_config resamplerConfig;
     ma_spatializer_config spatializerConfig;
     ma_gainer_config gainerConfig;
+    ma_uint32 sampleRate;
     ma_uint32 channelsIn;
     ma_uint32 channelsOut;
     ma_channel defaultStereoChannelMap[2] = {MA_CHANNEL_SIDE_LEFT, MA_CHANNEL_SIDE_RIGHT};  /* <-- Consistent with the default channel map of a stereo listener. Means channel conversion can run on a fast path. */
@@ -77066,6 +77085,7 @@ static ma_result ma_engine_node_get_heap_layout(const ma_engine_node_config* pCo
 
     pHeapLayout->sizeInBytes = 0;
 
+    sampleRate  = (pConfig->sampleRate   > 0) ? pConfig->sampleRate  : ma_engine_get_sample_rate(pConfig->pEngine);
     channelsIn  = (pConfig->channelsIn  != 0) ? pConfig->channelsIn  : ma_engine_get_channels(pConfig->pEngine);
     channelsOut = (pConfig->channelsOut != 0) ? pConfig->channelsOut : ma_engine_get_channels(pConfig->pEngine);
 
@@ -77085,10 +77105,13 @@ static ma_result ma_engine_node_get_heap_layout(const ma_engine_node_config* pCo
 
 
     /* Resmapler. */
-    resamplerConfig = ma_linear_resampler_config_init(ma_format_f32, channelsIn, 1, 1); /* Input and output sample rates don't affect the calculation of the heap size. */
-    resamplerConfig.lpfOrder = 0;
+    resamplerConfig = pConfig->resampling;
+    resamplerConfig.format        = ma_format_f32;
+    resamplerConfig.channels      = channelsIn;
+    resamplerConfig.sampleRateIn  = sampleRate;
+    resamplerConfig.sampleRateOut = ma_engine_get_sample_rate(pConfig->pEngine);
 
-    result = ma_linear_resampler_get_heap_size(&resamplerConfig, &tempHeapSize);
+    result = ma_resampler_get_heap_size(&resamplerConfig, &tempHeapSize);
     if (result != MA_SUCCESS) {
         return result;  /* Failed to retrieve the size of the heap for the resampler. */
     }
@@ -77156,7 +77179,7 @@ MA_API ma_result ma_engine_node_init_preallocated(const ma_engine_node_config* p
     ma_result result;
     ma_engine_node_heap_layout heapLayout;
     ma_node_config baseNodeConfig;
-    ma_linear_resampler_config resamplerConfig;
+    ma_resampler_config resamplerConfig;
     ma_fader_config faderConfig;
     ma_spatializer_config spatializerConfig;
     ma_panner_config pannerConfig;
@@ -77231,10 +77254,13 @@ MA_API ma_result ma_engine_node_init_preallocated(const ma_engine_node_config* p
     */
 
     /* We'll always do resampling first. */
-    resamplerConfig = ma_linear_resampler_config_init(ma_format_f32, baseNodeConfig.pInputChannels[0], pEngineNode->sampleRate, ma_engine_get_sample_rate(pEngineNode->pEngine));
-    resamplerConfig.lpfOrder = 0;    /* <-- Need to disable low-pass filtering for pitch shifting for now because there's cases where the biquads are becoming unstable. Need to figure out a better fix for this. */
+    resamplerConfig = pConfig->resampling;
+    resamplerConfig.format        = ma_format_f32;
+    resamplerConfig.channels      = baseNodeConfig.pInputChannels[0];
+    resamplerConfig.sampleRateIn  = pEngineNode->sampleRate;
+    resamplerConfig.sampleRateOut = ma_engine_get_sample_rate(pEngineNode->pEngine);
 
-    result = ma_linear_resampler_init_preallocated(&resamplerConfig, ma_offset_ptr(pHeap, heapLayout.resamplerOffset), &pEngineNode->resampler);
+    result = ma_resampler_init_preallocated(&resamplerConfig, ma_offset_ptr(pHeap, heapLayout.resamplerOffset), &pEngineNode->resampler);
     if (result != MA_SUCCESS) {
         goto error1;
     }
@@ -77293,7 +77319,7 @@ MA_API ma_result ma_engine_node_init_preallocated(const ma_engine_node_config* p
 
     /* No need for allocation callbacks here because we use a preallocated heap. */
 error3: ma_spatializer_uninit(&pEngineNode->spatializer, NULL);
-error2: ma_linear_resampler_uninit(&pEngineNode->resampler, NULL);
+error2: ma_resampler_uninit(&pEngineNode->resampler, NULL);
 error1: ma_node_uninit(&pEngineNode->baseNode, NULL);
 error0: return result;
 }
@@ -77342,7 +77368,7 @@ MA_API void ma_engine_node_uninit(ma_engine_node* pEngineNode, const ma_allocati
     }
 
     ma_spatializer_uninit(&pEngineNode->spatializer, pAllocationCallbacks);
-    ma_linear_resampler_uninit(&pEngineNode->resampler, pAllocationCallbacks);
+    ma_resampler_uninit(&pEngineNode->resampler, pAllocationCallbacks);
 
     /* Free the heap last. */
     if (pEngineNode->_ownsHeap) {
@@ -77364,8 +77390,12 @@ MA_API ma_sound_config ma_sound_config_init_2(ma_engine* pEngine)
 
     if (pEngine != NULL) {
         config.monoExpansionMode = pEngine->monoExpansionMode;
+        config.pitchResampling = pEngine->pitchResamplingConfig;
     } else {
         config.monoExpansionMode = ma_mono_expansion_mode_default;
+
+        config.pitchResampling = ma_resampler_config_init(ma_format_f32, 0, 0, 0, ma_resample_algorithm_linear);
+        config.pitchResampling.linear.lpfOrder = 0; /* <-- Need to disable low-pass filtering for pitch shifting for now because there's cases where the biquads are becoming unstable. Need to figure out a better fix for this. */
     }
 
     config.rangeEndInPCMFrames     = ~((ma_uint64)0);
@@ -77387,8 +77417,12 @@ MA_API ma_sound_group_config ma_sound_group_config_init_2(ma_engine* pEngine)
 
     if (pEngine != NULL) {
         config.monoExpansionMode = pEngine->monoExpansionMode;
+        config.pitchResampling   = pEngine->pitchResamplingConfig;
     } else {
         config.monoExpansionMode = ma_mono_expansion_mode_default;
+
+        config.pitchResampling = ma_resampler_config_init(ma_format_f32, 0, 0, 0, ma_resample_algorithm_linear);
+        config.pitchResampling.linear.lpfOrder = 0; /* <-- Need to disable low-pass filtering for pitch shifting for now because there's cases where the biquads are becoming unstable. Need to figure out a better fix for this. */
     }
 
     return config;
@@ -77400,8 +77434,12 @@ MA_API ma_engine_config ma_engine_config_init(void)
     ma_engine_config config;
 
     MA_ZERO_OBJECT(&config);
-    config.listenerCount     = 1;   /* Always want at least one listener. */
-    config.monoExpansionMode = ma_mono_expansion_mode_default;
+    config.listenerCount             = 1;   /* Always want at least one listener. */
+    config.monoExpansionMode         = ma_mono_expansion_mode_default;
+    config.resourceManagerResampling = ma_resampler_config_init(ma_format_unknown, 0, 0, 0, ma_resample_algorithm_linear);
+
+    config.pitchResampling = ma_resampler_config_init(ma_format_f32, 0, 0, 0, ma_resample_algorithm_linear);
+    config.pitchResampling.linear.lpfOrder = 0; /* <-- Need to disable low-pass filtering for pitch shifting for now because there's cases where the biquads are becoming unstable. Need to figure out a better fix for this. */
 
     return config;
 }
@@ -77482,6 +77520,7 @@ MA_API ma_result ma_engine_init(const ma_engine_config* pConfig, ma_engine* pEng
     pEngine->defaultVolumeSmoothTimeInPCMFrames = engineConfig.defaultVolumeSmoothTimeInPCMFrames;
     pEngine->onProcess = engineConfig.onProcess;
     pEngine->pProcessUserData = engineConfig.pProcessUserData;
+    pEngine->pitchResamplingConfig = engineConfig.pitchResampling;
     ma_allocation_callbacks_init_copy(&pEngine->allocationCallbacks, &engineConfig.allocationCallbacks);
 
     #if !defined(MA_NO_RESOURCE_MANAGER)
@@ -77664,6 +77703,7 @@ MA_API ma_result ma_engine_init(const ma_engine_config* pConfig, ma_engine* pEng
             resourceManagerConfig.decodedSampleRate = ma_engine_get_sample_rate(pEngine);
             ma_allocation_callbacks_init_copy(&resourceManagerConfig.allocationCallbacks, &pEngine->allocationCallbacks);
             resourceManagerConfig.pVFS              = engineConfig.pResourceManagerVFS;
+            resourceManagerConfig.resampling        = engineConfig.resourceManagerResampling;
 
             /* The Emscripten build cannot use threads unless it's targeting pthreads. */
             #if defined(MA_EMSCRIPTEN) && !defined(__EMSCRIPTEN_PTHREADS__)
@@ -78389,6 +78429,25 @@ static ma_result ma_sound_init_from_data_source_internal(ma_engine* pEngine, con
     }
 
 
+    /*
+    When pulling data from a data source we need a processing cache to hold onto unprocessed input data from the data source
+    after doing resampling.
+    */
+    if (pSound->pDataSource != NULL) {
+        pSound->processingCacheFramesRemaining = 0;
+        pSound->processingCacheCap = ma_node_graph_get_processing_size_in_frames(&pEngine->nodeGraph);
+        if (pSound->processingCacheCap == 0) {
+            pSound->processingCacheCap = 512;
+        }
+        
+        pSound->pProcessingCache = (float*)ma_calloc(pSound->processingCacheCap * ma_get_bytes_per_frame(ma_format_f32, engineNodeConfig.channelsIn), &pEngine->allocationCallbacks);
+        if (pSound->pProcessingCache == NULL) {
+            ma_engine_node_uninit(&pSound->engineNode, &pEngine->allocationCallbacks);
+            return MA_OUT_OF_MEMORY;
+        }
+    }
+
+
     /* Apply initial range and looping state to the data source if applicable. */
     if (pConfig->rangeBegInPCMFrames != 0 || pConfig->rangeEndInPCMFrames != ~((ma_uint64)0)) {
         ma_data_source_set_range_in_pcm_frames(ma_sound_get_data_source(pSound), pConfig->rangeBegInPCMFrames, pConfig->rangeEndInPCMFrames);
@@ -78626,6 +78685,11 @@ MA_API void ma_sound_uninit(ma_sound* pSound)
     */
     ma_engine_node_uninit(&pSound->engineNode, &pSound->engineNode.pEngine->allocationCallbacks);
 
+    if (pSound->pProcessingCache != NULL) {
+        ma_free(pSound->pProcessingCache, &pSound->engineNode.pEngine->allocationCallbacks);
+        pSound->pProcessingCache = NULL;
+    }
+
     /* Once the sound is detached from the group we can guarantee that it won't be referenced by the mixer thread which means it's safe for us to destroy the data source. */
 #ifndef MA_NO_RESOURCE_MANAGER
     if (pSound->ownsDataSource) {
@@ -78721,6 +78785,27 @@ MA_API ma_result ma_sound_stop_with_fade_in_milliseconds(ma_sound* pSound, ma_ui
     return ma_sound_stop_with_fade_in_pcm_frames(pSound, (fadeLengthInMilliseconds * sampleRate) / 1000);
 }
 
+MA_API void ma_sound_reset_start_time(ma_sound* pSound)
+{
+    ma_sound_set_start_time_in_pcm_frames(pSound, 0);
+}
+
+MA_API void ma_sound_reset_stop_time(ma_sound* pSound)
+{
+    ma_sound_set_stop_time_in_pcm_frames(pSound, ~(ma_uint64)0);
+}
+
+MA_API void ma_sound_reset_fade(ma_sound* pSound)
+{
+    ma_sound_set_fade_in_pcm_frames(pSound, 0, 1, 0);
+}
+
+MA_API void ma_sound_reset_stop_time_and_fade(ma_sound* pSound)
+{
+    ma_sound_reset_stop_time(pSound);
+    ma_sound_reset_fade(pSound);
+}
+
 MA_API void ma_sound_set_volume(ma_sound* pSound, float volume)
 {
     if (pSound == NULL) {
@@ -79372,7 +79457,7 @@ MA_API ma_result ma_sound_get_data_format(const ma_sound* pSound, ma_format* pFo
         }
 
         if (pSampleRate != NULL) {
-            *pSampleRate = pSound->engineNode.resampler.config.sampleRateIn;
+            *pSampleRate = pSound->engineNode.resampler.sampleRateIn;
         }
 
         if (pChannelMap != NULL) {
@@ -82436,7 +82521,6 @@ MA_PRIVATE ma_bool32 ma_dr_wav__on_seek_memory(void* pUserData, int offset, ma_d
     ma_dr_wav* pWav = (ma_dr_wav*)pUserData;
     ma_int64 newCursor;
     MA_DR_WAV_ASSERT(pWav != NULL);
-    newCursor = pWav->memoryStream.currentReadPos;
     if (origin == MA_DR_WAV_SEEK_SET) {
         newCursor = 0;
     } else if (origin == MA_DR_WAV_SEEK_CUR) {
@@ -82490,7 +82574,6 @@ MA_PRIVATE ma_bool32 ma_dr_wav__on_seek_memory_write(void* pUserData, int offset
     ma_dr_wav* pWav = (ma_dr_wav*)pUserData;
     ma_int64 newCursor;
     MA_DR_WAV_ASSERT(pWav != NULL);
-    newCursor = pWav->memoryStreamWrite.currentWritePos;
     if (origin == MA_DR_WAV_SEEK_SET) {
         newCursor = 0;
     } else if (origin == MA_DR_WAV_SEEK_CUR) {
@@ -82499,7 +82582,7 @@ MA_PRIVATE ma_bool32 ma_dr_wav__on_seek_memory_write(void* pUserData, int offset
         newCursor = (ma_int64)pWav->memoryStreamWrite.dataSize;
     } else {
         MA_DR_WAV_ASSERT(!"Invalid seek origin");
-        return MA_INVALID_ARGS;
+        return MA_FALSE;
     }
     newCursor += offset;
     if (newCursor < 0) {
@@ -83000,7 +83083,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_
                 pWav->msadpcm.cachedFrames[2]  = pWav->msadpcm.prevFrames[0][0];
                 pWav->msadpcm.cachedFrames[3]  = pWav->msadpcm.prevFrames[0][1];
                 pWav->msadpcm.cachedFrameCount = 2;
-                if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table)) {
+                if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff2Table)) {
                     return totalFramesRead;
                 }
             } else {
@@ -83022,7 +83105,8 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_
                 pWav->msadpcm.cachedFrames[2] = pWav->msadpcm.prevFrames[0][1];
                 pWav->msadpcm.cachedFrames[3] = pWav->msadpcm.prevFrames[1][1];
                 pWav->msadpcm.cachedFrameCount = 2;
-                if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff2Table)) {
+                if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff2Table) ||
+                    pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff2Table)) {
                     return totalFramesRead;
                 }
             }
@@ -83059,6 +83143,9 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_
                 if (pWav->channels == 1) {
                     ma_int32 newSample0;
                     ma_int32 newSample1;
+                    if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff2Table)) {
+                        return totalFramesRead;
+                    }
                     newSample0  = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8;
                     newSample0 += nibble0 * pWav->msadpcm.delta[0];
                     newSample0  = ma_dr_wav_clamp(newSample0, -32768, 32767);
@@ -83083,6 +83170,9 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_
                 } else {
                     ma_int32 newSample0;
                     ma_int32 newSample1;
+                    if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff2Table)) {
+                        return totalFramesRead;
+                    }
                     newSample0  = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8;
                     newSample0 += nibble0 * pWav->msadpcm.delta[0];
                     newSample0  = ma_dr_wav_clamp(newSample0, -32768, 32767);
@@ -83092,6 +83182,9 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_
                     }
                     pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1];
                     pWav->msadpcm.prevFrames[0][1] = newSample0;
+                    if (pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff2Table)) {
+                        return totalFramesRead;
+                    }
                     newSample1  = ((pWav->msadpcm.prevFrames[1][1] * coeff1Table[pWav->msadpcm.predictor[1]]) + (pWav->msadpcm.prevFrames[1][0] * coeff2Table[pWav->msadpcm.predictor[1]])) >> 8;
                     newSample1 += nibble1 * pWav->msadpcm.delta[1];
                     newSample1  = ma_dr_wav_clamp(newSample1, -32768, 32767);
@@ -84336,6 +84429,10 @@ MA_PRIVATE ma_int16* ma_dr_wav__read_pcm_frames_and_close_s16(ma_dr_wav* pWav, u
     ma_int16* pSampleData;
     ma_uint64 framesRead;
     MA_DR_WAV_ASSERT(pWav != NULL);
+    if (pWav->channels == 0 || pWav->totalPCMFrameCount > MA_SIZE_MAX / pWav->channels / sizeof(ma_int16)) {
+        ma_dr_wav_uninit(pWav);
+        return NULL;
+    }
     sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(ma_int16);
     if (sampleDataSize > MA_SIZE_MAX) {
         ma_dr_wav_uninit(pWav);
@@ -84370,6 +84467,10 @@ MA_PRIVATE float* ma_dr_wav__read_pcm_frames_and_close_f32(ma_dr_wav* pWav, unsi
     float* pSampleData;
     ma_uint64 framesRead;
     MA_DR_WAV_ASSERT(pWav != NULL);
+    if (pWav->channels == 0 || pWav->totalPCMFrameCount > MA_SIZE_MAX / pWav->channels / sizeof(float)) {
+        ma_dr_wav_uninit(pWav);
+        return NULL;
+    }
     sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(float);
     if (sampleDataSize > MA_SIZE_MAX) {
         ma_dr_wav_uninit(pWav);
@@ -84404,6 +84505,10 @@ MA_PRIVATE ma_int32* ma_dr_wav__read_pcm_frames_and_close_s32(ma_dr_wav* pWav, u
     ma_int32* pSampleData;
     ma_uint64 framesRead;
     MA_DR_WAV_ASSERT(pWav != NULL);
+    if (pWav->channels == 0 || pWav->totalPCMFrameCount > MA_SIZE_MAX / pWav->channels / sizeof(ma_int32)) {
+        ma_dr_wav_uninit(pWav);
+        return NULL;
+    }
     sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(ma_int32);
     if (sampleDataSize > MA_SIZE_MAX) {
         ma_dr_wav_uninit(pWav);
@@ -85786,7 +85891,7 @@ static MA_INLINE ma_uint32 ma_dr_flac__clz_lzcnt(ma_dr_flac_cache_t x)
             {
                 ma_uint64 r;
                 __asm__ __volatile__ (
-                    "lzcnt{ %1, %0| %0, %1}" : "=r"(r) : "r"(x) : "cc"
+                    "rep; bsr{q %1, %0| %0, %1}" : "=r"(r) : "r"(x) : "cc"
                 );
                 return (ma_uint32)r;
             }
@@ -85794,11 +85899,11 @@ static MA_INLINE ma_uint32 ma_dr_flac__clz_lzcnt(ma_dr_flac_cache_t x)
             {
                 ma_uint32 r;
                 __asm__ __volatile__ (
-                    "lzcnt{l %1, %0| %0, %1}" : "=r"(r) : "r"(x) : "cc"
+                    "rep; bsr{l %1, %0| %0, %1}" : "=r"(r) : "r"(x) : "cc"
                 );
                 return r;
             }
-        #elif defined(MA_ARM) && (defined(__ARM_ARCH) && __ARM_ARCH >= 5) && !defined(__ARM_ARCH_6M__) && !defined(MA_64BIT)
+        #elif defined(MA_ARM) && (defined(__ARM_ARCH) && __ARM_ARCH >= 5) && !defined(__ARM_ARCH_6M__) && !(defined(__thumb__) && !defined(__thumb2__)) && !defined(MA_64BIT)
             {
                 unsigned int r;
                 __asm__ __volatile__ (
@@ -88852,6 +88957,7 @@ static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRea
                         }
                     }
                     blockSizeRemaining -= metadata.data.picture.pictureDataSize;
+                    (void)blockSizeRemaining;
                     metadata.data.picture.pPictureData = (const ma_uint8*)pPictureData;
                     if (metadata.data.picture.pictureDataOffset != 0 || metadata.data.picture.pPictureData != NULL) {
                         onMeta(pUserDataMD, &metadata);
@@ -92276,57 +92382,42 @@ static type* ma_dr_flac__full_read_and_close_ ## extension (ma_dr_flac* pFlac, u
 {                                                                                                                                                                   \
     type* pSampleData = NULL;                                                                                                                                       \
     ma_uint64 totalPCMFrameCount;                                                                                                                               \
+    type buffer[4096];                                                                                                                                              \
+    ma_uint64 pcmFramesRead;                                                                                                                                    \
+    size_t sampleDataBufferSize = sizeof(buffer);                                                                                                                   \
                                                                                                                                                                     \
     MA_DR_FLAC_ASSERT(pFlac != NULL);                                                                                                                                   \
                                                                                                                                                                     \
-    totalPCMFrameCount = pFlac->totalPCMFrameCount;                                                                                                                 \
+    totalPCMFrameCount = 0;                                                                                                                                         \
                                                                                                                                                                     \
-    if (totalPCMFrameCount == 0) {                                                                                                                                  \
-        type buffer[4096];                                                                                                                                          \
-        ma_uint64 pcmFramesRead;                                                                                                                                \
-        size_t sampleDataBufferSize = sizeof(buffer);                                                                                                               \
-                                                                                                                                                                    \
-        pSampleData = (type*)ma_dr_flac__malloc_from_callbacks(sampleDataBufferSize, &pFlac->allocationCallbacks);                                                      \
-        if (pSampleData == NULL) {                                                                                                                                  \
-            goto on_error;                                                                                                                                          \
-        }                                                                                                                                                           \
-                                                                                                                                                                    \
-        while ((pcmFramesRead = (ma_uint64)ma_dr_flac_read_pcm_frames_##extension(pFlac, sizeof(buffer)/sizeof(buffer[0])/pFlac->channels, buffer)) > 0) {          \
-            if (((totalPCMFrameCount + pcmFramesRead) * pFlac->channels * sizeof(type)) > sampleDataBufferSize) {                                                   \
-                type* pNewSampleData;                                                                                                                               \
-                size_t newSampleDataBufferSize;                                                                                                                     \
+    pSampleData = (type*)ma_dr_flac__malloc_from_callbacks(sampleDataBufferSize, &pFlac->allocationCallbacks);                                                          \
+    if (pSampleData == NULL) {                                                                                                                                      \
+        goto on_error;                                                                                                                                              \
+    }                                                                                                                                                               \
                                                                                                                                                                     \
-                newSampleDataBufferSize = sampleDataBufferSize * 2;                                                                                                 \
-                pNewSampleData = (type*)ma_dr_flac__realloc_from_callbacks(pSampleData, newSampleDataBufferSize, sampleDataBufferSize, &pFlac->allocationCallbacks);    \
-                if (pNewSampleData == NULL) {                                                                                                                       \
-                    ma_dr_flac__free_from_callbacks(pSampleData, &pFlac->allocationCallbacks);                                                                          \
-                    goto on_error;                                                                                                                                  \
-                }                                                                                                                                                   \
+    while ((pcmFramesRead = (ma_uint64)ma_dr_flac_read_pcm_frames_##extension(pFlac, sizeof(buffer)/sizeof(buffer[0])/pFlac->channels, buffer)) > 0) {              \
+        if (((totalPCMFrameCount + pcmFramesRead) * pFlac->channels * sizeof(type)) > sampleDataBufferSize) {                                                       \
+            type* pNewSampleData;                                                                                                                                   \
+            size_t newSampleDataBufferSize;                                                                                                                         \
                                                                                                                                                                     \
-                sampleDataBufferSize = newSampleDataBufferSize;                                                                                                     \
-                pSampleData = pNewSampleData;                                                                                                                       \
+            newSampleDataBufferSize = sampleDataBufferSize * 2;                                                                                                     \
+            pNewSampleData = (type*)ma_dr_flac__realloc_from_callbacks(pSampleData, newSampleDataBufferSize, sampleDataBufferSize, &pFlac->allocationCallbacks);        \
+            if (pNewSampleData == NULL) {                                                                                                                           \
+                ma_dr_flac__free_from_callbacks(pSampleData, &pFlac->allocationCallbacks);                                                                              \
+                goto on_error;                                                                                                                                      \
             }                                                                                                                                                       \
                                                                                                                                                                     \
-            MA_DR_FLAC_COPY_MEMORY(pSampleData + (totalPCMFrameCount*pFlac->channels), buffer, (size_t)(pcmFramesRead*pFlac->channels*sizeof(type)));                   \
-            totalPCMFrameCount += pcmFramesRead;                                                                                                                    \
-        }                                                                                                                                                           \
-                                                                                                                                                                    \
-                                                                                                                         \
-        MA_DR_FLAC_ZERO_MEMORY(pSampleData + (totalPCMFrameCount*pFlac->channels), (size_t)(sampleDataBufferSize - totalPCMFrameCount*pFlac->channels*sizeof(type)));   \
-    } else {                                                                                                                                                        \
-        ma_uint64 dataSize = totalPCMFrameCount*pFlac->channels*sizeof(type);                                                                                   \
-        if (dataSize > (ma_uint64)MA_SIZE_MAX) {                                                                                                            \
-            goto on_error;                                                                                                        \
+            sampleDataBufferSize = newSampleDataBufferSize;                                                                                                         \
+            pSampleData = pNewSampleData;                                                                                                                           \
         }                                                                                                                                                           \
                                                                                                                                                                     \
-        pSampleData = (type*)ma_dr_flac__malloc_from_callbacks((size_t)dataSize, &pFlac->allocationCallbacks);               \
-        if (pSampleData == NULL) {                                                                                                                                  \
-            goto on_error;                                                                                                                                          \
-        }                                                                                                                                                           \
-                                                                                                                                                                    \
-        totalPCMFrameCount = ma_dr_flac_read_pcm_frames_##extension(pFlac, pFlac->totalPCMFrameCount, pSampleData);                                                     \
+        MA_DR_FLAC_COPY_MEMORY(pSampleData + (totalPCMFrameCount*pFlac->channels), buffer, (size_t)(pcmFramesRead*pFlac->channels*sizeof(type)));                       \
+        totalPCMFrameCount += pcmFramesRead;                                                                                                                        \
     }                                                                                                                                                               \
                                                                                                                                                                     \
+                                                                                                                         \
+    MA_DR_FLAC_ZERO_MEMORY(pSampleData + (totalPCMFrameCount*pFlac->channels), (size_t)(sampleDataBufferSize - totalPCMFrameCount*pFlac->channels*sizeof(type)));       \
+                                                                                                                                                                    \
     if (sampleRateOut) *sampleRateOut = pFlac->sampleRate;                                                                                                          \
     if (channelsOut) *channelsOut = pFlac->channels;                                                                                                                \
     if (totalPCMFrameCountOut) *totalPCMFrameCountOut = totalPCMFrameCount;                                                                                         \
@@ -94685,19 +94776,22 @@ static ma_bool32 ma_dr_mp3_init_internal(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc on
                                 ((ma_uint32)ape[25] << 8)  |
                                 ((ma_uint32)ape[26] << 16) |
                                 ((ma_uint32)ape[27] << 24);
-                            streamEndOffset -= 32 + tagSize;
-                            streamLen       -= 32 + tagSize;
-                            if (onMeta != NULL) {
-                                if (onSeek(pUserData, streamEndOffset, MA_DR_MP3_SEEK_END)) {
-                                    size_t apeTagSize = (size_t)tagSize + 32;
-                                    ma_uint8* pTagData = (ma_uint8*)ma_dr_mp3_malloc(apeTagSize, pAllocationCallbacks);
-                                    if (pTagData != NULL) {
-                                        if (onRead(pUserData, pTagData, apeTagSize) == apeTagSize) {
-                                            ma_dr_mp3__on_meta(pMP3, MA_DR_MP3_METADATA_TYPE_APE, pTagData, apeTagSize);
+                            if (32 + tagSize < streamLen) {
+                                streamEndOffset -= 32 + tagSize;
+                                streamLen       -= 32 + tagSize;
+                                if (onMeta != NULL) {
+                                    if (onSeek(pUserData, streamEndOffset, MA_DR_MP3_SEEK_END)) {
+                                        size_t apeTagSize = (size_t)tagSize + 32;
+                                        ma_uint8* pTagData = (ma_uint8*)ma_dr_mp3_malloc(apeTagSize, pAllocationCallbacks);
+                                        if (pTagData != NULL) {
+                                            if (onRead(pUserData, pTagData, apeTagSize) == apeTagSize) {
+                                                ma_dr_mp3__on_meta(pMP3, MA_DR_MP3_METADATA_TYPE_APE, pTagData, apeTagSize);
+                                            }
+                                            ma_dr_mp3_free(pTagData, pAllocationCallbacks);
                                         }
-                                        ma_dr_mp3_free(pTagData, pAllocationCallbacks);
                                     }
                                 }
+                            } else {
                             }
                         }
                     }
@@ -94785,7 +94879,6 @@ static ma_bool32 ma_dr_mp3_init_internal(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc on
         {
             ma_dr_mp3_bs bs;
             ma_dr_mp3_L3_gr_info grInfo[4];
-            const ma_uint8* pTagData = pFirstFrameData;
             ma_dr_mp3_bs_init(&bs, pFirstFrameData + MA_DR_MP3_HDR_SIZE, firstFrameInfo.frame_bytes - MA_DR_MP3_HDR_SIZE);
             if (MA_DR_MP3_HDR_IS_CRC(pFirstFrameData)) {
                 ma_dr_mp3_bs_get_bits(&bs, 16);
@@ -94793,6 +94886,7 @@ static ma_bool32 ma_dr_mp3_init_internal(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc on
             if (ma_dr_mp3_L3_read_side_info(&bs, grInfo, pFirstFrameData) >= 0) {
                 ma_bool32 isXing = MA_FALSE;
                 ma_bool32 isInfo = MA_FALSE;
+                const ma_uint8* pTagData;
                 const ma_uint8* pTagDataBeg;
                 pTagDataBeg = pFirstFrameData + MA_DR_MP3_HDR_SIZE + (bs.pos/8);
                 pTagData    = pTagDataBeg;
@@ -94892,7 +94986,6 @@ static ma_bool32 ma_dr_mp3__on_seek_memory(void* pUserData, int byteOffset, ma_d
     ma_dr_mp3* pMP3 = (ma_dr_mp3*)pUserData;
     ma_int64 newCursor;
     MA_DR_MP3_ASSERT(pMP3 != NULL);
-    newCursor = pMP3->memory.currentReadPos;
     if (origin == MA_DR_MP3_SEEK_SET) {
         newCursor = 0;
     } else if (origin == MA_DR_MP3_SEEK_CUR) {
@@ -95543,6 +95636,8 @@ static float* ma_dr_mp3__full_read_and_close_f32(ma_dr_mp3* pMP3, ma_dr_mp3_conf
             pNewFrames = (float*)ma_dr_mp3__realloc_from_callbacks(pFrames, (size_t)newFramesBufferSize, (size_t)oldFramesBufferSize, &pMP3->allocationCallbacks);
             if (pNewFrames == NULL) {
                 ma_dr_mp3__free_from_callbacks(pFrames, &pMP3->allocationCallbacks);
+                pFrames = NULL;
+                totalFramesRead = 0;
                 break;
             }
             pFrames = pNewFrames;
@@ -95594,6 +95689,8 @@ static ma_int16* ma_dr_mp3__full_read_and_close_s16(ma_dr_mp3* pMP3, ma_dr_mp3_c
             pNewFrames = (ma_int16*)ma_dr_mp3__realloc_from_callbacks(pFrames, (size_t)newFramesBufferSize, (size_t)oldFramesBufferSize, &pMP3->allocationCallbacks);
             if (pNewFrames == NULL) {
                 ma_dr_mp3__free_from_callbacks(pFrames, &pMP3->allocationCallbacks);
+                pFrames = NULL;
+                totalFramesRead = 0;
                 break;
             }
             pFrames = pNewFrames;
diff --git a/llama/llama.go b/llama/llama.go
index 87844f2a242..137192646b8 100644
--- a/llama/llama.go
+++ b/llama/llama.go
@@ -352,7 +352,7 @@ func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float
 
 	err := -1
 	if loraAdapter != nil {
-		err = int(C.llama_set_adapter_lora(context.c, loraAdapter, C.float(scale)))
+		err = int(C.llama_set_adapters_lora(context.c, &loraAdapter, 1, (*C.float)(&scale)))
 	}
 	if err != 0 {
 		return errors.New("error applying lora from file")
diff --git a/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch b/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch
index 126dee34e12..7f80ebe8477 100644
--- a/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch
+++ b/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch
@@ -23,7 +23,7 @@ problem.
  8 files changed, 21 insertions(+), 2 deletions(-)
 
 diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
-index 8547ecc84..9f37ca70c 100644
+index 22c656996..c522fec01 100644
 --- a/ggml/src/ggml-backend.cpp
 +++ b/ggml/src/ggml-backend.cpp
 @@ -112,7 +112,6 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
@@ -34,7 +34,7 @@ index 8547ecc84..9f37ca70c 100644
  }
  
  size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
-@@ -591,6 +590,7 @@ static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer)
+@@ -593,6 +592,7 @@ static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer)
  
      free(ctx->buffers);
      free(ctx);
@@ -42,7 +42,7 @@ index 8547ecc84..9f37ca70c 100644
  }
  
  static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
-@@ -2125,6 +2125,11 @@ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
+@@ -2128,6 +2128,11 @@ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
  static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      GGML_ASSERT(buffer);
      ggml_aligned_free(buffer->context, buffer->size);
@@ -54,7 +54,7 @@ index 8547ecc84..9f37ca70c 100644
  }
  
  static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
-@@ -2177,7 +2182,7 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = {
+@@ -2180,7 +2185,7 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = {
  };
  
  static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_from_ptr_i = {
@@ -64,10 +64,10 @@ index 8547ecc84..9f37ca70c 100644
      /* .init_tensor     = */ NULL, // no initialization required
      /* .memset_tensor   = */ ggml_backend_cpu_buffer_memset_tensor,
 diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp
-index da624c587..efc63e092 100644
+index 3f3de9f0b..90a15f217 100644
 --- a/ggml/src/ggml-cann/ggml-cann.cpp
 +++ b/ggml/src/ggml-cann/ggml-cann.cpp
-@@ -831,6 +831,7 @@ static bool ggml_backend_buffer_is_cann(ggml_backend_buffer_t buffer) {
+@@ -845,6 +845,7 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
  static void ggml_backend_cann_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
      delete ctx;
@@ -75,7 +75,7 @@ index da624c587..efc63e092 100644
  }
  
  /**
-@@ -1570,6 +1571,7 @@ static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buf
+@@ -1559,6 +1560,7 @@ static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buf
   */
  static void ggml_backend_cann_host_buffer_free(ggml_backend_buffer_t buffer) {
      ACL_CHECK(aclrtFreeHost(buffer->context));
@@ -84,10 +84,10 @@ index da624c587..efc63e092 100644
  
  /**
 diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
-index ab0f6fe9c..6519af435 100644
+index 7e6d33035..9625513fa 100644
 --- a/ggml/src/ggml-cuda/ggml-cuda.cu
 +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
-@@ -583,6 +583,7 @@ struct ggml_backend_cuda_buffer_context {
+@@ -584,6 +584,7 @@ struct ggml_backend_cuda_buffer_context {
  static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
      delete ctx;
@@ -95,7 +95,7 @@ index ab0f6fe9c..6519af435 100644
  }
  
  static bool ggml_backend_buffer_is_cuda(ggml_backend_buffer_t buffer) {
-@@ -838,6 +839,7 @@ struct ggml_backend_cuda_split_buffer_context {
+@@ -839,6 +840,7 @@ struct ggml_backend_cuda_split_buffer_context {
  static void ggml_backend_cuda_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
      delete ctx;
@@ -103,7 +103,7 @@ index ab0f6fe9c..6519af435 100644
  }
  
  static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buffer) {
-@@ -1119,6 +1121,7 @@ static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) {
+@@ -1120,6 +1122,7 @@ static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) {
  
  static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      CUDA_CHECK(cudaFreeHost(buffer->context));
@@ -112,10 +112,10 @@ index ab0f6fe9c..6519af435 100644
  
  static void * ggml_cuda_host_malloc(size_t size) {
 diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp
-index 70bf6f3d9..f2b7fe692 100644
+index 1c705362f..9f20b9a25 100644
 --- a/ggml/src/ggml-metal/ggml-metal.cpp
 +++ b/ggml/src/ggml-metal/ggml-metal.cpp
-@@ -25,6 +25,7 @@ static void ggml_backend_metal_buffer_shared_free_buffer(ggml_backend_buffer_t b
+@@ -29,6 +29,7 @@ static void ggml_backend_metal_buffer_shared_free_buffer(ggml_backend_buffer_t b
      GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
  
      ggml_metal_buffer_free(ctx);
@@ -123,7 +123,7 @@ index 70bf6f3d9..f2b7fe692 100644
  }
  
  static void * ggml_backend_metal_buffer_shared_get_base(ggml_backend_buffer_t buffer) {
-@@ -99,6 +100,7 @@ static void ggml_backend_metal_buffer_private_free_buffer(ggml_backend_buffer_t
+@@ -103,6 +104,7 @@ static void ggml_backend_metal_buffer_private_free_buffer(ggml_backend_buffer_t
      GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
  
      ggml_metal_buffer_free(ctx);
@@ -132,10 +132,10 @@ index 70bf6f3d9..f2b7fe692 100644
  
  static void * ggml_backend_metal_buffer_private_get_base(ggml_backend_buffer_t buffer) {
 diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp
-index 0d37587f6..ff373d413 100644
+index 3da022ed8..d7742fa13 100644
 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp
 +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp
-@@ -3417,6 +3417,7 @@ struct ggml_backend_opencl_buffer_context {
+@@ -3877,6 +3877,7 @@ struct ggml_backend_opencl_buffer_context {
  static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
      delete ctx;
@@ -144,10 +144,10 @@ index 0d37587f6..ff373d413 100644
  
  static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) {
 diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp
-index 18a45d2d9..89041805e 100644
+index d7c8ad8c1..281fa1bdb 100644
 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp
 +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp
-@@ -556,6 +556,7 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+@@ -557,6 +557,7 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
      RPC_STATUS_ASSERT(status);
      delete ctx;
@@ -156,7 +156,7 @@ index 18a45d2d9..89041805e 100644
  
  static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
 diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp
-index e996d98be..84b679315 100644
+index 0614d7e8f..336172700 100644
 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp
 +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp
 @@ -356,6 +356,7 @@ ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try {
@@ -175,19 +175,19 @@ index e996d98be..84b679315 100644
  }
  
  static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buffer) {
-@@ -1159,6 +1161,7 @@ static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_
+@@ -1175,6 +1177,7 @@ inline void free_aligned_mem_host(void * memblock) {
  
  static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
-     ggml_sycl_host_free(buffer->context);
+     free_aligned_mem_host((void *)buffer->context);
 +    delete buffer;
  }
  
  static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
 diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
-index 34ec09d40..120191ca0 100644
+index 23d6d39e0..c9ca7a986 100644
 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
 +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
-@@ -12365,6 +12365,7 @@ static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+@@ -13227,6 +13227,7 @@ static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
      ggml_vk_destroy_buffer(ctx->dev_buffer);
      delete ctx;
@@ -195,7 +195,7 @@ index 34ec09d40..120191ca0 100644
  }
  
  static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) {
-@@ -12508,6 +12509,7 @@ static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffe
+@@ -13370,6 +13371,7 @@ static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffe
  static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()");
      ggml_vk_host_free(vk_instance.devices[0], buffer->context);
diff --git a/llama/patches/0002-pretokenizer.patch b/llama/patches/0002-pretokenizer.patch
index 9cee5c56f8e..4153a5123e8 100644
--- a/llama/patches/0002-pretokenizer.patch
+++ b/llama/patches/0002-pretokenizer.patch
@@ -6,14 +6,14 @@ Subject: [PATCH] pretokenizer
 allow for an unset pretokenizer with a warning in the
 logs instead of throwing an error
 ---
- src/llama-vocab.cpp | 14 +++-----------
- 1 file changed, 3 insertions(+), 11 deletions(-)
+ src/llama-vocab.cpp | 17 +++++------------
+ 1 file changed, 5 insertions(+), 12 deletions(-)
 
 diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
-index 7b01a2edf..63250cdf1 100644
+index 194eed238..385484304 100644
 --- a/src/llama-vocab.cpp
 +++ b/src/llama-vocab.cpp
-@@ -1825,16 +1825,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
+@@ -1871,16 +1871,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
          if (type == LLAMA_VOCAB_TYPE_BPE) {
              add_space_prefix = false;
              clean_spaces = true;
@@ -31,8 +31,8 @@ index 7b01a2edf..63250cdf1 100644
                  pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
              } else if (
                      tokenizer_pre == "llama3"   ||
-@@ -2015,7 +2006,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
-                 pre_type = LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2;
+@@ -2091,7 +2082,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
+                 pre_type = LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN;
                  clean_spaces = false;
              } else {
 -                throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
@@ -41,3 +41,20 @@ index 7b01a2edf..63250cdf1 100644
              }
          } else if (type == LLAMA_VOCAB_TYPE_SPM) {
              pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+@@ -2135,6 +2127,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
+         scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
+     }
+ 
++    const uint32_t n_scores = score_idx != -1 ? gguf_get_arr_n(ctx, score_idx) : 0;
+     const int * toktypes = nullptr;
+     const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str());
+     if (toktype_idx != -1) {
+@@ -2156,7 +2149,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
+ 
+         auto & token_data = id_to_token[i];
+         token_data.text  = std::move(word);
+-        token_data.score = scores ? scores[i] : 0.0f;
++        token_data.score = (scores && i < n_scores) ? scores[i] : 0.0f;
+         token_data.attr  = LLAMA_TOKEN_ATTR_NORMAL;
+ 
+         if (toktypes) {  //TODO: remove, required until per token attributes are available from GGUF file
diff --git a/llama/patches/0003-clip-unicode.patch b/llama/patches/0003-clip-unicode.patch
index 73d10732d1d..8dab0219406 100644
--- a/llama/patches/0003-clip-unicode.patch
+++ b/llama/patches/0003-clip-unicode.patch
@@ -6,14 +6,14 @@ Subject: [PATCH] clip-unicode
 fixes loading vision models in llama.cpp on windows
 filesystems for paths that include wide characters
 ---
- tools/mtmd/clip.cpp | 39 +++++++++++++++++++++++++++++++++++++++
- 1 file changed, 39 insertions(+)
+ tools/mtmd/clip.cpp | 47 +++++++++++++++++++++++++++++++++++++++++----
+ 1 file changed, 43 insertions(+), 4 deletions(-)
 
 diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp
-index 35e3aef0a..84a3796b5 100644
+index 607d4b837..d12110a5c 100644
 --- a/tools/mtmd/clip.cpp
 +++ b/tools/mtmd/clip.cpp
-@@ -24,6 +24,19 @@
+@@ -25,6 +25,19 @@
  #include 
  #include 
  
@@ -33,7 +33,7 @@ index 35e3aef0a..84a3796b5 100644
  struct clip_logger_state g_logger_state = {clip_log_callback_default, NULL};
  
  //#define CLIP_DEBUG_FUNCTIONS
-@@ -1619,7 +1632,29 @@ struct clip_model_loader {
+@@ -1895,7 +1908,29 @@ struct clip_model_loader {
          {
              std::vector read_buf;
  
@@ -63,7 +63,7 @@ index 35e3aef0a..84a3796b5 100644
              if (!fin) {
                  throw std::runtime_error(string_format("%s: failed to open %s\n", __func__, fname.c_str()));
              }
-@@ -1646,7 +1681,11 @@ struct clip_model_loader {
+@@ -1922,7 +1957,11 @@ struct clip_model_loader {
                      ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
                  }
              }
@@ -75,3 +75,39 @@ index 35e3aef0a..84a3796b5 100644
  
              LOG_DBG("%s: loaded %zu tensors from %s\n", __func__, tensors_to_load.size(), fname.c_str());
          }
+@@ -2305,7 +2344,7 @@ struct img_tool {
+             std::array pad_color = {0, 0, 0}) {
+         dst.nx = target_resolution.width;
+         dst.ny = target_resolution.height;
+-        dst.buf.resize(3 * dst.nx * dst.ny);
++        dst.buf.resize(3 * static_cast(dst.nx) * static_cast(dst.ny));
+ 
+         if (dst.nx == src.nx && dst.ny == src.ny) {
+             // no resize needed, simple copy
+@@ -2358,7 +2397,7 @@ struct img_tool {
+     static void crop(const clip_image_u8 & image, clip_image_u8 & dst, int x, int y, int w, int h) {
+         dst.nx = w;
+         dst.ny = h;
+-        dst.buf.resize(3 * w * h);
++        dst.buf.resize(3 * static_cast(w) * static_cast(h));
+ 
+         for (int i = 0; i < h; ++i) {
+             for (int j = 0; j < w; ++j) {
+@@ -2455,7 +2494,7 @@ private:
+     static void resize_bilinear(const clip_image_u8 & src, clip_image_u8 & dst, int target_width, int target_height) {
+         dst.nx = target_width;
+         dst.ny = target_height;
+-        dst.buf.resize(3 * target_width * target_height);
++        dst.buf.resize(3 * static_cast(target_width) * static_cast(target_height));
+ 
+         float x_ratio = static_cast(src.nx - 1) / target_width;
+         float y_ratio = static_cast(src.ny - 1) / target_height;
+@@ -2494,7 +2533,7 @@ private:
+ 
+         dst.nx = target_width;
+         dst.ny = target_height;
+-        dst.buf.resize(3 * target_width * target_height);
++        dst.buf.resize(3 * static_cast(target_width) * static_cast(target_height));
+ 
+         float Cc;
+         float C[5] = {};
diff --git a/llama/patches/0004-solar-pro.patch b/llama/patches/0004-solar-pro.patch
index f267356ea61..ac83c09be4e 100644
--- a/llama/patches/0004-solar-pro.patch
+++ b/llama/patches/0004-solar-pro.patch
@@ -19,10 +19,10 @@ adds support for the Solar Pro architecture
  create mode 100644 src/models/solar.cpp
 
 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
-index 4192af7c0..bd44d73e7 100644
+index 283823fa9..d3232121b 100644
 --- a/src/CMakeLists.txt
 +++ b/src/CMakeLists.txt
-@@ -125,6 +125,7 @@ add_library(llama
+@@ -140,6 +140,7 @@ add_library(llama
              models/seed-oss.cpp
              models/smallthinker.cpp
              models/smollm3.cpp
@@ -31,10 +31,10 @@ index 4192af7c0..bd44d73e7 100644
              models/starcoder.cpp
              models/starcoder2.cpp
 diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
-index 8caf80afc..2ce8ffec0 100644
+index 47e8d5278..977783cbe 100644
 --- a/src/llama-arch.cpp
 +++ b/src/llama-arch.cpp
-@@ -87,6 +87,7 @@ static const std::map LLM_ARCH_NAMES = {
+@@ -95,6 +95,7 @@ static const std::map LLM_ARCH_NAMES = {
      { LLM_ARCH_GRANITE_MOE,      "granitemoe"       },
      { LLM_ARCH_GRANITE_HYBRID,   "granitehybrid"    },
      { LLM_ARCH_CHAMELEON,        "chameleon"        },
@@ -42,15 +42,15 @@ index 8caf80afc..2ce8ffec0 100644
      { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
      { LLM_ARCH_PLM,              "plm"              },
      { LLM_ARCH_BAILINGMOE,       "bailingmoe"       },
-@@ -208,6 +209,7 @@ static const std::map LLM_KV_NAMES = {
+@@ -227,6 +228,7 @@ static const std::map LLM_KV_NAMES = {
      { LLM_KV_ATTENTION_OUTPUT_SCALE,                 "%s.attention.output_scale"                 },
      { LLM_KV_ATTENTION_TEMPERATURE_LENGTH,           "%s.attention.temperature_length"           },
      { LLM_KV_ATTENTION_TEMPERATURE_SCALE,            "%s.attention.temperature_scale"            },
 +    { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,        "%s.attention.block_skip_connection"        },
      { LLM_KV_ATTENTION_KEY_LENGTH_MLA,               "%s.attention.key_length_mla"               },
      { LLM_KV_ATTENTION_VALUE_LENGTH_MLA,             "%s.attention.value_length_mla"             },
- 
-@@ -339,6 +341,7 @@ static const std::map LLM_TENSOR_NAMES = {
+     { LLM_KV_ATTENTION_INDEXER_HEAD_COUNT,           "%s.attention.indexer.head_count"           },
+@@ -365,6 +367,7 @@ static const std::map LLM_TENSOR_NAMES = {
      { LLM_TENSOR_ATTN_QKV,                               "blk.%d.attn_qkv" },
      { LLM_TENSOR_LAYER_OUT_NORM,                         "blk.%d.layer_output_norm" },
      { LLM_TENSOR_ATTN_OUT_NORM,                          "blk.%d.attn_output_norm" },
@@ -58,9 +58,9 @@ index 8caf80afc..2ce8ffec0 100644
      { LLM_TENSOR_POS_EMBD,                               "position_embd" },
      { LLM_TENSOR_FFN_ACT,                                "blk.%d.ffn.act" },
      { LLM_TENSOR_TOKEN_EMBD_NORM,                        "token_embd_norm" },
-@@ -2176,6 +2179,22 @@ static std::set llm_get_tensor_names(llm_arch arch) {
-             return {
-                 LLM_TENSOR_TOKEN_EMBD,
+@@ -2485,6 +2488,22 @@ static std::set llm_get_tensor_names(llm_arch arch) {
+                 LLM_TENSOR_FFN_DOWN,
+                 LLM_TENSOR_FFN_UP,
              };
 +        case LLM_ARCH_SOLAR:
 +            return {
@@ -70,7 +70,7 @@ index 8caf80afc..2ce8ffec0 100644
 +                LLM_TENSOR_ATTN_NORM,
 +                LLM_TENSOR_ATTN_Q,
 +                LLM_TENSOR_ATTN_K,
-+                LLM_TENSOR_ATTN_V,
++               LLM_TENSOR_ATTN_V,
 +                LLM_TENSOR_ATTN_OUT,
 +                LLM_TENSOR_FFN_NORM,
 +                LLM_TENSOR_FFN_GATE,
@@ -78,10 +78,10 @@ index 8caf80afc..2ce8ffec0 100644
 +                LLM_TENSOR_FFN_UP,
 +                LLM_TENSOR_BSKCN_TV,
 +            };
-         default:
-             GGML_ABORT("unknown architecture for tensor mapping");
-     }
-@@ -2344,6 +2363,7 @@ static const std::map LLM_TENSOR_INFOS = {
+         case LLM_ARCH_KIMI_LINEAR:
+             return {
+                 LLM_TENSOR_TOKEN_EMBD,
+@@ -2713,6 +2732,7 @@ static const std::map LLM_TENSOR_INFOS = {
      {LLM_TENSOR_LAUREL_POST_NORM,           {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
      // this tensor is loaded for T5, but never used
      {LLM_TENSOR_DEC_CROSS_ATTN_REL_B,       {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
@@ -90,10 +90,10 @@ index 8caf80afc..2ce8ffec0 100644
      {LLM_TENSOR_POS_NET_NORM,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
      {LLM_TENSOR_POS_NET_NORM1,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
 diff --git a/src/llama-arch.h b/src/llama-arch.h
-index 6cbf9b1f8..14d461c76 100644
+index 6d1b1df31..e9f2739ac 100644
 --- a/src/llama-arch.h
 +++ b/src/llama-arch.h
-@@ -91,6 +91,7 @@ enum llm_arch {
+@@ -99,6 +99,7 @@ enum llm_arch {
      LLM_ARCH_GRANITE_MOE,
      LLM_ARCH_GRANITE_HYBRID,
      LLM_ARCH_CHAMELEON,
@@ -101,27 +101,27 @@ index 6cbf9b1f8..14d461c76 100644
      LLM_ARCH_WAVTOKENIZER_DEC,
      LLM_ARCH_PLM,
      LLM_ARCH_BAILINGMOE,
-@@ -212,6 +213,7 @@ enum llm_kv {
+@@ -231,6 +232,7 @@ enum llm_kv {
      LLM_KV_ATTENTION_OUTPUT_SCALE,
      LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
      LLM_KV_ATTENTION_TEMPERATURE_SCALE,
 +    LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
      LLM_KV_ATTENTION_KEY_LENGTH_MLA,
      LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
- 
-@@ -465,6 +467,7 @@ enum llm_tensor {
-     LLM_TENSOR_ENC_OUTPUT_NORM,
+     LLM_KV_ATTENTION_INDEXER_HEAD_COUNT,
+@@ -502,6 +504,7 @@ enum llm_tensor {
      LLM_TENSOR_CLS,
      LLM_TENSOR_CLS_OUT,
+     LLM_TENSOR_CLS_NORM,
 +    LLM_TENSOR_BSKCN_TV,
      LLM_TENSOR_CONV1D,
      LLM_TENSOR_CONVNEXT_DW,
      LLM_TENSOR_CONVNEXT_NORM,
 diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp
-index fe1fa4341..aabff2f06 100644
+index 756dda1a7..515a900b3 100644
 --- a/src/llama-hparams.cpp
 +++ b/src/llama-hparams.cpp
-@@ -163,6 +163,14 @@ uint32_t llama_hparams::n_pos_per_embd() const {
+@@ -181,6 +181,14 @@ uint32_t llama_hparams::n_pos_per_embd() const {
      return rope_type == LLAMA_ROPE_TYPE_MROPE || rope_type == LLAMA_ROPE_TYPE_IMROPE ? 4 : 1;
  }
  
@@ -137,7 +137,7 @@ index fe1fa4341..aabff2f06 100644
      if (il < n_layer) {
          return swa_layers[il];
 diff --git a/src/llama-hparams.h b/src/llama-hparams.h
-index f6e95b5d2..c6e673276 100644
+index c4b2a99da..ccca1bd5f 100644
 --- a/src/llama-hparams.h
 +++ b/src/llama-hparams.h
 @@ -65,6 +65,8 @@ struct llama_hparams {
@@ -149,7 +149,7 @@ index f6e95b5d2..c6e673276 100644
      uint32_t n_layer_dense_lead = 0;
      uint32_t n_lora_q           = 0;
      uint32_t n_lora_kv          = 0;
-@@ -259,6 +261,9 @@ struct llama_hparams {
+@@ -279,6 +281,9 @@ struct llama_hparams {
  
      uint32_t n_pos_per_embd() const;
  
@@ -158,12 +158,12 @@ index f6e95b5d2..c6e673276 100644
 +
      bool is_swa(uint32_t il) const;
  
-     bool has_kv(uint32_t il) const;
+     // note: currently only support if either all or none of the layers are MLA
 diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp
-index ca2ea2461..8916a6242 100644
+index 1501e392c..37b69a4b3 100644
 --- a/src/llama-model-loader.cpp
 +++ b/src/llama-model-loader.cpp
-@@ -466,7 +466,7 @@ namespace GGUFMeta {
+@@ -497,7 +497,7 @@ namespace GGUFMeta {
      template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required);
      template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required);
      template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required);
@@ -173,10 +173,10 @@ index ca2ea2461..8916a6242 100644
  llama_model_loader::llama_model_loader(
          const std::string & fname,
 diff --git a/src/llama-model.cpp b/src/llama-model.cpp
-index ae8207ee1..00cd579e0 100644
+index dabf3b308..6436cde36 100644
 --- a/src/llama-model.cpp
 +++ b/src/llama-model.cpp
-@@ -1995,6 +1995,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
+@@ -2194,6 +2194,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                      default: type = LLM_TYPE_UNKNOWN;
                 }
              } break;
@@ -198,7 +198,7 @@ index ae8207ee1..00cd579e0 100644
          case LLM_ARCH_WAVTOKENIZER_DEC:
              {
                  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
-@@ -5429,6 +5444,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
+@@ -6146,6 +6161,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
  
                          layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
  
@@ -233,7 +233,7 @@ index ae8207ee1..00cd579e0 100644
                          layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
                          layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
                          layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-@@ -7534,6 +7577,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
+@@ -8723,6 +8766,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
              {
                  llm = std::make_unique(*this, params);
              } break;
@@ -244,7 +244,7 @@ index ae8207ee1..00cd579e0 100644
          case LLM_ARCH_WAVTOKENIZER_DEC:
              {
                  llm = std::make_unique(*this, params);
-@@ -7798,6 +7845,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
+@@ -9026,6 +9073,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
          case LLM_ARCH_GRANITE_MOE:
          case LLM_ARCH_GRANITE_HYBRID:
          case LLM_ARCH_CHAMELEON:
@@ -253,10 +253,10 @@ index ae8207ee1..00cd579e0 100644
          case LLM_ARCH_NEO_BERT:
          case LLM_ARCH_SMOLLM3:
 diff --git a/src/llama-model.h b/src/llama-model.h
-index c6eb95318..b378b23ec 100644
+index d7c3e7d1c..679977bee 100644
 --- a/src/llama-model.h
 +++ b/src/llama-model.h
-@@ -76,6 +76,7 @@ enum llm_type {
+@@ -80,6 +80,7 @@ enum llm_type {
      LLM_TYPE_15B,
      LLM_TYPE_16B,
      LLM_TYPE_20B,
@@ -264,9 +264,9 @@ index c6eb95318..b378b23ec 100644
      LLM_TYPE_26B,
      LLM_TYPE_27B,
      LLM_TYPE_30B,
-@@ -405,6 +406,8 @@ struct llama_layer {
-     struct ggml_tensor * ffn_act_beta    = nullptr;
-     struct ggml_tensor * ffn_act_eps     = nullptr;
+@@ -440,6 +441,8 @@ struct llama_layer {
+     struct ggml_tensor * indexer_attn_k   = nullptr;
+     struct ggml_tensor * indexer_attn_q_b = nullptr; // note: for lora a/b, not bias
  
 +    struct ggml_tensor * bskcn_tv = nullptr;
 +
@@ -274,10 +274,10 @@ index c6eb95318..b378b23ec 100644
  
      struct llama_layer_convnext convnext;
 diff --git a/src/models/models.h b/src/models/models.h
-index ffb36acc6..6d84a185d 100644
+index 0712d03d8..d076bf288 100644
 --- a/src/models/models.h
 +++ b/src/models/models.h
-@@ -515,6 +515,11 @@ struct llm_build_smollm3 : public llm_graph_context {
+@@ -651,6 +651,11 @@ struct llm_build_smollm3 : public llm_graph_context {
      llm_build_smollm3(const llama_model & model, const llm_graph_params & params);
  };
  
diff --git a/llama/patches/0005-fix-deepseek-deseret-regex.patch b/llama/patches/0005-fix-deepseek-deseret-regex.patch
deleted file mode 100644
index 9aa2ae46bb8..00000000000
--- a/llama/patches/0005-fix-deepseek-deseret-regex.patch
+++ /dev/null
@@ -1,72 +0,0 @@
-From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: jmorganca 
-Date: Tue, 8 Apr 2025 19:43:06 -0700
-Subject: [PATCH] fix deepseek deseret regex
-
-on some systems, deepseek's regex would throw an error
-on windows due to the deseret characters in the matching
-regex
----
- src/llama-vocab.cpp |  2 +-
- src/unicode.cpp     | 21 +++++++++++++++++++++
- 2 files changed, 22 insertions(+), 1 deletion(-)
-
-diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
-index 63250cdf1..dd86a1745 100644
---- a/src/llama-vocab.cpp
-+++ b/src/llama-vocab.cpp
-@@ -299,7 +299,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
-             case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM:
-                 regex_exprs = {
-                     "[\r\n]",
--                    "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
-+                    "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z\U00010400-\U0001044f𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
-                     "\\s?[!-/:-~!-/:-~‘-‟ -。]+",
-                     "\\s+$",
-                     "[一-龥ࠀ-一가-퟿]+",
-diff --git a/src/unicode.cpp b/src/unicode.cpp
-index bb44edfad..13ced055f 100644
---- a/src/unicode.cpp
-+++ b/src/unicode.cpp
-@@ -2,6 +2,11 @@
- #define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
- #endif
- 
-+#if defined(_WIN32)
-+#define WIN32_LEAN_AND_MEAN
-+#include 
-+#endif
-+
- #include "unicode.h"
- #include "unicode-data.h"
- 
-@@ -200,6 +205,21 @@ static std::unordered_map unicode_utf8_to_byte_map() {
- }
- 
- static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
-+#ifdef _WIN32
-+    int wlen = MultiByteToWideChar(CP_UTF8, 0, s.c_str(), -1, NULL, 0);
-+    if (!wlen) {
-+        throw std::invalid_argument("failed to convert regex");
-+    }
-+    wchar_t * wbuf = (wchar_t *) malloc(wlen * sizeof(wchar_t));
-+    wlen = MultiByteToWideChar(CP_UTF8, 0, s.c_str(), -1, wbuf, wlen);
-+    if (!wlen) {
-+        free(wbuf);
-+        throw std::invalid_argument("failed to convert regex");
-+    }
-+    std::wstring ret = std::wstring(wbuf);
-+    free(wbuf);
-+    return ret;
-+#else
- #if defined(__clang__)
-     // disable C++17 deprecation warning for std::codecvt_utf8
- #    pragma clang diagnostic push
-@@ -218,6 +238,7 @@ static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
- #endif
- 
-     return conv.from_bytes(s);
-+#endif
- }
- 
- static std::vector unicode_byte_encoding_process(const std::vector & bpe_words) {
diff --git a/llama/patches/0006-maintain-ordering-for-rules-for-grammar.patch b/llama/patches/0005-maintain-ordering-for-rules-for-grammar.patch
similarity index 100%
rename from llama/patches/0006-maintain-ordering-for-rules-for-grammar.patch
rename to llama/patches/0005-maintain-ordering-for-rules-for-grammar.patch
diff --git a/llama/patches/0007-sort-devices-by-score.patch b/llama/patches/0006-sort-devices-by-score.patch
similarity index 89%
rename from llama/patches/0007-sort-devices-by-score.patch
rename to llama/patches/0006-sort-devices-by-score.patch
index f45da396a4c..9ff9d99d022 100644
--- a/llama/patches/0007-sort-devices-by-score.patch
+++ b/llama/patches/0006-sort-devices-by-score.patch
@@ -11,10 +11,10 @@ with the fastest acceleration is loaded
  1 file changed, 13 insertions(+), 8 deletions(-)
 
 diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
-index 4181a714a..079dba211 100644
+index 311fa5fe3..03e32b2d5 100644
 --- a/ggml/src/ggml-backend-reg.cpp
 +++ b/ggml/src/ggml-backend-reg.cpp
-@@ -183,7 +183,7 @@ struct ggml_backend_reg_entry {
+@@ -106,7 +106,7 @@ struct ggml_backend_reg_entry {
  
  struct ggml_backend_registry {
      std::vector backends;
@@ -23,7 +23,7 @@ index 4181a714a..079dba211 100644
  
      ggml_backend_registry() {
  #ifdef GGML_USE_CUDA
-@@ -237,7 +237,7 @@ struct ggml_backend_registry {
+@@ -169,7 +169,7 @@ struct ggml_backend_registry {
          }
      }
  
@@ -32,7 +32,7 @@ index 4181a714a..079dba211 100644
          if (!reg) {
              return;
          }
-@@ -248,15 +248,20 @@ struct ggml_backend_registry {
+@@ -180,15 +180,20 @@ struct ggml_backend_registry {
  #endif
          backends.push_back({ reg, std::move(handle) });
          for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) {
@@ -56,7 +56,7 @@ index 4181a714a..079dba211 100644
      }
  
      ggml_backend_reg_t load_backend(const fs::path & path, bool silent) {
-@@ -300,7 +305,7 @@ struct ggml_backend_registry {
+@@ -232,7 +237,7 @@ struct ggml_backend_registry {
  
          GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path_str(path).c_str());
  
@@ -65,7 +65,7 @@ index 4181a714a..079dba211 100644
  
          return reg;
      }
-@@ -323,7 +328,7 @@ struct ggml_backend_registry {
+@@ -255,7 +260,7 @@ struct ggml_backend_registry {
          // remove devices
          devices.erase(
              std::remove_if(devices.begin(), devices.end(),
@@ -74,7 +74,7 @@ index 4181a714a..079dba211 100644
              devices.end());
  
          // remove backend
-@@ -381,7 +386,7 @@ size_t ggml_backend_dev_count() {
+@@ -313,7 +318,7 @@ size_t ggml_backend_dev_count() {
  
  ggml_backend_dev_t ggml_backend_dev_get(size_t index) {
      GGML_ASSERT(index < ggml_backend_dev_count());
diff --git a/llama/patches/0008-add-phony-target-ggml-cpu-for-all-cpu-variants.patch b/llama/patches/0007-add-phony-target-ggml-cpu-for-all-cpu-variants.patch
similarity index 79%
rename from llama/patches/0008-add-phony-target-ggml-cpu-for-all-cpu-variants.patch
rename to llama/patches/0007-add-phony-target-ggml-cpu-for-all-cpu-variants.patch
index 315613e0ad4..a643db10df5 100644
--- a/llama/patches/0008-add-phony-target-ggml-cpu-for-all-cpu-variants.patch
+++ b/llama/patches/0007-add-phony-target-ggml-cpu-for-all-cpu-variants.patch
@@ -8,10 +8,10 @@ Subject: [PATCH] add phony target ggml-cpu for all cpu variants
  1 file changed, 2 insertions(+)
 
 diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
-index 4c04c3300..f4747f262 100644
+index 265023733..a20b7e54a 100644
 --- a/ggml/src/CMakeLists.txt
 +++ b/ggml/src/CMakeLists.txt
-@@ -345,6 +345,7 @@ function(ggml_add_cpu_backend_variant tag_name)
+@@ -346,6 +346,7 @@ function(ggml_add_cpu_backend_variant tag_name)
      endif()
  
      ggml_add_cpu_backend_variant_impl(${tag_name})
@@ -19,11 +19,11 @@ index 4c04c3300..f4747f262 100644
  endfunction()
  
  ggml_add_backend(CPU)
-@@ -355,6 +356,7 @@ if (GGML_CPU_ALL_VARIANTS)
+@@ -356,6 +357,7 @@ if (GGML_CPU_ALL_VARIANTS)
      elseif (GGML_CPU_ARM_ARCH)
          message(FATAL_ERROR "Cannot use both GGML_CPU_ARM_ARCH and GGML_CPU_ALL_VARIANTS")
      endif()
 +    add_custom_target(ggml-cpu)
      if (GGML_SYSTEM_ARCH STREQUAL "x86")
          ggml_add_cpu_backend_variant(x64)
-         ggml_add_cpu_backend_variant(sse42        SSE42)
+         ggml_add_cpu_backend_variant(sse42              SSE42)
diff --git a/llama/patches/0009-remove-amx.patch b/llama/patches/0008-remove-amx.patch
similarity index 55%
rename from llama/patches/0009-remove-amx.patch
rename to llama/patches/0008-remove-amx.patch
index cace86f96f8..1c87a6d8ff7 100644
--- a/llama/patches/0009-remove-amx.patch
+++ b/llama/patches/0008-remove-amx.patch
@@ -5,21 +5,22 @@ Subject: [PATCH] remove amx
 
 disable amx as it reduces performance on some systems
 ---
- ggml/src/CMakeLists.txt | 4 ----
- 1 file changed, 4 deletions(-)
+ ggml/src/CMakeLists.txt | 5 +----
+ 1 file changed, 1 insertion(+), 4 deletions(-)
 
 diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
-index f4747f262..d55aed348 100644
+index a20b7e54a..dbcb5ef5d 100644
 --- a/ggml/src/CMakeLists.txt
 +++ b/ggml/src/CMakeLists.txt
-@@ -365,10 +365,6 @@ if (GGML_CPU_ALL_VARIANTS)
-         ggml_add_cpu_backend_variant(skylakex     SSE42 AVX F16C AVX2 BMI2 FMA AVX512)
-         ggml_add_cpu_backend_variant(icelake      SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
-         ggml_add_cpu_backend_variant(alderlake    SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI)
+@@ -380,10 +380,7 @@ if (GGML_CPU_ALL_VARIANTS)
+             ggml_add_cpu_backend_variant(zen4           SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16)
+         endif()
+         ggml_add_cpu_backend_variant(alderlake          SSE42 AVX F16C FMA AVX2 BMI2 AVX_VNNI)
 -        if (NOT MSVC)
 -            # MSVC doesn't support AMX
--            ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
+-            ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
 -        endif()
++        # AMX variants removed by ollama - sapphirerapids with AMX_TILE AMX_INT8 not included
      elseif(GGML_SYSTEM_ARCH STREQUAL "ARM")
          if (CMAKE_SYSTEM_NAME MATCHES "Linux")
              # Many of these features are optional so we build versions with popular
diff --git a/llama/patches/0010-fix-string-arr-kv-loading.patch b/llama/patches/0009-fix-string-arr-kv-loading.patch
similarity index 91%
rename from llama/patches/0010-fix-string-arr-kv-loading.patch
rename to llama/patches/0009-fix-string-arr-kv-loading.patch
index 63acee83322..870a4e89e77 100644
--- a/llama/patches/0010-fix-string-arr-kv-loading.patch
+++ b/llama/patches/0009-fix-string-arr-kv-loading.patch
@@ -25,10 +25,10 @@ index 79ee20206..3efb22f01 100644
      // get ith C string from array with given key_id
      GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);
 diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp
-index b165d8bdc..f91d4faba 100644
+index cbeedf6c4..6fb6ea927 100644
 --- a/ggml/src/gguf.cpp
 +++ b/ggml/src/gguf.cpp
-@@ -805,10 +805,14 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id
+@@ -915,10 +915,14 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id
  
  const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) {
      GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
@@ -44,7 +44,7 @@ index b165d8bdc..f91d4faba 100644
  const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) {
      GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
      GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING);
-@@ -902,7 +906,6 @@ const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) {
+@@ -1012,7 +1016,6 @@ const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) {
  const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) {
      GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
      GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
@@ -53,10 +53,10 @@ index b165d8bdc..f91d4faba 100644
  }
  
 diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
-index dd86a1745..d63ce9c84 100644
+index 385484304..b244318bd 100644
 --- a/src/llama-vocab.cpp
 +++ b/src/llama-vocab.cpp
-@@ -1781,9 +1781,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
+@@ -1827,9 +1827,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
              const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
              if (precompiled_charsmap_keyidx != -1) {
                  const gguf_type pc_type = gguf_get_arr_type(ctx, precompiled_charsmap_keyidx);
diff --git a/llama/patches/0011-ollama-debug-tensor.patch b/llama/patches/0010-ollama-debug-tensor.patch
similarity index 89%
rename from llama/patches/0011-ollama-debug-tensor.patch
rename to llama/patches/0010-ollama-debug-tensor.patch
index a2a4eb6b6c1..859b0ee01ef 100644
--- a/llama/patches/0011-ollama-debug-tensor.patch
+++ b/llama/patches/0010-ollama-debug-tensor.patch
@@ -8,19 +8,19 @@ Subject: [PATCH] ollama debug tensor
  1 file changed, 6 insertions(+)
 
 diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
-index a59b51893..53891a91f 100644
+index 64eb01a4e..2aa324852 100644
 --- a/ggml/src/ggml-cpu/ggml-cpu.c
 +++ b/ggml/src/ggml-cpu/ggml-cpu.c
 @@ -15,6 +15,8 @@
- #include "ops.h"
  #include "ggml.h"
+ #include "common.h"
  
 +#include "ollama-debug.h"
 +
  #if defined(_MSC_VER) || defined(__MINGW32__)
  #include  // using malloc.h with MSC/MINGW
  #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
-@@ -2945,6 +2947,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
+@@ -2967,6 +2969,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
  
          ggml_compute_forward(¶ms, node);
  
diff --git a/llama/patches/0012-add-ollama-vocab-for-grammar-support.patch b/llama/patches/0011-add-ollama-vocab-for-grammar-support.patch
similarity index 86%
rename from llama/patches/0012-add-ollama-vocab-for-grammar-support.patch
rename to llama/patches/0011-add-ollama-vocab-for-grammar-support.patch
index f26e1bc29e2..35486de2d2c 100644
--- a/llama/patches/0012-add-ollama-vocab-for-grammar-support.patch
+++ b/llama/patches/0011-add-ollama-vocab-for-grammar-support.patch
@@ -4,16 +4,16 @@ Date: Mon, 21 Apr 2025 13:30:31 -0700
 Subject: [PATCH] add ollama vocab for grammar support
 
 ---
- src/llama-grammar.cpp  | 48 ++++++++++++++++++++++++++++++++++++------
- src/llama-grammar.h    | 14 ++++++++++++
- src/llama-sampling.cpp |  6 +++---
+ src/llama-grammar.cpp | 48 ++++++++++++++++++++++++++++++++++++-------
+ src/llama-grammar.h   | 14 +++++++++++++
+ src/llama-sampler.cpp |  6 +++---
  3 files changed, 58 insertions(+), 10 deletions(-)
 
 diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp
-index 75d5d750c..a0299d181 100644
+index 2d55070ce..9d3b896a6 100644
 --- a/src/llama-grammar.cpp
 +++ b/src/llama-grammar.cpp
-@@ -1041,6 +1041,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
+@@ -1079,6 +1079,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
  
  struct llama_grammar * llama_grammar_init_impl(
          const struct llama_vocab * vocab,
@@ -21,7 +21,7 @@ index 75d5d750c..a0299d181 100644
          const llama_grammar_element ** rules,
          size_t n_rules,
          size_t start_rule_index) {
-@@ -1096,6 +1097,7 @@ struct llama_grammar * llama_grammar_init_impl(
+@@ -1134,6 +1135,7 @@ struct llama_grammar * llama_grammar_init_impl(
      // then the pointers would be invalidated when the local vec_rules goes out of scope.
      return new llama_grammar {
          vocab,
@@ -29,7 +29,7 @@ index 75d5d750c..a0299d181 100644
          std::move(vec_rules),
          std::move(stacks),
          /* .partial_utf8 = */             {},
-@@ -1110,6 +1112,7 @@ struct llama_grammar * llama_grammar_init_impl(
+@@ -1148,6 +1150,7 @@ struct llama_grammar * llama_grammar_init_impl(
  
  struct llama_grammar * llama_grammar_init_impl(
          const struct llama_vocab * vocab,
@@ -37,7 +37,7 @@ index 75d5d750c..a0299d181 100644
                        const char * grammar_str,
                        const char * grammar_root,
                                bool lazy,
-@@ -1202,6 +1205,7 @@ struct llama_grammar * llama_grammar_init_impl(
+@@ -1240,6 +1243,7 @@ struct llama_grammar * llama_grammar_init_impl(
      // then the pointers would be invalidated when the local vec_rules goes out of scope.
      return new llama_grammar {
          vocab,
@@ -45,7 +45,7 @@ index 75d5d750c..a0299d181 100644
          std::move(vec_rules),
          std::move(stacks),
          /* .partial_utf8 = */             {},
-@@ -1225,6 +1229,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
+@@ -1263,6 +1267,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
  struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
      auto * result = new llama_grammar {
          grammar.vocab,
@@ -53,7 +53,7 @@ index 75d5d750c..a0299d181 100644
          grammar.rules,
          grammar.stacks,
          grammar.partial_utf8,
-@@ -1253,7 +1258,6 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
+@@ -1291,7 +1296,6 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
  }
  
  void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
@@ -61,7 +61,7 @@ index 75d5d750c..a0299d181 100644
  
      if (grammar.awaiting_trigger) {
          return;
-@@ -1275,9 +1279,13 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
+@@ -1313,9 +1317,13 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
  
      for (size_t i = 0; i < cur_p->size; ++i) {
          const llama_token id      = cur_p->data[i].id;
@@ -77,7 +77,7 @@ index 75d5d750c..a0299d181 100644
              if (!allow_eog) {
                  cur_p->data[i].logit = -INFINITY;
              }
-@@ -1296,9 +1304,10 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
+@@ -1334,9 +1342,10 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
  }
  
  void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
@@ -90,7 +90,7 @@ index 75d5d750c..a0299d181 100644
  
      if (grammar.awaiting_trigger) {
          if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
-@@ -1353,13 +1362,14 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
+@@ -1380,13 +1389,14 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
          }
      }
  
@@ -107,7 +107,7 @@ index 75d5d750c..a0299d181 100644
      }
  
      llama_grammar_accept_token(grammar, token, piece);
-@@ -1435,3 +1445,27 @@ void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token toke
+@@ -1462,3 +1472,27 @@ void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token toke
      }
  }
  
@@ -136,7 +136,7 @@ index 75d5d750c..a0299d181 100644
 +    }
 +}
 diff --git a/src/llama-grammar.h b/src/llama-grammar.h
-index a4c978ac1..5c0da4049 100644
+index b5a0e588e..57847583a 100644
 --- a/src/llama-grammar.h
 +++ b/src/llama-grammar.h
 @@ -6,8 +6,19 @@
@@ -159,7 +159,7 @@ index a4c978ac1..5c0da4049 100644
  
  // grammar element type
  enum llama_gretype {
-@@ -127,6 +138,7 @@ struct llama_grammar {
+@@ -129,6 +140,7 @@ struct llama_grammar {
  
      // note: allow null vocab for testing (not great)
      const llama_vocab * vocab;
@@ -167,7 +167,7 @@ index a4c978ac1..5c0da4049 100644
  
      const llama_grammar_rules  rules;  // TODO: shared ptr
            llama_grammar_stacks stacks;
-@@ -155,12 +167,14 @@ struct llama_grammar {
+@@ -157,12 +169,14 @@ struct llama_grammar {
  // note: needed for tests (not great)
  struct llama_grammar * llama_grammar_init_impl(
          const struct llama_vocab * vocab,
@@ -182,11 +182,11 @@ index a4c978ac1..5c0da4049 100644
                        const char * grammar_str,
                        const char * grammar_root,
                                bool lazy,
-diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp
-index 3f4a729bc..38a30ea05 100644
---- a/src/llama-sampling.cpp
-+++ b/src/llama-sampling.cpp
-@@ -1561,7 +1561,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
+diff --git a/src/llama-sampler.cpp b/src/llama-sampler.cpp
+index 9bbc5dbde..5cf66b63f 100644
+--- a/src/llama-sampler.cpp
++++ b/src/llama-sampler.cpp
+@@ -2477,7 +2477,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
          trigger_patterns_c.push_back(trigger_pattern.pattern.c_str());
      }
  
@@ -195,7 +195,7 @@ index 3f4a729bc..38a30ea05 100644
                                                   ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(),
                                                   ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
  
-@@ -1639,9 +1639,9 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
+@@ -2559,9 +2559,9 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
              trigger_pattern += ")[\\s\\S]*";
  
              std::array tmp_trigger_patterns = { trigger_pattern.c_str() };
diff --git a/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch b/llama/patches/0012-add-argsort-and-cuda-copy-for-i32.patch
similarity index 95%
rename from llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch
rename to llama/patches/0012-add-argsort-and-cuda-copy-for-i32.patch
index a022e33eb75..ae15ac01a0a 100644
--- a/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch
+++ b/llama/patches/0012-add-argsort-and-cuda-copy-for-i32.patch
@@ -5,17 +5,17 @@ Subject: [PATCH] add argsort and cuda copy for i32
 
 ---
  ggml/src/ggml-cpu/ops.cpp            |  43 ++++++
- ggml/src/ggml-cuda/argsort.cu        | 122 +++++++++++++--
+ ggml/src/ggml-cuda/argsort.cu        | 120 +++++++++++++--
  ggml/src/ggml-cuda/cpy-utils.cuh     |   6 +
  ggml/src/ggml-cuda/cpy.cu            |  40 +++++
  ggml/src/ggml-metal/ggml-metal.metal | 215 +++++++++++++++++++++++++++
- 5 files changed, 414 insertions(+), 12 deletions(-)
+ 5 files changed, 413 insertions(+), 11 deletions(-)
 
 diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
-index 303278397..7d1733adb 100644
+index b7a70e06f..c90b7db7e 100644
 --- a/ggml/src/ggml-cpu/ops.cpp
 +++ b/ggml/src/ggml-cpu/ops.cpp
-@@ -7932,6 +7932,45 @@ static void ggml_compute_forward_argsort_f32(
+@@ -8022,6 +8022,45 @@ static void ggml_compute_forward_argsort_f32(
      }
  }
  
@@ -61,7 +61,7 @@ index 303278397..7d1733adb 100644
  void ggml_compute_forward_argsort(
      const ggml_compute_params * params,
      ggml_tensor * dst) {
-@@ -7943,6 +7982,10 @@ void ggml_compute_forward_argsort(
+@@ -8033,6 +8072,10 @@ void ggml_compute_forward_argsort(
              {
                  ggml_compute_forward_argsort_f32(params, dst);
              } break;
@@ -73,10 +73,10 @@ index 303278397..7d1733adb 100644
              {
                  GGML_ABORT("fatal error");
 diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu
-index da9652c3b..b82be371c 100644
+index 4896669c3..6fae8b808 100644
 --- a/ggml/src/ggml-cuda/argsort.cu
 +++ b/ggml/src/ggml-cuda/argsort.cu
-@@ -168,13 +168,107 @@ static void argsort_f32_i32_cuda_bitonic(const float *   x,
+@@ -198,13 +198,107 @@ void argsort_f32_i32_cuda_bitonic(const float *   x,
      }
  }
  
@@ -185,28 +185,27 @@ index da9652c3b..b82be371c 100644
      GGML_ASSERT( dst->type == GGML_TYPE_I32);
      GGML_ASSERT(ggml_is_contiguous(src0));
  
-@@ -183,18 +277,22 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+@@ -213,18 +307,22 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
  
      enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
  
--#ifdef GGML_CUDA_USE_CUB
++    if (src0->type == GGML_TYPE_I32) {
++        argsort_i32_i32_cuda((const int32_t *)src0_d, (int *)dst_d, ncols, nrows, order, stream);
++    } else {
+ #ifdef GGML_CUDA_USE_CUB
 -    const int    ncols_pad      = next_power_of_2(ncols);
 -    const size_t shared_mem     = ncols_pad * sizeof(int);
 -    const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
--
++        const int    ncols_pad      = next_power_of_2(ncols);
++        const size_t shared_mem     = ncols_pad * sizeof(int);
++        const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
+ 
 -    if (shared_mem > max_shared_mem || ncols > 1024) {
 -        ggml_cuda_pool & pool = ctx.pool();
 -        argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
-+    if (src0->type == GGML_TYPE_I32) {
-+        argsort_i32_i32_cuda((const int32_t *)src0_d, (int *)dst_d, ncols, nrows, order, stream);
-     } else {
+-    } else {
 -        argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
 -    }
-+#ifdef GGML_CUDA_USE_CUB
-+        const int    ncols_pad      = next_power_of_2(ncols);
-+        const size_t shared_mem     = ncols_pad * sizeof(int);
-+        const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
-+
 +        if (shared_mem > max_shared_mem || ncols > 1024) {
 +            ggml_cuda_pool & pool = ctx.pool();
 +            argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
@@ -234,10 +233,10 @@ index 7697c292d..00d773dd3 100644
 +    *dst = *src;
 +}
 diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
-index c4ceb4fc5..0e53ecc39 100644
+index ee84303ef..178e82d76 100644
 --- a/ggml/src/ggml-cuda/cpy.cu
 +++ b/ggml/src/ggml-cuda/cpy.cu
-@@ -352,6 +352,43 @@ static void ggml_cpy_f32_iq4_nl_cuda(
+@@ -369,6 +369,43 @@ static void ggml_cpy_f32_iq4_nl_cuda(
          (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
  }
  
@@ -281,7 +280,7 @@ index c4ceb4fc5..0e53ecc39 100644
  void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
      const int64_t ne = ggml_nelements(src0);
      GGML_ASSERT(ne == ggml_nelements(src1));
-@@ -481,6 +518,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
+@@ -495,6 +532,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
              ggml_cpy_scalar_cuda
                  (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
          }
@@ -292,10 +291,10 @@ index c4ceb4fc5..0e53ecc39 100644
          if (can_be_transposed) {
              ggml_cpy_scalar_cuda
 diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
-index 51bcbae30..236838e9e 100644
+index 6c349aa0c..628af4bb4 100644
 --- a/ggml/src/ggml-metal/ggml-metal.metal
 +++ b/ggml/src/ggml-metal/ggml-metal.metal
-@@ -4954,8 +4954,77 @@ kernel void kernel_argsort_f32_i32(
+@@ -4736,8 +4736,77 @@ kernel void kernel_argsort_f32_i32(
      }
  }
  
@@ -373,7 +372,7 @@ index 51bcbae30..236838e9e 100644
  
  typedef void (argsort_merge_t)(
          constant   ggml_metal_kargs_argsort_merge & args,
-@@ -5110,8 +5179,154 @@ kernel void kernel_argsort_merge_f32_i32(
+@@ -4892,8 +4961,154 @@ kernel void kernel_argsort_merge_f32_i32(
      }
  }
  
@@ -526,5 +525,5 @@ index 51bcbae30..236838e9e 100644
 +template [[host_name("kernel_argsort_merge_i32_i32_asc")]]  kernel argsort_merge_t kernel_argsort_merge_i32_i32;
 +template [[host_name("kernel_argsort_merge_i32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_i32_i32;
  
- kernel void kernel_leaky_relu_f32(
-         constant     ggml_metal_kargs_leaky_relu & args,
+ constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
+ 
diff --git a/llama/patches/0014-graph-memory-reporting-on-failure.patch b/llama/patches/0013-graph-memory-reporting-on-failure.patch
similarity index 91%
rename from llama/patches/0014-graph-memory-reporting-on-failure.patch
rename to llama/patches/0013-graph-memory-reporting-on-failure.patch
index 0b818ec89e4..5323992a1c3 100644
--- a/llama/patches/0014-graph-memory-reporting-on-failure.patch
+++ b/llama/patches/0013-graph-memory-reporting-on-failure.patch
@@ -23,7 +23,7 @@ index 78aa059dd..7fa8403b3 100644
  // Utils
  // Create a buffer and allocate all the tensors in a ggml_context
 diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
-index 4ed5f3577..a7ebe5dcd 100644
+index a9d177864..393c329be 100644
 --- a/ggml/include/ggml-backend.h
 +++ b/ggml/include/ggml-backend.h
 @@ -319,6 +319,7 @@ extern "C" {
@@ -35,10 +35,10 @@ index 4ed5f3577..a7ebe5dcd 100644
      GGML_API void                 ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
      GGML_API ggml_backend_t       ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
 diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c
-index 41419b617..73b39bfea 100644
+index 7f414b231..d32b016af 100644
 --- a/ggml/src/ggml-alloc.c
 +++ b/ggml/src/ggml-alloc.c
-@@ -485,6 +485,7 @@ struct node_alloc {
+@@ -480,6 +480,7 @@ struct node_alloc {
  struct ggml_gallocr {
      ggml_backend_buffer_type_t * bufts; // [n_buffers]
      struct vbuffer ** buffers; // [n_buffers]
@@ -46,7 +46,7 @@ index 41419b617..73b39bfea 100644
      struct ggml_dyn_tallocr ** buf_tallocs; // [n_buffers]
      int n_buffers;
  
-@@ -508,6 +509,9 @@ ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs
+@@ -503,6 +504,9 @@ ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs
      galloc->buffers = calloc(n_bufs, sizeof(struct vbuffer *));
      GGML_ASSERT(galloc->buffers != NULL);
  
@@ -56,7 +56,7 @@ index 41419b617..73b39bfea 100644
      galloc->buf_tallocs = calloc(n_bufs, sizeof(struct ggml_dyn_tallocr *));
      GGML_ASSERT(galloc->buf_tallocs != NULL);
  
-@@ -575,6 +579,7 @@ void ggml_gallocr_free(ggml_gallocr_t galloc) {
+@@ -570,6 +574,7 @@ void ggml_gallocr_free(ggml_gallocr_t galloc) {
      ggml_hash_set_free(&galloc->hash_set);
      free(galloc->hash_values);
      free(galloc->bufts);
@@ -64,7 +64,7 @@ index 41419b617..73b39bfea 100644
      free(galloc->buffers);
      free(galloc->buf_tallocs);
      free(galloc->node_allocs);
-@@ -904,6 +909,8 @@ static bool ggml_gallocr_reserve_n_impl(
+@@ -899,6 +904,8 @@ static bool ggml_gallocr_reserve_n_impl(
          }
      }
  
@@ -73,7 +73,7 @@ index 41419b617..73b39bfea 100644
      // reallocate buffers if needed
      for (int i = 0; i < galloc->n_buffers; i++) {
          // if the buffer type is used multiple times, we reuse the same buffer
-@@ -940,15 +947,20 @@ static bool ggml_gallocr_reserve_n_impl(
+@@ -935,15 +942,20 @@ static bool ggml_gallocr_reserve_n_impl(
                  galloc->buffers[i] = NULL;
              } else {
                  galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
@@ -97,7 +97,7 @@ index 41419b617..73b39bfea 100644
  }
  
  void ggml_gallocr_reserve_n_size(
-@@ -1118,6 +1130,22 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) {
+@@ -1113,6 +1125,22 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) {
      return ggml_vbuffer_size(galloc->buffers[buffer_id]);
  }
  
@@ -121,10 +121,10 @@ index 41419b617..73b39bfea 100644
  
  static void free_buffers(ggml_backend_buffer_t ** buffers, const size_t * n_buffers) {
 diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
-index 9f37ca70c..1459d16dd 100644
+index c522fec01..c1f4f63ba 100644
 --- a/ggml/src/ggml-backend.cpp
 +++ b/ggml/src/ggml-backend.cpp
-@@ -1859,6 +1859,13 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe
+@@ -1861,6 +1861,13 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe
      return ggml_gallocr_get_buffer_size(sched->galloc, backend_index);
  }
  
diff --git a/llama/patches/0015-ggml-Export-GPU-UUIDs.patch b/llama/patches/0014-ggml-Export-GPU-UUIDs.patch
similarity index 86%
rename from llama/patches/0015-ggml-Export-GPU-UUIDs.patch
rename to llama/patches/0014-ggml-Export-GPU-UUIDs.patch
index ec0dfdc6138..5f1e63720da 100644
--- a/llama/patches/0015-ggml-Export-GPU-UUIDs.patch
+++ b/llama/patches/0014-ggml-Export-GPU-UUIDs.patch
@@ -4,28 +4,29 @@ Date: Sun, 30 Nov 2025 11:05:56 -0800
 Subject: [PATCH] ggml: Export GPU UUIDs
 
 ---
- ggml/include/ggml-backend.h        |  1 +
- ggml/src/ggml-cuda/ggml-cuda.cu    | 67 +++++++++++++++++++++++++++---
+ ggml/include/ggml-backend.h        |  2 +
+ ggml/src/ggml-cuda/ggml-cuda.cu    | 72 +++++++++++++++++++++++++++---
  ggml/src/ggml-metal/ggml-metal.cpp |  1 +
- 3 files changed, 63 insertions(+), 6 deletions(-)
+ 3 files changed, 69 insertions(+), 6 deletions(-)
 
 diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
-index a7ebe5dcd..03557bb31 100644
+index 393c329be..99412fe56 100644
 --- a/ggml/include/ggml-backend.h
 +++ b/ggml/include/ggml-backend.h
-@@ -158,6 +158,7 @@ extern "C" {
+@@ -158,6 +158,8 @@ extern "C" {
          const char * description;
          // device free memory in bytes
          size_t memory_free;
++        // device UUID
 +        const char * id;
          // device total memory in bytes
          size_t memory_total;
          // device type
 diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
-index 6519af435..c9d3a2b03 100644
+index 9625513fa..3f56beab8 100644
 --- a/ggml/src/ggml-cuda/ggml-cuda.cu
 +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
-@@ -189,6 +189,51 @@ static int ggml_cuda_parse_id(char devName[]) {
+@@ -192,6 +192,51 @@ static int ggml_cuda_parse_id(char devName[]) {
  }
  #endif // defined(GGML_USE_HIP)
  
@@ -77,7 +78,7 @@ index 6519af435..c9d3a2b03 100644
  static ggml_cuda_device_info ggml_cuda_init() {
      ggml_cuda_device_info info = {};
  
-@@ -255,22 +300,24 @@ static ggml_cuda_device_info ggml_cuda_init() {
+@@ -256,22 +301,29 @@ static ggml_cuda_device_info ggml_cuda_init() {
                  info.devices[id].cc += prop.minor * 0x10;
              }
          }
@@ -102,21 +103,26 @@ index 6519af435..c9d3a2b03 100644
          info.devices[id].cc = 100*prop.major + 10*prop.minor;
 -        GGML_LOG_INFO("  Device %d: %s, compute capability %d.%d, VMM: %s\n",
 -                        id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
++#ifdef __CUDA_ARCH_LIST__
++        if (std::getenv("GGML_CUDA_INIT") != NULL) {
++            GGML_ASSERT(ggml_cuda_has_arch(info.devices[id].cc) && "ggml was not compiled with support for this arch");
++        }
++#endif // defined(__CUDA_ARCH_LIST__)
 +        GGML_LOG_INFO("  Device %d: %s, compute capability %d.%d, VMM: %s, ID: %s\n",
 +                        id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no",
 +                        ggml_cuda_parse_uuid(prop, id).c_str());
          std::string device_name(prop.name);
          if (device_name == "NVIDIA GeForce MX450") {
              turing_devices_without_mma.push_back({ id, device_name });
-@@ -4110,6 +4157,7 @@ struct ggml_backend_cuda_device_context {
+@@ -4331,6 +4383,7 @@ struct ggml_backend_cuda_device_context {
      std::string name;
      std::string description;
      std::string pci_bus_id;
 +    std::string id;
+     int op_offload_min_batch_size;
  };
  
- static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
-@@ -4198,6 +4246,11 @@ static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_k
+@@ -4420,6 +4473,11 @@ static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_k
  }
  #endif // defined(__linux__)
  
@@ -128,7 +134,7 @@ index 6519af435..c9d3a2b03 100644
  static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
      ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
      ggml_cuda_set_device(ctx->device);
-@@ -4238,6 +4291,7 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
+@@ -4460,6 +4518,7 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
  
      props->name        = ggml_backend_cuda_device_get_name(dev);
      props->description = ggml_backend_cuda_device_get_description(dev);
@@ -136,7 +142,7 @@ index 6519af435..c9d3a2b03 100644
      props->type        = ggml_backend_cuda_device_get_type(dev);
      props->device_id   = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
      ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
-@@ -4834,6 +4888,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
+@@ -5072,6 +5131,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
                  cudaDeviceProp prop;
                  CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
                  dev_ctx->description = prop.name;
@@ -145,10 +151,10 @@ index 6519af435..c9d3a2b03 100644
                  char pci_bus_id[16] = {};
                  snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
 diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp
-index f2b7fe692..8fc1c2fb5 100644
+index 9f20b9a25..8d2c93a95 100644
 --- a/ggml/src/ggml-metal/ggml-metal.cpp
 +++ b/ggml/src/ggml-metal/ggml-metal.cpp
-@@ -547,6 +547,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen
+@@ -665,6 +665,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen
  static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
      props->name        = ggml_backend_metal_device_get_name(dev);
      props->description = ggml_backend_metal_device_get_description(dev);
diff --git a/llama/patches/0016-add-C-API-for-mtmd_input_text.patch b/llama/patches/0015-add-C-API-for-mtmd_input_text.patch
similarity index 86%
rename from llama/patches/0016-add-C-API-for-mtmd_input_text.patch
rename to llama/patches/0015-add-C-API-for-mtmd_input_text.patch
index 8205e2cb800..d5962a1f9db 100644
--- a/llama/patches/0016-add-C-API-for-mtmd_input_text.patch
+++ b/llama/patches/0015-add-C-API-for-mtmd_input_text.patch
@@ -10,11 +10,11 @@ Signed-off-by: Gabe Goodhart 
  2 files changed, 13 insertions(+)
 
 diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp
-index 2638fe4fc..c4e905a4e 100644
+index 8ca979c86..03fcb32e7 100644
 --- a/tools/mtmd/mtmd.cpp
 +++ b/tools/mtmd/mtmd.cpp
-@@ -87,6 +87,16 @@ enum mtmd_slice_tmpl {
-     MTMD_SLICE_TMPL_IDEFICS3,
+@@ -88,6 +88,16 @@ enum mtmd_slice_tmpl {
+     MTMD_SLICE_TMPL_LFM2,
  };
  
 +mtmd_input_text* mtmd_input_text_init(const char * text, bool add_special, bool parse_special) {
@@ -31,10 +31,10 @@ index 2638fe4fc..c4e905a4e 100644
      return "<__media__>";
  }
 diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h
-index 9f7e861e9..72cec1937 100644
+index ef25d32bb..a4a45b299 100644
 --- a/tools/mtmd/mtmd.h
 +++ b/tools/mtmd/mtmd.h
-@@ -80,6 +80,9 @@ typedef struct mtmd_input_chunk  mtmd_input_chunk;
+@@ -83,6 +83,9 @@ typedef struct mtmd_input_chunk  mtmd_input_chunk;
  typedef struct mtmd_input_chunks mtmd_input_chunks;
  typedef struct mtmd_input_text   mtmd_input_text;
  
diff --git a/llama/patches/0017-no-power-throttling-win32-with-gnuc.patch b/llama/patches/0016-no-power-throttling-win32-with-gnuc.patch
similarity index 90%
rename from llama/patches/0017-no-power-throttling-win32-with-gnuc.patch
rename to llama/patches/0016-no-power-throttling-win32-with-gnuc.patch
index 010d609e266..f8d5d1c70aa 100644
--- a/llama/patches/0017-no-power-throttling-win32-with-gnuc.patch
+++ b/llama/patches/0016-no-power-throttling-win32-with-gnuc.patch
@@ -8,10 +8,10 @@ Subject: [PATCH] no power throttling win32 with gnuc
  1 file changed, 1 insertion(+), 1 deletion(-)
 
 diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
-index 53891a91f..8d4851312 100644
+index 2aa324852..b4938ea0b 100644
 --- a/ggml/src/ggml-cpu/ggml-cpu.c
 +++ b/ggml/src/ggml-cpu/ggml-cpu.c
-@@ -2479,7 +2479,7 @@ static bool ggml_thread_apply_priority(int32_t prio) {
+@@ -2482,7 +2482,7 @@ static bool ggml_thread_apply_priority(int32_t prio) {
          // Newer Windows 11 versions aggresively park (offline) CPU cores and often place
          // all our threads onto the first 4 cores which results in terrible performance with
          // n_threads > 4
diff --git a/llama/patches/0017-ggml-Add-batch-size-hint-to-graph_compute.patch b/llama/patches/0017-ggml-Add-batch-size-hint-to-graph_compute.patch
new file mode 100644
index 00000000000..7e439c0932d
--- /dev/null
+++ b/llama/patches/0017-ggml-Add-batch-size-hint-to-graph_compute.patch
@@ -0,0 +1,16764 @@
+From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
+From: jmorganca 
+Date: Sat, 10 Jan 2026 15:36:34 -0800
+Subject: [PATCH] ggml: Add batch size hint to graph_compute
+
+This adds a batch_size parameter to backend graph_compute functions
+to provide optimization hints for processing.
+---
+ ggml/include/ggml-backend.h               |     5 +-
+ ggml/src/ggml-backend-impl.h              |     4 +-
+ ggml/src/ggml-backend.cpp                 |    19 +-
+ ggml/src/ggml-blas/ggml-blas.cpp          |     3 +-
+ ggml/src/ggml-cann/ggml-cann.cpp          |     4 +-
+ ggml/src/ggml-cpu/ggml-cpu.cpp            |     4 +-
+ ggml/src/ggml-cuda/ggml-cuda.cu           |     4 +-
+ ggml/src/ggml-hexagon/ggml-hexagon.cpp    |     4 +-
+ ggml/src/ggml-metal/ggml-metal.cpp        |     4 +-
+ ggml/src/ggml-opencl/ggml-opencl.cpp      |     4 +-
+ ggml/src/ggml-rpc/ggml-rpc.cpp            |     4 +-
+ ggml/src/ggml-sycl/ggml-sycl.cpp          |     4 +-
+ ggml/src/ggml-vulkan/ggml-vulkan.cpp      |     3 +-
+ ggml/src/ggml-vulkan/ggml-vulkan.cpp.orig | 16394 ++++++++++++++++++++
+ ggml/src/ggml-webgpu/ggml-webgpu.cpp      |     3 +-
+ ggml/src/ggml-zdnn/ggml-zdnn.cpp          |     3 +-
+ ggml/src/ggml-zendnn/ggml-zendnn.cpp      |     4 +-
+ 17 files changed, 16448 insertions(+), 22 deletions(-)
+ create mode 100644 ggml/src/ggml-vulkan/ggml-vulkan.cpp.orig
+
+diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
+index 99412fe56..97f630faa 100644
+--- a/ggml/include/ggml-backend.h
++++ b/ggml/include/ggml-backend.h
+@@ -98,7 +98,7 @@ extern "C" {
+ 
+     GGML_API enum ggml_status ggml_backend_graph_plan_compute (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+     GGML_API enum ggml_status ggml_backend_graph_compute      (ggml_backend_t backend, struct ggml_cgraph * cgraph);
+-    GGML_API enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph);
++    GGML_API enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph, int batch_size);
+ 
+     // NOTE: will be removed, use device version instead
+     GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op);
+@@ -308,6 +308,9 @@ extern "C" {
+     GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload);
+     GGML_API void                 ggml_backend_sched_free(ggml_backend_sched_t sched);
+ 
++    // Provide a hint on the batch size to optimize processing (uses heuristics if unset)
++    GGML_API void                 ggml_backend_sched_set_batch_size(ggml_backend_sched_t sched, int batch_size);
++
+     // Initialize backend buffers from a measure graph
+     GGML_API void                 ggml_backend_sched_reserve_size(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph, size_t * sizes);
+     GGML_API bool                 ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); // returns success
+diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h
+index 59190b7c4..a833756f9 100644
+--- a/ggml/src/ggml-backend-impl.h
++++ b/ggml/src/ggml-backend-impl.h
+@@ -106,8 +106,8 @@ extern "C" {
+         // compute the graph with the plan
+         enum ggml_status          (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+ 
+-        // compute graph (always async if supported by the backend)
+-        enum ggml_status          (*graph_compute)     (ggml_backend_t backend, struct ggml_cgraph * cgraph);
++        // compute graph (always async if supported by the backend). batch_size may be -1 if unknown
++        enum ggml_status          (*graph_compute)     (ggml_backend_t backend, struct ggml_cgraph * cgraph, int batch_size);
+ 
+         // (optional) event synchronization
+         // record an event on this stream
+diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
+index c1f4f63ba..761bd12df 100644
+--- a/ggml/src/ggml-backend.cpp
++++ b/ggml/src/ggml-backend.cpp
+@@ -355,14 +355,14 @@ enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_ba
+ }
+ 
+ enum ggml_status ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+-    enum ggml_status err = ggml_backend_graph_compute_async(backend, cgraph);
++    enum ggml_status err = ggml_backend_graph_compute_async(backend, cgraph, -1);
+     ggml_backend_synchronize(backend);
+     return err;
+ }
+ 
+-enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
++enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph, int batch_size) {
+     GGML_ASSERT(backend);
+-    return backend->iface.graph_compute(backend, cgraph);
++    return backend->iface.graph_compute(backend, cgraph, batch_size);
+ }
+ 
+ bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
+@@ -729,6 +729,8 @@ struct ggml_backend_sched {
+ 
+     bool op_offload;
+ 
++    int batch_size; // a hint on the batch size to optimize processing, -1 to use heuristics
++
+     int debug;
+ 
+     // used for debugging graph reallocations [GGML_SCHED_DEBUG_REALLOC]
+@@ -827,7 +829,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
+         if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
+             int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor);
+             // check if a backend with higher prio wants to offload the op
+-            if (sched->op_offload && src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) {
++            if (sched->op_offload && (sched->batch_size < 0 || sched->batch_size >= 32) && src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) {
+                 for (int b = 0; b < src_backend_id; b++) {
+                     if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) {
+                         SET_CAUSE(tensor, "1.off");
+@@ -1579,7 +1581,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
+         }
+ 
+         if (!sched->callback_eval) {
+-            enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph);
++            enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph, sched->batch_size);
+             if (ec != GGML_STATUS_SUCCESS) {
+                 return ec;
+             }
+@@ -1601,7 +1603,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
+ 
+                 struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1);
+ 
+-                enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &gv);
++                enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &gv, sched->batch_size);
+                 if (ec != GGML_STATUS_SUCCESS) {
+                     return ec;
+                 }
+@@ -1691,12 +1693,17 @@ ggml_backend_sched_t ggml_backend_sched_new(
+ 
+     sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends);
+     sched->op_offload = op_offload;
++    sched->batch_size = -1;
+ 
+     ggml_backend_sched_reset(sched);
+ 
+     return sched;
+ }
+ 
++void ggml_backend_sched_set_batch_size(ggml_backend_sched_t sched, int batch_size) {
++    sched->batch_size = batch_size;
++}
++
+ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
+     if (sched == NULL) {
+         return;
+diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp
+index 2e9ddf224..6a399bdb1 100644
+--- a/ggml/src/ggml-blas/ggml-blas.cpp
++++ b/ggml/src/ggml-blas/ggml-blas.cpp
+@@ -220,7 +220,7 @@ static void ggml_backend_blas_free(ggml_backend_t backend) {
+     delete backend;
+ }
+ 
+-static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
++static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph, int batch_size) {
+     ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
+ 
+     for (int i = 0; i < cgraph->n_nodes; i++) {
+@@ -254,6 +254,7 @@ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend,
+     return GGML_STATUS_SUCCESS;
+ 
+     GGML_UNUSED(backend);
++    GGML_UNUSED(batch_size);
+ }
+ 
+ static struct ggml_backend_i blas_backend_i = {
+diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp
+index 90a15f217..02a781c31 100644
+--- a/ggml/src/ggml-cann/ggml-cann.cpp
++++ b/ggml/src/ggml-cann/ggml-cann.cpp
+@@ -2187,8 +2187,10 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
+  * @return enum ggml_status Returns GGML_STATUS_SUCCESS if computation
+  *         completes successfully, otherwise an appropriate error status.
+  */
+-static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
++static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) {
+     ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
++
++    GGML_UNUSED(batch_size);
+     ggml_cann_set_device(cann_ctx->device);
+     g_nz_workspaces[cann_ctx->device].clear();
+ 
+diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp
+index ddf1737a3..622cf5d24 100644
+--- a/ggml/src/ggml-cpu/ggml-cpu.cpp
++++ b/ggml/src/ggml-cpu/ggml-cpu.cpp
+@@ -167,7 +167,7 @@ static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backe
+     GGML_UNUSED(backend);
+ }
+ 
+-static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
++static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph, int batch_size) {
+     struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
+ 
+     struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool);
+@@ -188,6 +188,8 @@ static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, s
+     cplan.use_ref             = cpu_ctx->use_ref;
+ 
+     return ggml_graph_compute(cgraph, &cplan);
++
++    GGML_UNUSED(batch_size);
+ }
+ 
+ static const struct ggml_backend_i ggml_backend_cpu_i = {
+diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
+index 3f56beab8..0ab859d3c 100644
+--- a/ggml/src/ggml-cuda/ggml-cuda.cu
++++ b/ggml/src/ggml-cuda/ggml-cuda.cu
+@@ -3970,9 +3970,11 @@ static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, co
+ }
+ #endif // USE_CUDA_GRAPH
+ 
+-static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
++static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) {
+     ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
+ 
++    GGML_UNUSED(batch_size);
++
+     ggml_cuda_set_device(cuda_ctx->device);
+ 
+     bool use_cuda_graph             = false;
+diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp
+index 7a44443a8..6d46006b5 100644
+--- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp
++++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp
+@@ -2489,9 +2489,11 @@ static inline int last_compute_op(ggml_cgraph * graph) {
+     return last;
+ }
+ 
+-static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) {
++static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, ggml_cgraph * graph, int batch_size) {
+     auto sess = static_cast(backend->context);
+ 
++    GGML_UNUSED(batch_size);
++
+     HEX_VERBOSE("ggml-hex: %s graph-compute n_nodes %d\n", sess->name.c_str(), graph->n_nodes);
+ 
+     const int last = last_compute_op(graph);
+diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp
+index 8d2c93a95..ba12c7c14 100644
+--- a/ggml/src/ggml-metal/ggml-metal.cpp
++++ b/ggml/src/ggml-metal/ggml-metal.cpp
+@@ -526,9 +526,11 @@ static bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml
+     return ggml_metal_cpy_tensor_async(ctx_src, ctx_dst, src, dst);
+ }
+ 
+-static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
++static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) {
+     ggml_metal_t ctx = (ggml_metal_t)backend->context;
+ 
++    GGML_UNUSED(batch_size);
++
+     return ggml_metal_graph_compute(ctx, cgraph);
+ }
+ 
+diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp
+index d7742fa13..28c251129 100644
+--- a/ggml/src/ggml-opencl/ggml-opencl.cpp
++++ b/ggml/src/ggml-opencl/ggml-opencl.cpp
+@@ -3316,9 +3316,11 @@ static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor *
+ static void ggml_opencl_op_norm_fused(ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor);
+ static void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor);
+ 
+-static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
++static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) {
+     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+ 
++    GGML_UNUSED(batch_size);
++
+     for (int i = 0; i < cgraph->n_nodes; i++) {
+         ggml_tensor * node = cgraph->nodes[i];
+ 
+diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp
+index 281fa1bdb..b5f7adf89 100644
+--- a/ggml/src/ggml-rpc/ggml-rpc.cpp
++++ b/ggml/src/ggml-rpc/ggml-rpc.cpp
+@@ -866,9 +866,11 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve
+     memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
+ }
+ 
+-static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
++static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) {
+     ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
+ 
++    GGML_UNUSED(batch_size);
++
+     GGML_ASSERT(cgraph->n_nodes > 0);
+     bool reuse = rpc_ctx->gc.is_cached(cgraph);
+     if (reuse) {
+diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp
+index 336172700..9e1e2510e 100644
+--- a/ggml/src/ggml-sycl/ggml-sycl.cpp
++++ b/ggml/src/ggml-sycl/ggml-sycl.cpp
+@@ -4374,9 +4374,11 @@ static bool check_graph_compatibility(ggml_cgraph * cgraph) {
+ }
+ #endif
+ 
+-static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
++static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) {
+     auto * sycl_ctx = static_cast(backend->context);
+ 
++    GGML_UNUSED(batch_size);
++
+ #ifdef GGML_SYCL_GRAPH
+     bool use_sycl_graph = !g_ggml_sycl_disable_graph && check_graph_compatibility(cgraph);
+     if (use_sycl_graph) {
+diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+index c9ca7a986..008b82e9b 100644
+--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
++++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+@@ -14063,8 +14063,9 @@ static int32_t find_first_set(uint32_t x) {
+     return ret;
+ }
+ 
+-static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
++static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) {
+     VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
++    GGML_UNUSED(batch_size);
+     ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
+ 
+     if (vk_instance.debug_utils_support) {
+diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp.orig b/ggml/src/ggml-vulkan/ggml-vulkan.cpp.orig
+new file mode 100644
+index 000000000..a07de51d9
+--- /dev/null
++++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp.orig
+@@ -0,0 +1,16394 @@
++#include "ggml-vulkan.h"
++#include 
++#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_CHECK_RESULTS)
++#include 
++#include "ggml-cpu.h"
++#endif
++
++// See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-
++#define VULKAN_HPP_DISPATCH_LOADER_DYNAMIC 1
++// We use VULKAN_HPP_DEFAULT_DISPATCHER, but not VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE
++// to avoid conflicts with applications or other libraries who might use it.
++#if VK_HEADER_VERSION >= 301
++namespace vk::detail { class DispatchLoaderDynamic; }
++using vk::detail::DispatchLoaderDynamic;
++#else
++namespace vk { class DispatchLoaderDynamic; }
++using vk::DispatchLoaderDynamic;
++#endif
++DispatchLoaderDynamic & ggml_vk_default_dispatcher();
++#define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher()
++
++#include 
++
++#include 
++#include 
++#include 
++#include 
++#include 
++#include 
++#include 
++#include 
++#include 
++#include 
++#include 
++#include 
++#include 
++#include 
++#include 
++#include 
++#include 
++
++#if defined(_MSC_VER)
++# define NOMINMAX 1
++# include 
++# define YIELD() YieldProcessor()
++#elif defined(__clang__) || defined(__GNUC__)
++# if defined(__x86_64__) ||defined(__i386__)
++#  include 
++#  define YIELD() _mm_pause()
++# elif defined(__arm__) || defined(__aarch64__)
++#  if defined(__clang__)
++#   include 
++#   define YIELD() __yield()
++#  else
++#   define YIELD() asm volatile("yield")
++#  endif
++# endif
++#endif
++
++#if !defined(YIELD)
++#define YIELD()
++#endif
++
++#include "ggml-impl.h"
++#include "ggml-backend-impl.h"
++
++#include "ggml-vulkan-shaders.hpp"
++
++// remove this once it's more widely available in the SDK
++#if !defined(VK_KHR_shader_bfloat16)
++
++#define VK_KHR_shader_bfloat16 1
++#define VK_KHR_SHADER_BFLOAT16_SPEC_VERSION                          1
++#define VK_KHR_SHADER_BFLOAT16_EXTENSION_NAME                        "VK_KHR_shader_bfloat16"
++#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR ((VkStructureType)1000141000)
++#define VK_COMPONENT_TYPE_BFLOAT16_KHR                               ((VkComponentTypeKHR)1000141000)
++
++typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR {
++    VkStructureType                       sType;
++    void*                                 pNext;
++    VkBool32                              shaderBFloat16Type;
++    VkBool32                              shaderBFloat16DotProduct;
++    VkBool32                              shaderBFloat16CooperativeMatrix;
++} VkPhysicalDeviceShaderBfloat16FeaturesKHR;
++#endif
++
++#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
++#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
++static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
++
++#define VK_VENDOR_ID_AMD 0x1002
++#define VK_VENDOR_ID_APPLE 0x106b
++#define VK_VENDOR_ID_INTEL 0x8086
++#define VK_VENDOR_ID_NVIDIA 0x10de
++#define VK_VENDOR_ID_QUALCOMM 0x5143
++
++#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256
++
++#define GGML_VK_MAX_NODES 8192
++
++#define VK_CHECK(err, msg)                                          \
++    do {                                                            \
++        vk::Result err_ = (err);                                    \
++        if (err_ != vk::Result::eSuccess) {                         \
++            fprintf(stderr, "ggml_vulkan: %s error %s at %s:%d\n",  \
++                #err, to_string(err_).c_str(), __FILE__, __LINE__); \
++            exit(1);                                                \
++        }                                                           \
++    } while (0)
++
++#ifdef GGML_VULKAN_DEBUG
++#define VK_LOG_DEBUG(msg) std::cerr << msg << std::endl
++#else
++#define VK_LOG_DEBUG(msg) ((void) 0)
++#endif // GGML_VULKAN_DEBUG
++
++struct ggml_backend_vk_context;
++
++#define MAX_PARAMETER_COUNT 12
++// Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT.
++#define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 3)
++
++typedef std::shared_ptr vk_pipeline;
++
++struct vk_pipeline_struct {
++    std::string name;
++    vk::ShaderModule shader_module;
++    vk::PipelineLayout layout;
++    vk::Pipeline pipeline;
++    uint32_t push_constant_size;
++    uint32_t parameter_count;
++    std::array wg_denoms;
++    uint32_t align;
++    // true if fields have been set by ggml_vk_create_pipeline
++    bool initialized {};
++    // set to true to request the pipeline is compiled
++    std::atomic needed {};
++    // set to true when the shader has been compiled
++    std::atomic compiled {};
++    // number of registers used, extracted from pipeline executable properties
++    uint32_t register_count {};
++
++#if defined(VK_EXT_shader_64bit_indexing)
++    bool is_64b_indexing {};
++#endif
++    // linked list of pipelines for multiple compilation variants.
++    // currently only used to compile a 64-bit indexing variant.
++    vk_pipeline next;
++};
++
++typedef std::weak_ptr vk_pipeline_ref;
++
++static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
++
++struct vk_matmul_pipeline_struct {
++    vk_pipeline l, m, s;
++    vk_pipeline a_l, a_m, a_s;
++    // Returns true when all unaligned pipelines are null.
++    // We only check for unaligned variants since one of the unaligned pipelines must exist
++    // while aligned pipelines are optional
++    bool is_empty() const {
++        return l == nullptr && m == nullptr && s == nullptr;
++    }
++};
++typedef std::shared_ptr vk_matmul_pipeline;
++
++struct vk_matmul_pipeline2 {
++    vk_matmul_pipeline2() {
++        f16acc = std::make_shared();
++        f32acc = std::make_shared();
++    }
++    vk_matmul_pipeline f32acc;
++    vk_matmul_pipeline f16acc;
++};
++
++struct vk_device_struct;
++typedef std::shared_ptr vk_device;
++typedef std::weak_ptr vk_device_ref;
++
++struct vk_buffer_struct;
++typedef std::shared_ptr vk_buffer;
++typedef std::weak_ptr vk_buffer_ref;
++
++struct ggml_backend_vk_buffer_type_context {
++    std::string name;
++    vk_device device;
++};
++
++struct vk_queue;
++
++// Stores command pool/buffers. There's an instance of this
++// for each (context,queue) pair and for each (device,queue) pair.
++struct vk_command_pool {
++    void init(vk_device& device, vk_queue *q_);
++    void destroy(vk::Device& device);
++
++    vk::CommandPool pool;
++    uint32_t cmd_buffer_idx;
++    std::vector cmd_buffers;
++
++    vk_queue *q;
++};
++
++// Prevent simultaneous submissions to the same queue.
++// This could be per vk_queue if we stopped having two vk_queue structures
++// sharing the same vk::Queue.
++static std::mutex queue_mutex;
++
++struct vk_queue {
++    uint32_t queue_family_index;
++    vk::Queue queue;
++
++    vk_command_pool cmd_pool;
++
++    vk::PipelineStageFlags stage_flags;
++
++    bool transfer_only;
++
++    // copy everything except the cmd_pool
++    void copyFrom(vk_queue &other) {
++        queue_family_index = other.queue_family_index;
++        queue = other.queue;
++        stage_flags = other.stage_flags;
++        transfer_only = other.transfer_only;
++    }
++};
++
++static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft);
++static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);
++static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft);
++static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft);
++static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor);
++static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = {
++    /* .get_name         = */ ggml_backend_vk_buffer_type_name,
++    /* .alloc_buffer     = */ ggml_backend_vk_buffer_type_alloc_buffer,
++    /* .get_alignment    = */ ggml_backend_vk_buffer_type_get_alignment,
++    /* .get_max_size     = */ ggml_backend_vk_buffer_type_get_max_size,
++    /* .get_alloc_size   = */ ggml_backend_vk_buffer_type_get_alloc_size,
++    /* .is_host          = */ NULL,
++};
++
++class vk_memory_logger;
++class vk_perf_logger;
++static void ggml_vk_destroy_buffer(vk_buffer& buf);
++static void ggml_vk_synchronize(ggml_backend_vk_context * ctx);
++
++static constexpr uint32_t mul_mat_vec_max_cols = 8;
++static constexpr uint32_t p021_max_gqa_ratio = 8;
++
++enum vk_device_architecture {
++    OTHER,
++    AMD_GCN,
++    AMD_RDNA1,
++    AMD_RDNA2,
++    AMD_RDNA3,
++    INTEL_XE2,
++    NVIDIA_PRE_TURING,
++    NVIDIA_TURING,
++};
++
++static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
++    vk::PhysicalDeviceProperties props = device.getProperties();
++
++    if (props.vendorID == VK_VENDOR_ID_AMD) {
++        const std::vector ext_props = device.enumerateDeviceExtensionProperties();
++
++        bool amd_shader_core_properties = false;
++        bool integer_dot_product = false;
++        bool subgroup_size_control = false;
++
++        for (const auto& properties : ext_props) {
++            if (strcmp("VK_AMD_shader_core_properties", properties.extensionName) == 0) {
++                amd_shader_core_properties = true;
++            } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0) {
++                integer_dot_product = true;
++            } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
++                subgroup_size_control = true;
++            }
++        }
++
++        if (!amd_shader_core_properties || !integer_dot_product || !subgroup_size_control) {
++            return vk_device_architecture::OTHER;
++        }
++
++        vk::PhysicalDeviceProperties2 props2;
++        vk::PhysicalDeviceShaderCorePropertiesAMD shader_core_props_amd;
++        vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props;
++        vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
++
++        props2.pNext = &shader_core_props_amd;
++        shader_core_props_amd.pNext = &integer_dot_props;
++        integer_dot_props.pNext = &subgroup_size_control_props;
++
++        device.getProperties2(&props2);
++
++        if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 64) {
++            return vk_device_architecture::AMD_GCN;
++        }
++        if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 32) {
++            // RDNA
++            if (shader_core_props_amd.wavefrontsPerSimd == 20) {
++                return vk_device_architecture::AMD_RDNA1;
++            }
++            if (integer_dot_props.integerDotProduct4x8BitPackedMixedSignednessAccelerated) {
++                return vk_device_architecture::AMD_RDNA3;
++            }
++            return vk_device_architecture::AMD_RDNA2;
++        }
++    } else if (props.vendorID == VK_VENDOR_ID_INTEL) {
++        const std::vector ext_props = device.enumerateDeviceExtensionProperties();
++
++        bool subgroup_size_control = false;
++
++        for (const auto& properties : ext_props) {
++            if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
++                subgroup_size_control = true;
++            }
++        }
++
++        if (!subgroup_size_control) {
++            return vk_device_architecture::OTHER;
++        }
++
++        vk::PhysicalDeviceProperties2 props2;
++        vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
++
++        props2.pNext = &subgroup_size_control_props;
++        device.getProperties2(&props2);
++
++        if (subgroup_size_control_props.minSubgroupSize == 16) {
++            // Xe2 architecture uses SIMD16 while previous Xe and Gen architecture uses SIMD8.
++            // Minimum subgroup size matches the SIMD width so we distinguish architecture by checking this value.
++            // https://www.intel.com/content/www/us/en/content-details/824434/2024-intel-tech-tour-xe2-and-lunar-lake-s-gpu.html
++            // https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
++            return vk_device_architecture::INTEL_XE2;
++        }
++    } else if (props.vendorID == VK_VENDOR_ID_NVIDIA) {
++        const std::vector ext_props = device.enumerateDeviceExtensionProperties();
++
++        bool cooperative_matrix = false;
++        bool sm_builtins = false;
++
++        // Detect "pre-turing" based on lack of coopmat support.
++        for (const auto& properties : ext_props) {
++            if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) {
++                cooperative_matrix = true;
++            } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
++                sm_builtins = true;
++            }
++        }
++
++        if (!cooperative_matrix) {
++            return vk_device_architecture::NVIDIA_PRE_TURING;
++        }
++
++        if (sm_builtins) {
++            vk::PhysicalDeviceProperties2 props2;
++            vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
++
++            props2.pNext = &sm_props;
++
++            device.getProperties2(&props2);
++
++            // Turing has 32, following architectures have 48
++            if (sm_props.shaderWarpsPerSM == 32) {
++                return vk_device_architecture::NVIDIA_TURING;
++            }
++        }
++    }
++    return vk_device_architecture::OTHER;
++}
++
++enum vk_conv_shapes {
++    CONV_SHAPE_128x128,
++    CONV_SHAPE_64x32,
++    CONV_SHAPE_32x256,
++    CONV_SHAPE_COUNT,
++};
++
++struct vk_conv_block_size {
++    uint32_t K;
++    uint32_t NPQ;
++    uint32_t CRS;
++};
++
++vk_conv_block_size vk_conv_block_sizes[CONV_SHAPE_COUNT] = {
++    // K   NPQ  CRS
++    { 128, 128, 16 }, // CONV_SHAPE_128x128
++    {  64,  32, 32 }, // CONV_SHAPE_64x32
++    {  32, 256, 16 }, // CONV_SHAPE_32x256
++};
++
++enum dmmv_wg_sizes {
++    DMMV_WG_SIZE_SUBGROUP,
++    DMMV_WG_SIZE_LARGE,
++    DMMV_WG_SIZE_COUNT,
++};
++
++enum FaCodePath {
++    FA_SCALAR,
++    FA_COOPMAT1,
++    FA_COOPMAT2,
++};
++
++struct vk_fa_pipeline_state {
++    uint32_t HSK, HSV;
++    uint32_t Br, Bc;
++    uint32_t D_split, row_split;
++    bool shmem_staging;
++    FaCodePath path;
++    uint32_t workgroup_size, subgroup_size;
++    bool aligned;
++    bool f32acc;
++    uint32_t flags;
++    uint32_t limit_occupancy_shmem;
++
++    bool operator<(const vk_fa_pipeline_state &b) const {
++        return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem) <
++               std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem);
++    }
++};
++
++struct vk_conv2d_pipeline_state {
++    vk_conv2d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t p0, uint32_t p1, uint32_t d0, uint32_t d1, uint32_t KW, uint32_t KH)
++        : s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), KW(KW), KH(KH) {}
++
++    uint32_t s0, s1, p0, p1, d0, d1, KW, KH;
++
++    bool operator<(const vk_conv2d_pipeline_state &b) const {
++        return std::tie(s0, s1, p0, p1, d0, d1, KW, KH) <
++               std::tie(b.s0, b.s1, b.p0, b.p1, b.d0, b.d1, b.KW, b.KH);
++    }
++};
++
++struct vk_solve_tri_pipeline_state {
++    vk_solve_tri_pipeline_state(uint32_t N, uint32_t K)
++        : N(N), K(K) {}
++
++    uint32_t N, K;
++
++    bool operator<(const vk_solve_tri_pipeline_state &b) const {
++        return std::tie(N, K) <
++               std::tie(b.N, b.K);
++    }
++};
++
++enum shader_reduction_mode {
++    SHADER_REDUCTION_MODE_SHMEM,
++    SHADER_REDUCTION_MODE_HYBRID,
++    SHADER_REDUCTION_MODE_SUBGROUP,
++    SHADER_REDUCTION_MODE_COUNT,
++};
++
++// argsort pipelines for up to 1<<10 invocations per workgroup
++static constexpr uint32_t num_argsort_pipelines = 11;
++static constexpr uint32_t num_topk_moe_pipelines = 10;
++static constexpr uint32_t num_topk_pipelines = 11;
++
++static constexpr std::initializer_list topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE,  GGML_OP_ARGSORT,
++                                                                             GGML_OP_VIEW,     GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
++                                                                             GGML_OP_SUM_ROWS, GGML_OP_CLAMP,    GGML_OP_DIV,
++                                                                             GGML_OP_RESHAPE };
++
++static constexpr std::initializer_list topk_moe_sigmoid_norm_bias{ GGML_OP_UNARY,    GGML_OP_RESHAPE,  GGML_OP_ADD,
++                                                                            GGML_OP_ARGSORT,  GGML_OP_VIEW,     GGML_OP_GET_ROWS,
++                                                                            GGML_OP_RESHAPE,  GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
++                                                                            GGML_OP_DIV,      GGML_OP_RESHAPE };
++
++static constexpr std::initializer_list topk_moe_early_softmax     { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE,  GGML_OP_ARGSORT,
++                                                                             GGML_OP_VIEW,     GGML_OP_GET_ROWS };
++
++static constexpr std::initializer_list topk_moe_late_softmax      { GGML_OP_ARGSORT,  GGML_OP_VIEW,
++                                                                             GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
++                                                                             GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
++
++//node #978 (  SOFT_MAX):     ffn_moe_probs-15 (   0K) [Vulka         ] use=2:    ffn_moe_logits-15 (   0K) [Vulka         ]
++//node #979 (   RESHAPE): ffn_moe_probs-15 (re (   0K) [Vulka         ] use=1:     ffn_moe_probs-15 (   0K) [Vulka         ]
++//node #980 (   ARGSORT):   ffn_moe_argsort-15 (   0K) [Vulka         ] use=1:     ffn_moe_probs-15 (   0K) [Vulka         ]
++//node #981 (      VIEW):      ffn_moe_topk-15 (   0K) [Vulka         ] use=4:   ffn_moe_argsort-15 (   0K) [Vulka         ]
++//node #982 (  GET_ROWS):   ffn_moe_weights-15 (   0K) [Vulka         ] use=1: ffn_moe_probs-15 (re (   0K) [Vulka         ]      ffn_moe_topk-15 (   0K) [Vulka         ]
++//node #983 (   RESHAPE): ffn_moe_weights-15 ( (   0K) [Vulka         ] use=2:   ffn_moe_weights-15 (   0K) [Vulka         ]
++//node #984 (  SUM_ROWS): ffn_moe_weights_sum- (   0K) [Vulka         ] use=1: ffn_moe_weights-15 ( (   0K) [Vulka         ]
++//node #985 (     CLAMP): ffn_moe_weights_sum_ (   0K) [Vulka         ] use=1: ffn_moe_weights_sum- (   0K) [Vulka         ]
++//node #986 (       DIV): ffn_moe_weights_norm (   0K) [Vulka         ] use=1: ffn_moe_weights-15 ( (   0K) [Vulka         ] ffn_moe_weights_sum_ (   0K) [Vulka         ]
++//node #987 (   RESHAPE): ffn_moe_weights_norm (   0K) [Vulka         ] use=1: ffn_moe_weights_norm (   0K) [Vulka         ]
++static constexpr std::initializer_list> topk_moe_early_softmax_norm_edges {
++    { 1, 0, 0 }, // reshape->src[0]  == softmax
++    { 2, 0, 0 }, // argsort->src[0]  == softmax
++    { 3, 0, 2 }, // view->src[0]     == argsort
++    { 4, 0, 1 }, // get_rows->src[0] == reshape
++    { 4, 1, 3 }, // get_rows->src[1] == view
++    { 5, 0, 4 }, // reshape->src[0]  == get_rows
++    { 6, 0, 5 }, // sum_rows->src[0] == reshape
++    { 7, 0, 6 }, // clamp->src[0]    == sum_rows
++    { 8, 0, 5 }, // div->src[0]      == reshape
++    { 8, 1, 7 }, // div->src[1]      == clamp
++    { 9, 0, 8 }, // reshape->src[0]  == div
++};
++
++//node #436 (     UNARY):     ffn_moe_probs-10 ( 256K) [Vulka         ] use=2:    ffn_moe_logits-10 ( 256K) [Vulka         ]
++//node #437 (   RESHAPE): ffn_moe_probs-10 (re ( 256K) [Vulka         ] use=1:     ffn_moe_probs-10 ( 256K) [Vulka         ]
++//node #438 (       ADD): ffn_moe_probs_biased ( 256K) [Vulka         ] use=1:     ffn_moe_probs-10 ( 256K) [Vulka         ] blk.10.exp_probs_b.b (   0K) [Vulka         ]
++//node #439 (   ARGSORT):   ffn_moe_argsort-10 ( 256K) [Vulka         ] use=1: ffn_moe_probs_biased ( 256K) [Vulka         ]
++//node #440 (      VIEW):      ffn_moe_topk-10 ( 255K) [Vulka         ] use=3:   ffn_moe_argsort-10 ( 256K) [Vulka         ]
++//node #441 (  GET_ROWS):   ffn_moe_weights-10 (  12K) [Vulka         ] use=1: ffn_moe_probs-10 (re ( 256K) [Vulka         ]      ffn_moe_topk-10 ( 255K) [Vulka         ]
++//node #442 (   RESHAPE): ffn_moe_weights-10 ( (  12K) [Vulka         ] use=2:   ffn_moe_weights-10 (  12K) [Vulka         ]
++//node #443 (  SUM_ROWS): ffn_moe_weights_sum- (   2K) [Vulka         ] use=1: ffn_moe_weights-10 ( (  12K) [Vulka         ]
++//node #444 (     CLAMP): ffn_moe_weights_sum_ (   2K) [Vulka         ] use=1: ffn_moe_weights_sum- (   2K) [Vulka         ]
++//node #445 (       DIV): ffn_moe_weights_norm (  12K) [Vulka         ] use=1: ffn_moe_weights-10 ( (  12K) [Vulka         ] ffn_moe_weights_sum_ (   2K) [Vulka         ]
++//node #446 (   RESHAPE): ffn_moe_weights_norm (  12K) [Vulka         ] use=1: ffn_moe_weights_norm (  12K) [Vulka         ]
++static constexpr std::initializer_list> topk_moe_sigmoid_norm_bias_edges {
++    { 1, 0, 0 }, // reshape->src[0]  == sigmoid
++    { 2, 0, 0 }, // add->src[0]      == sigmoid
++    { 3, 0, 2 }, // argsort->src[0]  == add
++    { 4, 0, 3 }, // view->src[0]     == argsort
++    { 5, 0, 1 }, // get_rows->src[0] == reshape
++    { 5, 1, 4 }, // get_rows->src[1] == view
++    { 6, 0, 5 }, // reshape->src[0]  == get_rows
++    { 7, 0, 6 }, // sum_rows->src[0] == reshape
++    { 8, 0, 7 }, // clamp->src[0]    == sum_rows
++    { 9, 0, 6 }, // div->src[0]      == reshape
++    { 9, 1, 8 }, // div->src[1]      == clamp
++    {10, 0, 9 }, // reshape->src[0]  == div
++};
++
++// same as early_softmax_norm but ending after the get_rows
++static constexpr std::initializer_list> topk_moe_early_softmax_edges {
++    { 1, 0, 0 }, // reshape->src[0]  == softmax
++    { 2, 0, 0 }, // argsort->src[0]  == softmax
++    { 3, 0, 2 }, // view->src[0]     == argsort
++    { 4, 0, 1 }, // get_rows->src[0] == reshape
++    { 4, 1, 3 }, // get_rows->src[1] == view
++};
++
++//node #652 (   ARGSORT):   ffn_moe_argsort-11 (   0K) [Vulka         ] use=1:     ffn_moe_probs-11 (   0K) [Vulka         ]
++//node #653 (      VIEW):      ffn_moe_topk-11 (   0K) [Vulka         ] use=7:   ffn_moe_argsort-11 (   0K) [Vulka         ]
++//node #654 (  GET_ROWS):   ffn_moe_weights-11 (   0K) [Vulka         ] use=1: ffn_moe_probs-11 (re (   0K) [Vulka         ]      ffn_moe_topk-11 (   0K) [Vulka         ]
++//node #655 (   RESHAPE): ffn_moe_weights-11 ( (   0K) [Vulka         ] use=1:   ffn_moe_weights-11 (   0K) [Vulka         ]
++//node #656 (  SOFT_MAX):             node_656 (   0K) [Vulka         ] use=1: ffn_moe_weights-11 ( (   0K) [Vulka         ]
++//node #657 (   RESHAPE): ffn_moe_weights_soft (   0K) [Vulka         ] use=1:             node_656 (   0K) [Vulka         ]
++static constexpr std::initializer_list> topk_moe_late_softmax_edges {
++    { 1, 0, 0 }, // view->src[0]     == argsort
++    { 2, 1, 1 }, // get_rows->src[1] == view
++    { 3, 0, 2 }, // reshape->src[0]  == get_rows
++    { 4, 0, 3 }, // soft_max->src[0] == reshape
++    { 5, 0, 4 }, // reshape->src[0]  == soft_max
++};
++
++enum topk_moe_mode {
++    TOPK_MOE_EARLY_SOFTMAX,
++    TOPK_MOE_EARLY_SOFTMAX_NORM,
++    TOPK_MOE_LATE_SOFTMAX,
++    TOPK_MOE_SIGMOID_NORM_BIAS,
++    TOPK_MOE_COUNT,
++};
++
++static constexpr std::initializer_list> rope_view_set_rows_edges {
++    { 1, 0, 0 }, // view->src[0]     == rope
++    { 2, 0, 1 }, // set_rows->src[0] == view
++};
++
++static constexpr std::initializer_list> rms_norm_mul_rope_view_set_rows_edges {
++    { 1, 0, 0 }, // mul->src[0]      == rms
++    { 2, 0, 1 }, // rope->src[0]     == mul
++    { 3, 0, 2 }, // view->src[0]     == rope
++    { 4, 0, 3 }, // set_rows->src[0] == view
++};
++
++
++struct vk_device_struct {
++    std::recursive_mutex mutex;
++
++    vk::PhysicalDevice physical_device;
++    vk::PhysicalDeviceProperties properties;
++    std::string name;
++    uint64_t max_memory_allocation_size;
++    uint64_t max_buffer_size;
++    uint64_t suballocation_block_size;
++    uint64_t min_imported_host_pointer_alignment;
++    bool external_memory_host {};
++    bool fp16;
++    bool bf16;
++    bool pipeline_robustness;
++    bool memory_priority;
++    vk::Device device;
++    uint32_t vendor_id;
++    vk::DriverId driver_id;
++    vk_device_architecture architecture;
++    vk_queue compute_queue;
++    vk_queue transfer_queue;
++    bool single_queue;
++    bool support_async;
++    uint32_t subgroup_size;
++    uint32_t subgroup_size_log2;
++    uint32_t shader_core_count;
++    bool uma;
++    bool prefer_host_memory;
++    bool float_controls_rte_fp16;
++    bool subgroup_basic;
++    bool subgroup_arithmetic;
++    bool subgroup_shuffle;
++    bool subgroup_ballot;
++    bool subgroup_clustered;
++    bool subgroup_vote;
++    bool multi_add;
++    bool shader_int64;
++    bool buffer_device_address;
++    bool vulkan_memory_model;
++
++    bool add_rms_fusion;
++    uint32_t partials_binding_alignment;
++
++    bool shader_64b_indexing;
++
++    bool integer_dot_product;
++    // 0: default, 1: force mmvq, -1: disable mmvq
++    int32_t mmvq_mode;
++
++    bool subgroup_size_control;
++    uint32_t subgroup_min_size;
++    uint32_t subgroup_max_size;
++    bool subgroup_require_full_support;
++
++    // floor(log2(maxComputeWorkGroupInvocations))
++    uint32_t max_workgroup_size_log2 {};
++
++    bool flash_attention_fp16;
++
++    bool coopmat_support;
++    bool coopmat_acc_f32_support {};
++    bool coopmat_acc_f16_support {};
++    bool coopmat_bf16_support {};
++    bool coopmat_support_16x16x16_f16acc {};
++    bool coopmat_support_16x16x16_f32acc {};
++    bool coopmat1_fa_support {};
++    uint32_t coopmat_m;
++    uint32_t coopmat_n;
++    uint32_t coopmat_k;
++
++    bool coopmat_int_support;
++    uint32_t coopmat_int_m;
++    uint32_t coopmat_int_n;
++    uint32_t coopmat_int_k;
++
++    bool coopmat2;
++
++    bool pipeline_executable_properties_support {};
++
++    size_t idx;
++
++    bool mul_mat_l[GGML_TYPE_COUNT];
++    bool mul_mat_m[GGML_TYPE_COUNT];
++    bool mul_mat_s[GGML_TYPE_COUNT];
++    bool mul_mat_id_l[GGML_TYPE_COUNT];
++    bool mul_mat_id_m[GGML_TYPE_COUNT];
++    bool mul_mat_id_s[GGML_TYPE_COUNT];
++
++    vk::DescriptorSetLayout dsl;
++
++    vk_matmul_pipeline pipeline_matmul_f32 {};
++    vk_matmul_pipeline pipeline_matmul_f32_f16 {};
++    vk_matmul_pipeline pipeline_matmul_bf16 {};
++    vk_matmul_pipeline2 pipeline_matmul_f16;
++    vk_matmul_pipeline2 pipeline_matmul_f16_f32;
++
++    vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
++    vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
++    vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT];
++
++    vk_matmul_pipeline pipeline_matmul_id_f32 {};
++    vk_matmul_pipeline pipeline_matmul_id_bf16 {};
++    vk_matmul_pipeline2 pipeline_matmul_id_f16;
++    vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
++
++    vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
++    vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_COUNT];
++
++    vk_pipeline pipeline_matmul_split_k_reduce;
++    vk_pipeline pipeline_quantize_q8_1_x4;
++
++    vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
++    vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
++    vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
++    vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT];
++
++    vk_pipeline pipeline_dequant_mul_mat_vec_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
++    vk_pipeline pipeline_dequant_mul_mat_vec_id_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT];
++
++    vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio];
++    vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
++    vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
++    vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
++    vk_pipeline pipeline_acc_f32;
++    vk_pipeline pipeline_set_f32;
++
++    // [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16]
++    vk_pipeline pipeline_add[2][2][2];
++    vk_pipeline pipeline_add_norepeat[2][2][2];
++    vk_pipeline pipeline_sub[2][2][2];
++    vk_pipeline pipeline_sub_norepeat[2][2][2];
++    vk_pipeline pipeline_mul[2][2][2];
++    vk_pipeline pipeline_mul_norepeat[2][2][2];
++    vk_pipeline pipeline_div[2][2][2];
++    vk_pipeline pipeline_div_norepeat[2][2][2];
++    vk_pipeline pipeline_add_rms[2][2][2];
++    vk_pipeline pipeline_add_rms_norepeat[2][2][2];
++
++    // indexed by num_additional_fused_ops == num_adds - 1
++    vk_pipeline pipeline_multi_add[MAX_FUSED_ADDS];
++    vk_pipeline pipeline_multi_add_rms[MAX_FUSED_ADDS];
++
++    vk_pipeline pipeline_add_id_f32;
++
++    vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
++    vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32, pipeline_upscale_bilinear_antialias_f32;
++    vk_pipeline pipeline_scale_f32;
++    vk_pipeline pipeline_sqr_f32;
++    vk_pipeline pipeline_sqrt_f32;
++    vk_pipeline pipeline_sin_f32;
++    vk_pipeline pipeline_cos_f32;
++    vk_pipeline pipeline_log[2];
++    vk_pipeline pipeline_tri[2];
++    vk_pipeline pipeline_diag[2];
++    vk_pipeline pipeline_clamp_f32;
++    vk_pipeline pipeline_pad_f32;
++    vk_pipeline pipeline_roll_f32;
++    vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
++    vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32;
++    vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32;
++    vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
++    vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
++    vk_pipeline pipeline_cpy_transpose_16, pipeline_cpy_transpose_32;
++    vk_pipeline pipeline_set_rows_i32[GGML_TYPE_COUNT];
++    vk_pipeline pipeline_set_rows_i64[GGML_TYPE_COUNT];
++    vk_pipeline pipeline_norm_f32;
++    vk_pipeline pipeline_group_norm_f32;
++    vk_pipeline pipeline_rms_norm_f32;
++    vk_pipeline pipeline_rms_norm_mul_f32;
++    vk_pipeline pipeline_rms_norm_partials_f32;
++    vk_pipeline pipeline_rms_norm_mul_partials_f32;
++    vk_pipeline pipeline_rms_norm_mul_rope_f32_f32;
++    vk_pipeline pipeline_rms_norm_mul_rope_f32_f16;
++    vk_pipeline pipeline_rms_norm_back_f32;
++    vk_pipeline pipeline_l2_norm_f32;
++
++    // [src/dst 0=fp32,1=fp16]
++    vk_pipeline pipeline_exp[2];
++    vk_pipeline pipeline_gelu[2];
++    vk_pipeline pipeline_gelu_erf[2];
++    vk_pipeline pipeline_gelu_quick[2];
++    vk_pipeline pipeline_silu[2];
++    vk_pipeline pipeline_relu[2];
++    vk_pipeline pipeline_xielu[2];
++    vk_pipeline pipeline_neg[2];
++    vk_pipeline pipeline_tanh[2];
++    vk_pipeline pipeline_sigmoid[2];
++    vk_pipeline pipeline_hardsigmoid[2];
++    vk_pipeline pipeline_hardswish[2];
++    vk_pipeline pipeline_abs[2];
++    vk_pipeline pipeline_softplus[2];
++    vk_pipeline pipeline_step[2];
++    vk_pipeline pipeline_round[2];
++    vk_pipeline pipeline_ceil[2];
++    vk_pipeline pipeline_floor[2];
++    vk_pipeline pipeline_trunc[2];
++
++    vk_pipeline pipeline_add1_f16_f16;
++    vk_pipeline pipeline_add1_f16_f32;
++    vk_pipeline pipeline_add1_f32_f32;
++
++    vk_pipeline pipeline_arange_f32;
++
++    vk_pipeline pipeline_fill_f32;
++
++    vk_pipeline pipeline_geglu[2];
++    vk_pipeline pipeline_reglu[2];
++    vk_pipeline pipeline_swiglu[2];
++    vk_pipeline pipeline_swiglu_oai[2];
++    vk_pipeline pipeline_geglu_erf[2];
++    vk_pipeline pipeline_geglu_quick[2];
++
++    vk_pipeline pipeline_leaky_relu_f32;
++    vk_pipeline pipeline_silu_back_f32;
++    vk_pipeline pipeline_diag_mask_inf_f32;
++    vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
++    vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
++    vk_pipeline pipeline_soft_max_back_f32;
++
++    vk_pipeline pipeline_soft_max_large1_f32, pipeline_soft_max_large1_f32_f16;
++    vk_pipeline pipeline_soft_max_large2_f32, pipeline_soft_max_large2_f32_f16;
++    vk_pipeline pipeline_soft_max_large3_f32, pipeline_soft_max_large3_f32_f16;
++
++    vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16;
++    vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16;
++    vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16, pipeline_rope_multi_f32_f16;
++    vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
++    vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
++    vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
++    vk_pipeline pipeline_topk_f32[num_topk_pipelines];
++    vk_pipeline pipeline_sum_rows_f32;
++    vk_pipeline pipeline_cumsum_f32;
++    vk_pipeline pipeline_cumsum_small_f32;
++    vk_pipeline pipeline_cumsum_multipass1_f32;
++    vk_pipeline pipeline_cumsum_multipass2_f32;
++    vk_pipeline pipeline_argmax_f32;
++    vk_pipeline pipeline_count_equal_i32;
++    std::map pipeline_solve_tri_f32;
++    vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
++    vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16;
++    vk_pipeline pipeline_timestep_embedding_f32;
++    vk_pipeline pipeline_conv_transpose_1d_f32;
++    vk_pipeline pipeline_pool2d_f32;
++    vk_pipeline pipeline_rwkv_wkv6_f32;
++    vk_pipeline pipeline_rwkv_wkv7_f32;
++    vk_pipeline pipeline_ssm_scan_f32_d128;
++    vk_pipeline pipeline_ssm_scan_f32_d256;
++    vk_pipeline pipeline_ssm_conv_f32;
++    vk_pipeline pipeline_opt_step_adamw_f32;
++    vk_pipeline pipeline_opt_step_sgd_f32;
++    std::map pipeline_conv2d_f32[CONV_SHAPE_COUNT];
++    std::map pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
++    std::map pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
++    std::map pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
++    vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
++    vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
++
++    std::map pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT];
++
++    std::map, vk_pipeline> pipeline_fa_mask_opt;
++
++    vk_pipeline pipeline_flash_attn_split_k_reduce;
++    vk_pipeline pipeline_count_experts;
++
++    // [2] is for whether to take n_experts from spec constant (0) or push constant (1)
++    vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
++
++    std::vector all_pipelines;
++
++    std::vector> pinned_memory;
++
++    vk::Fence fence;
++    vk_buffer sync_staging;
++
++    ggml_backend_buffer_type buffer_type;
++
++    bool disable_fusion;
++    bool disable_host_visible_vidmem;
++    bool allow_sysmem_fallback;
++    bool disable_graph_optimize;
++
++    std::unique_ptr memory_logger;
++
++    ~vk_device_struct() {
++        VK_LOG_DEBUG("destroy device " << name);
++
++        device.destroyFence(fence);
++
++        ggml_vk_destroy_buffer(sync_staging);
++
++        compute_queue.cmd_pool.destroy(device);
++        transfer_queue.cmd_pool.destroy(device);
++
++        for (auto& pipeline : all_pipelines) {
++            if (pipeline.expired()) {
++                continue;
++            }
++
++            vk_pipeline pl = pipeline.lock();
++            ggml_vk_destroy_pipeline(device, pl);
++        }
++        all_pipelines.clear();
++
++        device.destroyDescriptorSetLayout(dsl);
++
++        device.destroy();
++    }
++};
++
++void vk_command_pool::init(vk_device& device, vk_queue *q_) {
++    cmd_buffer_idx = 0;
++    q = q_;
++
++    vk::CommandPoolCreateInfo command_pool_create_info(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), q->queue_family_index);
++    pool = device->device.createCommandPool(command_pool_create_info);
++}
++
++void vk_command_pool::destroy(vk::Device& device) {
++    device.destroyCommandPool(pool);
++    pool = nullptr;
++    cmd_buffers.clear();
++}
++
++struct vk_buffer_struct {
++    vk::Buffer buffer = VK_NULL_HANDLE;
++    vk::DeviceMemory device_memory = VK_NULL_HANDLE;
++    vk::MemoryPropertyFlags memory_property_flags;
++    void * ptr;
++    size_t size = 0;
++    vk::DeviceAddress bda_addr {};
++
++    vk_device device;
++
++    ~vk_buffer_struct() {
++        if (size == 0) {
++            return;
++        }
++        VK_LOG_DEBUG("~vk_buffer_struct(" << buffer << ", " << size << ")");
++
++        device->device.freeMemory(device_memory);
++        device->device.destroyBuffer(buffer);
++    }
++};
++
++struct vk_subbuffer {
++    vk_buffer buffer;
++    uint64_t offset;
++    uint64_t size;
++
++    operator vk::DescriptorBufferInfo() const {
++        return { buffer->buffer, offset, size };
++    }
++};
++
++// vk_event is used for the event-related backend interfaces. It uses 'event' for
++// event_wait and 'fence' for event_synchronize. Polling on an event for
++// event_synchronize wouldn't be sufficient to wait for command buffers to complete,
++// and would lead to validation errors.
++struct vk_event {
++    vk::Event event;
++    vk::Fence fence;
++};
++
++struct vk_semaphore {
++    vk::Semaphore s;
++    uint64_t value;
++};
++
++struct vk_submission {
++    vk::CommandBuffer buffer;
++    std::vector wait_semaphores;
++    std::vector signal_semaphores;
++};
++
++typedef std::vector vk_sequence;
++
++struct vk_mat_mat_push_constants {
++    uint32_t M; uint32_t N; uint32_t K;
++    uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
++    uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
++    uint32_t base_work_group_z; uint32_t num_batches;
++    uint32_t k_split;
++    uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
++    uint32_t padded_N;
++};
++
++#define MAT_VEC_FUSION_FLAGS_BIAS0 0x1
++#define MAT_VEC_FUSION_FLAGS_BIAS1 0x2
++#define MAT_VEC_FUSION_FLAGS_SCALE0 0x4
++#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8
++
++struct vk_mat_vec_push_constants {
++    uint32_t ncols;
++    uint32_t stride_a;
++    uint32_t stride_b;
++    uint32_t stride_d;
++    uint32_t batch_stride_a;
++    uint32_t batch_stride_b;
++    uint32_t batch_stride_d;
++    uint32_t fusion_flags;
++    uint32_t base_work_group_y;
++    uint32_t ne02;
++    uint32_t ne12;
++    uint32_t broadcast2;
++    uint32_t broadcast3;
++};
++
++struct vk_mat_vec_p021_push_constants {
++    uint32_t ncols_x;
++    uint32_t nrows_x;
++    uint32_t nchannels_x;
++    uint32_t nchannels_y;
++    uint32_t b_offset;
++    uint32_t d_offset;
++    uint32_t fusion_flags;
++};
++
++struct vk_mat_vec_nc_push_constants {
++    uint32_t ncols_x;
++    uint32_t nrows_x;
++    uint32_t row_stride_x;
++    uint32_t channel_stride_x;
++    uint32_t channel_stride_y;
++    uint32_t channel_x_divisor;
++    uint32_t ne12;
++    uint32_t b_offset;
++    uint32_t d_offset;
++    uint32_t nb03;
++    uint32_t nb13;
++    uint32_t nb23;
++    uint32_t fusion_flags;
++};
++
++struct vk_mat_mat_id_push_constants {
++    uint32_t M; uint32_t N; uint32_t K;
++    uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
++    uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
++    uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11;
++    uint32_t padded_N;
++};
++struct vk_mat_vec_id_push_constants {
++    uint32_t ncols;
++    uint32_t stride_a;
++    uint32_t stride_b;
++    uint32_t stride_d;
++    uint32_t batch_stride_a;
++    uint32_t batch_stride_b;
++    uint32_t batch_stride_d;
++    uint32_t fusion_flags;
++    uint32_t nei0;
++    uint32_t ne11;
++    uint32_t expert_i1;
++    uint32_t nbi1;
++};
++
++struct vk_flash_attn_push_constants {
++    uint32_t N;
++    uint32_t KV;
++
++    uint32_t ne1;
++    uint32_t ne2;
++    uint32_t ne3;
++
++    uint32_t neq2;
++    uint32_t neq3;
++    uint32_t nek2;
++    uint32_t nek3;
++    uint32_t nev2;
++    uint32_t nev3;
++    uint32_t nem1;
++    uint32_t nem2;
++    uint32_t nem3;
++
++    uint32_t nb01;
++    uint32_t nb02;
++    uint32_t nb03;
++    uint32_t nb11;
++    uint32_t nb12;
++    uint32_t nb13;
++    uint32_t nb21;
++    uint32_t nb22;
++    uint32_t nb23;
++
++    float scale;
++    float max_bias;
++    float logit_softcap;
++
++    uint32_t mask_n_head_log2;
++    float m0;
++    float m1;
++
++    uint32_t gqa_ratio;
++    uint32_t split_kv;
++    uint32_t k_num;
++};
++static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128");
++
++struct vk_op_push_constants {
++    uint32_t KX;
++    uint32_t KY;
++    float param1;
++    float param2;
++    float param3;
++    float param4;
++};
++
++struct vk_op_count_experts_push_constants {
++    uint32_t ne00;
++    uint32_t ne01;
++    uint32_t nb00;
++    uint32_t nb01;
++    uint32_t a_offset;
++};
++
++struct vk_op_glu_push_constants {
++    uint32_t N;
++    uint32_t ne00;
++    uint32_t ne20;
++    uint32_t mode;  // 0: default, 1: swapped, 2: split
++    float alpha; // for swiglu_oai
++    float limit;
++};
++
++struct vk_op_unary_push_constants {
++    uint32_t ne;
++    uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
++    uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
++    uint32_t misalign_offsets;
++    float param1; float param2;
++    uint32_t ne0_012mp; uint32_t ne0_012L;
++    uint32_t ne0_01mp;  uint32_t ne0_01L;
++    uint32_t ne0_0mp;   uint32_t ne0_0L;
++    uint32_t ne1_012mp; uint32_t ne1_012L;
++    uint32_t ne1_01mp;  uint32_t ne1_01L;
++    uint32_t ne1_0mp;   uint32_t ne1_0L;
++};
++static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128");
++
++static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst, int64_t ne = 0) {
++    GGML_ASSERT(ne != 0 || (ggml_nelements(src0) == ggml_nelements(dst)));
++    ne = ne != 0 ? ne : ggml_nelements(dst);
++    GGML_ASSERT(ne <= (int64_t)std::numeric_limits::max());
++
++    vk_op_unary_push_constants p{};
++    p.ne = (uint32_t)ne;
++
++    size_t src0_tsize = ggml_type_size(src0->type);
++    p.ne00 = (uint32_t)src0->ne[0];
++    p.ne01 = (uint32_t)src0->ne[1];
++    p.ne02 = (uint32_t)src0->ne[2];
++    p.ne03 = (uint32_t)src0->ne[3];
++    p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
++    p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
++    p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
++    p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);
++
++    size_t dst_tsize = ggml_type_size(dst->type);
++    p.ne10 = (uint32_t)dst->ne[0];
++    p.ne11 = (uint32_t)dst->ne[1];
++    p.ne12 = (uint32_t)dst->ne[2];
++    p.ne13 = (uint32_t)dst->ne[3];
++    p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
++    p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
++    p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
++    p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
++
++    return p; // offsets are initialized later in ggml_vk_op
++}
++
++struct vk_op_pad_push_constants {
++    uint32_t ne;
++    uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
++    uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
++    uint32_t misalign_offsets;
++    uint32_t circular;
++
++    uint32_t lp0; uint32_t rp0;
++    uint32_t lp1; uint32_t rp1;
++    uint32_t lp2; uint32_t rp2;
++    uint32_t lp3; uint32_t rp3;
++};
++
++static vk_op_pad_push_constants vk_op_pad_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst) {
++    int64_t ne = ggml_nelements(dst);
++    GGML_ASSERT(ne <= (int64_t)std::numeric_limits::max());
++
++    vk_op_pad_push_constants p{};
++    p.ne = (uint32_t)ne;
++
++    size_t src0_tsize = ggml_type_size(src0->type);
++    p.ne00 = (uint32_t)src0->ne[0];
++    p.ne01 = (uint32_t)src0->ne[1];
++    p.ne02 = (uint32_t)src0->ne[2];
++    p.ne03 = (uint32_t)src0->ne[3];
++    p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
++    p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
++    p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
++    p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);
++
++    size_t dst_tsize = ggml_type_size(dst->type);
++    p.ne10 = (uint32_t)dst->ne[0];
++    p.ne11 = (uint32_t)dst->ne[1];
++    p.ne12 = (uint32_t)dst->ne[2];
++    p.ne13 = (uint32_t)dst->ne[3];
++    p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
++    p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
++    p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
++    p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
++
++    p.lp0 = dst->op_params[0];
++    p.rp0 = dst->op_params[1];
++    p.lp1 = dst->op_params[2];
++    p.rp1 = dst->op_params[3];
++    p.lp2 = dst->op_params[4];
++    p.rp2 = dst->op_params[5];
++    p.lp3 = dst->op_params[6];
++    p.rp3 = dst->op_params[7];
++    p.circular = dst->op_params[8];
++
++    return p; // fastdiv values and offsets are initialized later in ggml_vk_op
++}
++
++// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
++// Precompute mp (m' in the paper) and L such that division
++// can be computed using a multiply (high 32b of 64b result)
++// and a shift:
++//
++// n/d = (mulhi(n, mp) + n) >> L;
++static void init_fastdiv_values(uint32_t d, uint32_t &mp, uint32_t &L)
++{
++    // compute L = ceil(log2(d));
++    L = 0;
++    while (L < 32 && (uint32_t{1} << L) < d) {
++        L++;
++    }
++
++    mp = (uint32_t)((uint64_t{1} << 32) * ((uint64_t{1} << L) - d) / d + 1);
++}
++
++template  void init_pushconst_fastdiv(T &p) {
++    GGML_UNUSED(p);
++    static_assert(!std::is_const::value, "unexpected type");
++}
++
++template <> void init_pushconst_fastdiv(vk_op_unary_push_constants &p) {
++    // Compute magic values to divide by these six numbers.
++    init_fastdiv_values(p.ne02*p.ne01*p.ne00,  p.ne0_012mp,    p.ne0_012L);
++    init_fastdiv_values(p.ne01*p.ne00,         p.ne0_01mp,     p.ne0_01L);
++    init_fastdiv_values(p.ne00,                p.ne0_0mp,      p.ne0_0L);
++    init_fastdiv_values(p.ne12*p.ne11*p.ne10,  p.ne1_012mp,    p.ne1_012L);
++    init_fastdiv_values(p.ne11*p.ne10,         p.ne1_01mp,     p.ne1_01L);
++    init_fastdiv_values(p.ne10,                p.ne1_0mp,      p.ne1_0L);
++}
++
++struct vk_op_binary_push_constants {
++    uint32_t ne;
++    uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
++    uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
++    uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23;
++    uint32_t misalign_offsets;
++    float param1; float param2; int32_t param3;
++};
++
++struct vk_op_multi_add_push_constants {
++    // shape for dst
++    uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23;
++
++    // strides for srcs+dst
++    uint32_t nb[MAX_PARAMETER_COUNT][4];
++
++    uint32_t rms_partials;
++};
++// update multi_add.comp if this changes
++static_assert(MAX_PARAMETER_COUNT == 12);
++static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
++
++struct vk_op_topk_moe_push_constants {
++    uint32_t n_rows;
++    uint32_t n_experts_push;
++    uint32_t n_expert_used;
++    float clamp_min;
++    float clamp_max;
++    uint32_t gating_func;
++    uint32_t has_bias;
++    uint32_t with_norm;
++    float output_scale;
++    float output_bias;
++};
++
++struct vk_op_add_id_push_constants {
++    uint32_t ne0;
++    uint32_t ne1;
++    uint32_t s01;
++    uint32_t s02;
++    uint32_t s11;
++    uint32_t s21;
++};
++
++struct vk_op_diag_mask_push_constants {
++    uint32_t ncols;
++    uint32_t rows_per_channel;
++    int32_t n_past;
++};
++
++struct vk_op_rope_push_constants {
++    uint32_t rope_mode;
++    uint32_t nrows;
++    uint32_t n_dims;
++    float freq_scale;
++    float freq_base;
++    float ext_factor;
++    float attn_factor;
++    float corr_dims[2];
++    float theta_scale;
++    uint32_t has_ff;
++    int32_t sections[4];
++    uint32_t is_imrope;
++    uint32_t is_back;
++    uint32_t set_rows_stride;
++    uint32_t ne00;
++    uint32_t ne01;
++    uint32_t ne02;
++    uint32_t nb01;
++    uint32_t nb02;
++    uint32_t nb03;
++    uint32_t nb11;
++    uint32_t nb12;
++    uint32_t nb13;
++};
++static_assert(sizeof(vk_op_rope_push_constants) <= 128, "sizeof(vk_op_rope_push_constants) must be <= 128");
++
++// For fused rms_norm+mul+rope(+view+set_rows)
++struct vk_op_rms_norm_mul_rope_push_constants {
++    vk_op_binary_push_constants bin;
++    vk_op_rope_push_constants rope;
++};
++
++struct vk_op_soft_max_push_constants {
++    uint32_t KX;
++    uint32_t KY;
++    uint32_t ne00;
++    uint32_t ne01;
++    uint32_t ne02;
++    uint32_t ne12;
++    uint32_t ne13;
++    uint32_t nb11;
++    uint32_t nb12;
++    uint32_t nb13;
++    float scale;
++    float max_bias;
++    float m0;
++    float m1;
++    uint32_t n_head_log2;
++    uint32_t nrows_x;
++    uint32_t has_sinks;
++};
++
++struct vk_op_argsort_push_constants {
++    uint32_t ncols;
++    uint32_t ncols_padded;
++    uint32_t ncols_padded_log2;
++    uint32_t nrows;
++    uint32_t order;
++    uint32_t outer_start;
++    uint32_t outer_end;
++    uint32_t inner_start;
++    uint32_t inner_end;
++};
++
++struct vk_op_topk_push_constants {
++    uint32_t orig_ncols;
++    uint32_t ncols_input;
++    uint32_t ncols_output;
++    uint32_t k;
++    uint32_t nrows;
++    uint32_t first_pass;
++    uint32_t last_pass;
++};
++
++struct vk_op_im2col_push_constants {
++    uint64_t dst_addr;
++    uint32_t batch_offset; uint32_t offset_delta;
++    uint32_t IC;
++    uint32_t IW; uint32_t IH;
++    uint32_t OW; uint32_t OH;
++    uint32_t KW; uint32_t KH;
++    uint32_t pelements;
++    uint32_t CHW;
++    int32_t s0; int32_t s1;
++    int32_t p0; int32_t p1;
++    int32_t d0; int32_t d1;
++    uint32_t batch_IC;
++};
++
++struct vk_op_im2col_3d_push_constants {
++    uint64_t dst_addr;
++    uint32_t nb10;
++    uint32_t nb11;
++    uint32_t nb12;
++    uint32_t nb13;
++    uint32_t s0;
++    uint32_t s1;
++    uint32_t s2;
++    uint32_t p0;
++    uint32_t p1;
++    uint32_t p2;
++    uint32_t d0;
++    uint32_t d1;
++    uint32_t d2;
++    uint32_t IW;
++    uint32_t IH;
++    uint32_t ID;
++    uint32_t IC;
++    uint32_t KW;
++    uint32_t OH;
++    uint32_t KD_KH_KW;
++    uint32_t KH_KW;
++    uint32_t IC_KD_KH_KW;
++    uint32_t N_OD_OH;
++    uint32_t OD_OH;
++    uint32_t OD_OH_OW_IC_KD_KH_KW;
++    uint32_t OH_OW_IC_KD_KH_KW;
++    uint32_t OW_IC_KD_KH_KW;
++    uint32_t misalign_offsets;
++};
++
++struct vk_op_timestep_embedding_push_constants {
++    uint32_t nb1;
++    uint32_t dim;
++    uint32_t max_period;
++};
++
++struct vk_op_conv_transpose_1d_push_constants {
++    uint32_t Cout;
++    uint32_t Cin;
++    uint32_t K;
++    uint32_t L;
++    uint32_t KL;
++
++    uint32_t nb01;
++    uint32_t nb02;
++    uint32_t nb11;
++    uint32_t nb1;
++
++    int32_t s0;
++};
++
++struct vk_op_pool2d_push_constants {
++    uint32_t IW; uint32_t IH;
++    uint32_t OW; uint32_t OH;
++    uint32_t OC;
++    uint32_t pelements;
++    uint32_t op;
++    int32_t k0; int32_t k1;
++    int32_t s0; int32_t s1;
++    int32_t p0; int32_t p1;
++};
++
++struct vk_op_rwkv_wkv6_push_constants {
++    uint32_t B;
++    uint32_t T;
++    uint32_t C;
++    uint32_t H;
++};
++
++struct vk_op_rwkv_wkv7_push_constants {
++    uint32_t B;
++    uint32_t T;
++    uint32_t C;
++    uint32_t H;
++};
++struct vk_op_ssm_scan_push_constants {
++    uint32_t nb02, nb03, nb12, nb13;
++    uint32_t nb21, nb22, nb31;
++    uint32_t nb42, nb43, nb52, nb53;
++    uint32_t s_off;
++    uint32_t n_head, d_head, n_group, n_tok;
++};
++struct vk_op_ssm_conv_push_constants {
++    uint32_t nb01, nb02;
++    uint32_t nb11;
++    uint32_t dst_nb0, dst_nb1, dst_nb2;
++    uint32_t nc, ncs, nr, n_t, n_s;
++};
++
++struct vk_op_conv2d_push_constants {
++    uint32_t Cout;
++    uint32_t Cin;
++    uint32_t N;
++
++    uint32_t W;
++    uint32_t H;
++    uint32_t OW;
++    uint32_t OH;
++
++    uint32_t nb01;
++    uint32_t nb02;
++    uint32_t nb03;
++
++    uint32_t nb11;
++    uint32_t nb12;
++    uint32_t nb13;
++
++    uint32_t nb1;
++    uint32_t nb2;
++    uint32_t nb3;
++
++    // init_fastdiv_values constants for dividing by OW, OW*OH
++    uint32_t OWmp;   uint32_t OWL;
++    uint32_t OWOHmp; uint32_t OWOHL;
++};
++
++template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
++    // Compute magic values to divide by OW, OW*OH
++    init_fastdiv_values(p.OW,       p.OWmp,    p.OWL);
++    init_fastdiv_values(p.OW*p.OH,  p.OWOHmp,  p.OWOHL);
++}
++
++struct vk_op_conv2d_dw_push_constants {
++    uint32_t ne;
++    uint32_t batches;
++    uint32_t channels;
++    uint32_t dst_w;
++    uint32_t dst_h;
++    uint32_t src_w;
++    uint32_t src_h;
++    uint32_t knl_w;
++    uint32_t knl_h;
++    int32_t stride_x;
++    int32_t stride_y;
++    int32_t pad_x;
++    int32_t pad_y;
++    int32_t dilation_x;
++    int32_t dilation_y;
++};
++
++struct vk_op_upscale_push_constants {
++    uint32_t ne; uint32_t a_offset; uint32_t d_offset;
++    uint32_t ne00; uint32_t ne01;
++    uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
++    uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
++    float sf0; float sf1; float sf2; float sf3;
++    float pixel_offset;
++};
++
++struct vk_op_sum_rows_push_constants
++{
++    uint32_t n_cols;
++    uint32_t ne01, ne02;
++    uint32_t nb01, nb02, nb03;
++    uint32_t nb11, nb12, nb13;
++    float weight;
++    uint32_t misalign_offsets;
++    uint32_t ne0_12mp, ne0_12L;
++    uint32_t ne0_1mp, ne0_1L;
++};
++
++static vk_op_sum_rows_push_constants vk_op_sum_rows_push_constants_init(const ggml_tensor * src, const ggml_tensor * dst, int64_t n_cols) {
++    uint32_t type_size = (uint32_t)ggml_type_size(src->type);
++    vk_op_sum_rows_push_constants p = {};
++    p.n_cols = (uint32_t)n_cols;
++    p.ne01 = (uint32_t)src->ne[1];
++    p.ne02 = (uint32_t)src->ne[2];
++    p.nb01 = (uint32_t)src->nb[1] / type_size;
++    p.nb02 = (uint32_t)src->nb[2] / type_size;
++    p.nb03 = (uint32_t)src->nb[3] / type_size;
++    p.nb11 = (uint32_t)dst->nb[1] / type_size;
++    p.nb12 = (uint32_t)dst->nb[2] / type_size;
++    p.nb13 = (uint32_t)dst->nb[3] / type_size;
++    p.weight = 1.0f;
++    return p;
++}
++
++template <> void init_pushconst_fastdiv(vk_op_sum_rows_push_constants &p) {
++    init_fastdiv_values(p.ne01*p.ne02, p.ne0_12mp, p.ne0_12L);
++    init_fastdiv_values(p.ne01,        p.ne0_1mp,  p.ne0_1L);
++}
++
++struct vk_quantize_q8_1_push_constants {
++    uint32_t ne;
++    uint32_t num_blocks;
++};
++
++struct vk_op_flash_attn_split_k_reduce_push_constants {
++    uint32_t D;
++    uint32_t ne1;
++    uint32_t ne2;
++    uint32_t ne3;
++    uint32_t k_num;
++    uint32_t sinks;
++};
++
++struct vk_op_flash_attn_mask_opt_push_constants {
++    uint32_t nem0;
++    uint32_t nem1;
++    uint32_t nem2;
++    uint32_t nbm1;
++    uint32_t nbm2;
++    uint32_t nbm3;
++    uint32_t nbd1;
++    uint32_t nbd2;
++    uint32_t nbd3;
++};
++
++// Allow pre-recording command buffers
++struct vk_staging_memcpy {
++    vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
++
++    void * dst;
++    const void * src;
++    size_t n;
++};
++
++struct vk_staging_memset {
++    vk_staging_memset(void * _dst, uint32_t _val, size_t _n) : dst(_dst), val(_val), n(_n) {}
++
++    void * dst;
++    uint32_t val;
++    size_t n;
++};
++
++struct vk_context_struct {
++    vk_submission * s;
++    std::vector seqs;
++
++    int exit_tensor_idx;
++
++    std::vector in_memcpys;
++    std::vector out_memcpys;
++    std::vector memsets;
++
++    vk_command_pool * p {};
++};
++typedef std::shared_ptr vk_context;
++typedef std::weak_ptr vk_context_ref;
++
++struct ggml_vk_garbage_collector {
++    std::vector tl_semaphores;
++    std::vector semaphores;
++    std::vector events;
++    std::vector contexts;
++};
++
++static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_context subctx);
++static void ggml_vk_load_shaders(vk_device& device);
++static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx);
++
++static bool vk_memory_logger_enabled = false;
++
++#define VK_LOG_MEMORY(msg) if (vk_memory_logger_enabled) { std::cerr << "ggml_vulkan memory: " << msg << std::endl; }
++
++static std::string format_size(size_t size) {
++    const size_t kib = 1024;
++    const size_t mib = kib * 1024;
++    const size_t gib = mib * 1024;
++
++    std::ostringstream oss;
++    oss << std::fixed << std::setprecision(2);
++
++    if (size >= gib) {
++        oss << static_cast(size) / gib << " GiB";
++    } else if (size >= mib) {
++        oss << static_cast(size) / mib << " MiB";
++    } else if (size >= kib) {
++        oss << static_cast(size) / kib << " KiB";
++    } else {
++        oss << size << " B";
++    }
++
++    return oss.str();
++}
++
++class vk_memory_logger {
++public:
++    vk_memory_logger(): total_device(0), total_host(0) {}
++    void log_allocation(vk_buffer_ref buf_ref, size_t size);
++    void log_deallocation(vk_buffer_ref buf_ref);
++
++private:
++    std::map allocations; // Track allocations
++    size_t total_device;
++    size_t total_host;
++    static std::mutex log_mutex;
++};
++
++std::mutex vk_memory_logger::log_mutex;
++
++static bool vk_perf_logger_enabled = false;
++static bool vk_perf_logger_concurrent = false;
++static bool vk_enable_sync_logger = false;
++// number of calls between perf logger prints
++static uint32_t vk_perf_logger_frequency = 1;
++static std::string vk_pipeline_stats_filter;
++
++class vk_perf_logger {
++  public:
++    void print_timings(bool force = false) {
++        if (timings.empty()) {
++            return;
++        }
++        print_count++;
++        if ((print_count % vk_perf_logger_frequency) != 0 && !force) {
++            return;
++        }
++        print_count = 0;
++        uint64_t total_all_op_times = 0;
++        std::cerr << "----------------\nVulkan Timings:" << std::endl;
++        for (const auto & t : timings) {
++            uint64_t total_op_times = 0;
++            for (const auto & time : t.second) {
++                total_op_times += time;
++            }
++            std::cerr << t.first << ": " << t.second.size() << " x " << (total_op_times / t.second.size() / 1000.0)
++                      << " us = " << (total_op_times / 1000.0) << " us";
++
++            // If we have as many flops entries as timing entries for the op, then compute and log the flops/S.
++            auto it = flops.find(t.first);
++            if (it != flops.end() && (it->second).size() == t.second.size()) {
++                uint64_t total_op_flops = 0;
++                for (const auto & elem : it->second) {
++                    total_op_flops += elem;
++                }
++                std::cerr << " ("
++                          << (double(total_op_flops) / (1000.0 * 1000.0 * 1000.0)) /
++                                 (double(total_op_times) / (1000.0 * 1000.0 * 1000.0))
++                          << " GFLOPS/s)";
++            }
++
++            total_all_op_times += total_op_times;
++
++            std::cerr << std::endl;
++        }
++
++        if (timings.size() > 0) {
++            std::cerr << "Total time: " << total_all_op_times / 1000.0 << " us." << std::endl;
++        }
++
++        timings.clear();
++        flops.clear();
++    }
++
++    std::string get_node_fusion_name(const ggml_tensor * node, const char *fusion_name, uint64_t *n_flops) {
++        *n_flops = 0;
++        std::string fusion_str;
++        if (fusion_name) {
++            fusion_str = fusion_name + std::string(" ");
++        }
++        if (node->op == GGML_OP_UNARY) {
++            return fusion_str + ggml_unary_op_name(ggml_get_unary_op(node));
++        }
++        if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
++            const uint64_t m     = node->ne[0];
++            const uint64_t n     = node->ne[1];
++            const uint64_t k     = node->src[1]->ne[0];
++            const uint64_t batch = node->ne[2] * node->ne[3];
++            std::string    name  = ggml_op_name(node->op);
++            if ((node->op == GGML_OP_MUL_MAT && n <= mul_mat_vec_max_cols) ||
++                (node->op == GGML_OP_MUL_MAT_ID && node->src[2]->ne[1] == 1)) {
++                name += "_VEC";
++            }
++            name += " ";
++            name += ggml_type_name(node->src[0]->type);
++            name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k);
++            if (node->op == GGML_OP_MUL_MAT_ID) {
++                name += " n_expert=" + std::to_string(node->src[0]->ne[2]);
++            }
++            if (batch > 1) {
++                name += " batch=" + std::to_string(batch);
++            }
++            name = fusion_str + name;
++            *n_flops = m * n * (k + (k - 1)) * batch;
++            return name;
++        }
++        if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D) {
++            std::string   name    = ggml_op_name(node->op);
++            ggml_tensor * knl     = node->src[0];
++            uint64_t      OW      = node->ne[0];
++            uint64_t      OH      = node->ne[1];
++            uint64_t      N       = node->ne[3];
++            uint64_t      Cout    = node->ne[2];
++            uint64_t      KW      = knl->ne[0];
++            uint64_t      KH      = knl->ne[1];
++            uint64_t      Cin     = node->src[1]->ne[2];
++            // KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ
++            uint64_t      size_M  = Cout;
++            uint64_t      size_K  = Cin * KW * KH;
++            uint64_t      size_N  = N * OW * OH;
++            *n_flops = size_M * size_N * (size_K + (size_K - 1));
++            name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) +
++                    ", N=N*OW*OH=" + std::to_string(size_N);
++            name = fusion_str + name;
++            return name;
++        }
++        if (node->op == GGML_OP_RMS_NORM) {
++            std::string   name    = ggml_op_name(node->op);
++            name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")";
++            name = fusion_str + name;
++            return name;
++        }
++        if (node->op == GGML_OP_FLASH_ATTN_EXT) {
++            const ggml_tensor * dst = node;
++            const ggml_tensor * q = node->src[0];
++            const ggml_tensor * k = node->src[1];
++            const ggml_tensor * v = node->src[2];
++            const ggml_tensor * m = node->src[3];
++            std::stringstream name;
++            name << fusion_str;
++            name << ggml_op_name(node->op) <<
++                " dst(" << dst->ne[0] << "," << dst->ne[1] << "," << dst->ne[2] << "," << dst->ne[3] << "), " <<
++                " q(" << q->ne[0] << "," << q->ne[1] << "," << q->ne[2] << "," << q->ne[3] << "), " <<
++                " k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " <<
++                " v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " <<
++                " m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")";
++            *n_flops = 2ull * q->ne[1] * q->ne[2] * (k->ne[0] + v->ne[0]) * k->ne[1] * q->ne[3];
++            return name.str();
++        }
++        if (node->op == GGML_OP_TOP_K) {
++            std::stringstream name;
++            name << fusion_str;
++            name << ggml_op_name(node->op) <<
++                " K=" << node->ne[0] <<
++                " (" << node->src[0]->ne[0] << "," << node->src[0]->ne[1] << "," << node->src[0]->ne[2] << "," << node->src[0]->ne[3] << ")";
++            return name.str();
++        }
++        return fusion_str + ggml_op_name(node->op);
++    }
++
++    void log_timing(const ggml_tensor * node, const char *fusion_name, uint64_t time) {
++        uint64_t n_flops;
++        std::string name = get_node_fusion_name(node, fusion_name, &n_flops);
++        if (n_flops) {
++            flops[name].push_back(n_flops);
++        }
++        timings[name].push_back(time);
++    }
++
++    void log_timing(const std::vector &nodes, const std::vector &names, uint64_t time) {
++        uint64_t total_flops = 0;
++        std::string name;
++        for (size_t n = 0; n < nodes.size(); ++n) {
++            uint64_t n_flops = 0;
++            name += get_node_fusion_name(nodes[n], names[n], &n_flops);
++            total_flops += n_flops;
++
++            if (n != nodes.size() - 1) {
++                name += ", ";
++            }
++        }
++        if (total_flops) {
++            flops[name].push_back(total_flops);
++        }
++        timings[name].push_back(time);
++    }
++
++  private:
++    std::map> timings;
++    std::map> flops;
++    uint32_t print_count {};
++};
++
++struct ggml_backend_vk_context {
++    std::string name;
++
++    vk_device device;
++
++    size_t semaphore_idx, event_idx;
++    ggml_vk_garbage_collector gc;
++    size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_add_rms_partials, prealloc_size_add_rms_partials_offset;
++    vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_add_rms_partials, sync_staging;
++    vk::Fence fence, almost_ready_fence;
++    bool submit_pending {};
++    bool almost_ready_fence_pending {};
++    // Set before op_add and unset after op_rms_norm to indicate that the add should
++    // write partial sums to accumulate the square of the vector components
++    bool do_add_rms_partials_offset_calculation;
++    bool do_add_rms_partials;
++
++    uint64_t last_total_mul_mat_bytes {};
++
++    // Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert.
++    vk_pipeline_struct * prealloc_y_last_pipeline_used {};
++    const ggml_tensor * prealloc_y_last_tensor_used {};
++
++    // Track which nodes have been used since the last sync, and whether they were written to
++    std::vector unsynced_nodes_written;
++    std::vector unsynced_nodes_read;
++    // Track which prealloc buffers have pending reads that need to be synchronized.
++    // These are checked before writing to the buffer (and call ggml_vk_sync_buffers if set),
++    // and set to true after the buffer contents are consumed.
++    bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync;
++
++    vk_context_ref compute_ctx;
++
++    std::vector tensor_ctxs;
++
++    std::vector descriptor_pools;
++    std::vector descriptor_sets;
++    uint32_t descriptor_set_idx {};
++    uint32_t pipeline_descriptor_set_requirements {};
++
++    vk_command_pool compute_cmd_pool;
++
++    // number of additional consecutive nodes that are being fused with the
++    // node currently being processed
++    int num_additional_fused_ops {};
++    // Bitmask of which fused ops need to write an intermediate value to memory.
++    // Bit 'i' means nodes[start_of_fusion + i] writes to memory.
++    // If there's no fusion, bit 0 is still set.
++    int fused_ops_write_mask {};
++    topk_moe_mode fused_topk_moe_mode {};
++    bool fused_topk_moe_scale {};
++
++    // for GGML_VK_PERF_LOGGER
++    std::unique_ptr perf_logger;
++    vk::QueryPool query_pool;
++    std::vector query_fusion_names;
++    std::vector query_fusion_node_count;
++    std::vector query_nodes;
++    std::vector query_node_idx;
++    int32_t num_queries {};
++    int32_t query_idx {};
++};
++
++static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000;  // NOLINT
++
++static uint64_t vk_tensor_offset(const ggml_tensor * tensor) {
++    if (tensor->view_src) {
++        return (uint8_t *) tensor->view_src->data - (uint8_t *) vk_ptr_base;
++    }
++    return (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base;
++}
++
++static uint32_t get_misalign_bytes(const ggml_backend_vk_context * ctx, const ggml_tensor * t)
++{
++    return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));;
++}
++
++template  void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
++    GGML_UNUSED(p);
++    GGML_UNUSED(src0);
++    GGML_UNUSED(src1);
++    GGML_UNUSED(src2);
++    GGML_UNUSED(src3);
++    GGML_UNUSED(dst);
++    static_assert(!std::is_const::value, "unexpected type");
++    GGML_ASSERT(!src0 || get_misalign_bytes(ctx, src0) == 0);
++    GGML_ASSERT(!src1 || get_misalign_bytes(ctx, src1) == 0);
++    GGML_ASSERT(!src2 || get_misalign_bytes(ctx, src2) == 0);
++    GGML_ASSERT(!src3 || get_misalign_bytes(ctx, src3) == 0);
++    GGML_ASSERT(!dst  || get_misalign_bytes(ctx, dst) == 0);
++}
++
++template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_mat_vec_p021_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
++    const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
++    const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
++
++    p.b_offset = b_offset;
++    p.d_offset = d_offset;
++
++    GGML_UNUSED(src0);
++    GGML_UNUSED(src2);
++    GGML_UNUSED(src3);
++}
++
++template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_mat_vec_nc_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
++    const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
++    const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
++
++    p.b_offset = b_offset;
++    p.d_offset = d_offset;
++
++    GGML_UNUSED(src0);
++    GGML_UNUSED(src2);
++    GGML_UNUSED(src3);
++}
++
++struct ggml_backend_vk_buffer_context {
++    vk_device_ref device;
++    vk_buffer dev_buffer;
++    std::string name;
++
++    ggml_backend_vk_buffer_context(vk_device_ref device, vk_buffer&& dev_buffer, std::string& name) :
++        device(device),
++        dev_buffer(dev_buffer),
++        name(name) {
++    }
++
++    ~ggml_backend_vk_buffer_context() {
++        ggml_vk_destroy_buffer(dev_buffer);
++    }
++};
++
++void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) {
++    if (!vk_memory_logger_enabled) {
++        return;
++    }
++    std::lock_guard guard(log_mutex);
++    vk_buffer buf = buf_ref.lock();
++    const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal);
++    const std::string type = device ? "device" : "host";
++    allocations[buf->buffer] = size;
++    total_device += device ? size : 0;
++    total_host += device ? 0 : size;
++    VK_LOG_MEMORY(buf->device->name << ": +" << format_size(size) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host));
++}
++
++void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) {
++    if (buf_ref.expired() || buf_ref.lock()->size == 0 || !vk_memory_logger_enabled) {
++        return;
++    }
++
++    std::lock_guard guard(log_mutex);
++    vk_buffer buf = buf_ref.lock();
++    const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal);
++    std::string type = device ? "device" : "host";
++    auto it = allocations.find(buf->buffer);
++    total_device -= device ? it->second : 0;
++    total_host -= device ? 0 : it->second;
++    if (it != allocations.end()) {
++        VK_LOG_MEMORY(buf->device->name << ": -" << format_size(it->second) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host));
++        allocations.erase(it);
++    } else {
++        VK_LOG_MEMORY("ERROR " << buf->device->name << ": Attempted to deallocate unknown " << type << " memory at " << buf->buffer);
++    }
++}
++
++struct vk_instance_t {
++    vk::Instance instance;
++
++    bool debug_utils_support = false;  // VK_EXT_debug_utils enabled
++    PFN_vkSetDebugUtilsObjectNameEXT pfn_vkSetDebugUtilsObjectNameEXT = {};
++    PFN_vkQueueBeginDebugUtilsLabelEXT pfn_vkQueueBeginDebugUtilsLabelEXT = {};
++    PFN_vkQueueEndDebugUtilsLabelEXT   pfn_vkQueueEndDebugUtilsLabelEXT   = {};
++    PFN_vkCmdBeginDebugUtilsLabelEXT   pfn_vkCmdBeginDebugUtilsLabelEXT   = {};
++    PFN_vkCmdEndDebugUtilsLabelEXT pfn_vkCmdEndDebugUtilsLabelEXT = {};
++    PFN_vkCmdInsertDebugUtilsLabelEXT  pfn_vkCmdInsertDebugUtilsLabelEXT  = {};
++
++    std::vector device_indices;
++    std::vector   device_supports_membudget;
++    vk_device devices[GGML_VK_MAX_DEVICES];
++};
++
++static bool vk_instance_initialized = false;
++static vk_instance_t vk_instance;
++
++#ifdef GGML_VULKAN_CHECK_RESULTS
++static size_t vk_skip_checks;
++static size_t vk_output_tensor;
++
++static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name);
++static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
++static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
++#endif
++
++typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
++
++static void ggml_backend_vk_free(ggml_backend_t backend);
++
++static VkDeviceSize ggml_vk_get_max_buffer_range(const ggml_backend_vk_context * ctx, const vk_buffer &buf, const VkDeviceSize offset) {
++    const VkDeviceSize range = std::min(VkDeviceSize{buf->size - offset},
++                                        VkDeviceSize{ctx->device->properties.limits.maxStorageBufferRange});
++    return range;
++}
++
++// Wait for ctx->fence to be signaled.
++static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) {
++    // Use waitForFences while most of the graph executes. Hopefully the CPU can sleep
++    // during this wait.
++    if (ctx->almost_ready_fence_pending) {
++        VK_CHECK(ctx->device->device.waitForFences({ ctx->almost_ready_fence }, true, UINT64_MAX), "almost_ready_fence");
++        ctx->device->device.resetFences({ ctx->almost_ready_fence });
++        ctx->almost_ready_fence_pending = false;
++    }
++
++    // Spin (w/pause) waiting for the graph to finish executing.
++    vk::Result result;
++    while ((result = ctx->device->device.getFenceStatus(ctx->fence)) != vk::Result::eSuccess) {
++        if (result != vk::Result::eNotReady) {
++            fprintf(stderr, "ggml_vulkan: error %s at %s:%d\n", to_string(result).c_str(), __FILE__, __LINE__);
++            exit(1);
++        }
++        for (uint32_t i = 0; i < 100; ++i) {
++            YIELD();
++            YIELD();
++            YIELD();
++            YIELD();
++            YIELD();
++            YIELD();
++            YIELD();
++            YIELD();
++            YIELD();
++            YIELD();
++        }
++    }
++    ctx->device->device.resetFences({ ctx->fence });
++}
++
++// variables to track number of compiles in progress
++static uint32_t compile_count = 0;
++static std::mutex compile_count_mutex;
++static std::condition_variable compile_count_cond;
++
++static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint,
++                                         uint32_t parameter_count, std::array wg_denoms, std::vector specialization_constants,
++                                         bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) {
++    VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << pipeline->name << ", " << entrypoint << ", " << parameter_count <<
++                 ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " <<
++                 disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")");
++    GGML_ASSERT(parameter_count > 0);
++    GGML_ASSERT(parameter_count <= MAX_PARAMETER_COUNT);
++    GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
++
++    vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast(spv_data));
++    pipeline->shader_module = device->device.createShaderModule(shader_module_create_info);
++
++    vk::PushConstantRange pcr(
++        vk::ShaderStageFlagBits::eCompute,
++        0,
++        pipeline->push_constant_size
++    );
++
++    vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), device->dsl, pcr);
++    pipeline->layout = device->device.createPipelineLayout(pipeline_layout_create_info);
++
++    std::vector specialization_entries(specialization_constants.size());
++
++    for (size_t i = 0; i < specialization_constants.size(); i++) {
++        specialization_entries[i].constantID = i;
++        specialization_entries[i].offset = i * sizeof(uint32_t);
++        specialization_entries[i].size = sizeof(uint32_t);
++    }
++
++    vk::SpecializationInfo specialization_info(
++        specialization_entries.size(),
++        specialization_entries.data(),
++        specialization_constants.size() * sizeof(uint32_t),
++        specialization_constants.data()
++    );
++
++    vk::PipelineShaderStageCreateFlags pipeline_shader_stage_create_flags{};
++
++    if (device->subgroup_require_full_support && require_full_subgroups) {
++        pipeline_shader_stage_create_flags |= vk::PipelineShaderStageCreateFlagBits::eRequireFullSubgroupsEXT;
++    }
++
++    vk::PipelineShaderStageCreateInfo pipeline_shader_create_info(
++            pipeline_shader_stage_create_flags,
++            vk::ShaderStageFlagBits::eCompute,
++            pipeline->shader_module,
++            entrypoint.c_str(),
++            &specialization_info);
++
++    vk::PipelineShaderStageRequiredSubgroupSizeCreateInfoEXT pipeline_shader_stage_required_subgroup_size_create_info;
++    pipeline_shader_stage_required_subgroup_size_create_info.requiredSubgroupSize = required_subgroup_size;
++    if (device->subgroup_size_control && required_subgroup_size > 0) {
++        GGML_ASSERT(device->subgroup_min_size <= required_subgroup_size && required_subgroup_size <= device->subgroup_max_size);
++        pipeline_shader_create_info.setPNext(&pipeline_shader_stage_required_subgroup_size_create_info);
++    }
++
++    vk::ComputePipelineCreateInfo compute_pipeline_create_info(
++        device->pipeline_executable_properties_support ?
++            vk::PipelineCreateFlagBits::eCaptureStatisticsKHR :
++            vk::PipelineCreateFlags{},
++        pipeline_shader_create_info,
++        pipeline->layout);
++
++    vk::PipelineRobustnessCreateInfoEXT rci;
++
++    if (device->pipeline_robustness && disable_robustness) {
++        rci.storageBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled;
++        rci.uniformBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled;
++        compute_pipeline_create_info.setPNext(&rci);
++    }
++
++#if defined(VK_EXT_shader_64bit_indexing)
++    vk::PipelineCreateFlags2CreateInfo pipelineFlags2CreateInfo;
++    if (pipeline->is_64b_indexing)
++    {
++        pipelineFlags2CreateInfo.flags = vk::PipelineCreateFlagBits2::e64BitIndexingEXT;
++        if (device->pipeline_executable_properties_support) {
++            pipelineFlags2CreateInfo.flags |= vk::PipelineCreateFlagBits2::eCaptureStatisticsKHR;
++        }
++        pipelineFlags2CreateInfo.setPNext(compute_pipeline_create_info.pNext);
++        compute_pipeline_create_info.setPNext(&pipelineFlags2CreateInfo);
++    }
++#endif
++
++    try {
++        pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
++    } catch (const vk::SystemError& e) {
++        std::cerr << "ggml_vulkan: Compute pipeline creation failed for " << pipeline->name << std::endl;
++        std::cerr << "ggml_vulkan: " << e.what() << std::endl;
++        throw e;
++    }
++    pipeline->compiled = true;
++
++    if (vk_instance.debug_utils_support) {
++        vk::DebugUtilsObjectNameInfoEXT duoni;
++        duoni.objectType = vk::ObjectType::ePipeline;
++        duoni.pObjectName = pipeline->name.c_str();
++        duoni.objectHandle = /*reinterpret_cast*/(uint64_t)(static_cast(pipeline->pipeline));
++        vk_instance.pfn_vkSetDebugUtilsObjectNameEXT(device->device, &static_cast(duoni));
++    }
++
++    if (device->pipeline_executable_properties_support) {
++        vk::PipelineExecutableInfoKHR executableInfo;
++        executableInfo.pipeline = pipeline->pipeline;
++
++        auto statistics = device->device.getPipelineExecutableStatisticsKHR(executableInfo);
++
++        bool print_stats = !vk_pipeline_stats_filter.empty() &&
++                           pipeline->name.find(vk_pipeline_stats_filter) != std::string::npos;
++        if (print_stats) {
++            std::cerr << "ggml_vulkan: pipeline stats for " << pipeline->name << ":" << std::endl;
++        }
++
++        for (auto & s : statistics) {
++            if (print_stats) {
++                std::cerr << "ggml_vulkan:   " << s.name.data() << ": ";
++                switch (s.format) {
++                    case vk::PipelineExecutableStatisticFormatKHR::eBool32:
++                        std::cerr << (s.value.b32 ? "true" : "false");
++                        break;
++                    case vk::PipelineExecutableStatisticFormatKHR::eInt64:
++                        std::cerr << s.value.i64;
++                        break;
++                    case vk::PipelineExecutableStatisticFormatKHR::eUint64:
++                        std::cerr << s.value.u64;
++                        break;
++                    case vk::PipelineExecutableStatisticFormatKHR::eFloat64:
++                        std::cerr << s.value.f64;
++                        break;
++                }
++                std::cerr << std::endl;
++            }
++            // "Register Count" is reported by NVIDIA drivers.
++            if (strcmp(s.name, "Register Count") == 0) {
++                VK_LOG_DEBUG(pipeline->name << " " << s.name << ": " << s.value.u64 << " registers");
++                pipeline->register_count = (uint32_t)s.value.u64;
++            }
++        }
++    }
++
++    device->all_pipelines.push_back(pipeline);
++
++    {
++        std::lock_guard guard(compile_count_mutex);
++        assert(compile_count > 0);
++        compile_count--;
++    }
++    compile_count_cond.notify_all();
++}
++
++static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
++    VK_LOG_DEBUG("ggml_pipeline_destroy_pipeline(" << pipeline->name << ")");
++    device.destroyPipelineLayout(pipeline->layout);
++
++    device.destroyShaderModule(pipeline->shader_module);
++
++    device.destroyPipeline(pipeline->pipeline);
++}
++
++static void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx, vk_pipeline& pipeline, uint32_t n) {
++    VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")");
++    ctx->pipeline_descriptor_set_requirements += n;
++    if (!pipeline->compiled) {
++        pipeline->needed = true;
++        ggml_vk_load_shaders(ctx->device);
++    }
++    ggml_pipeline_allocate_descriptor_sets(ctx);
++}
++
++static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx) {
++
++    if (ctx->descriptor_sets.size() >= ctx->pipeline_descriptor_set_requirements) {
++        // Enough descriptors are available
++        return;
++    }
++
++    vk_device& device = ctx->device;
++
++    // Grow by 50% to avoid frequent allocations
++    uint32_t needed = std::max(3 * ctx->descriptor_sets.size() / 2, size_t{ctx->pipeline_descriptor_set_requirements});
++    uint32_t to_alloc = needed - ctx->descriptor_sets.size();
++    uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - ctx->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE;
++    uint32_t pool_idx = ctx->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE;
++
++    while (to_alloc > 0) {
++        const uint32_t alloc_count = std::min(pool_remaining, to_alloc);
++        to_alloc -= alloc_count;
++        pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE;
++
++        if (pool_idx >= ctx->descriptor_pools.size()) {
++            vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, MAX_PARAMETER_COUNT * VK_DEVICE_DESCRIPTOR_POOL_SIZE);
++            vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size);
++            ctx->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info));
++        }
++
++        std::vector layouts(alloc_count);
++        for (uint32_t i = 0; i < alloc_count; i++) {
++            layouts[i] = device->dsl;
++        }
++        vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(ctx->descriptor_pools[pool_idx], alloc_count, layouts.data());
++        std::vector sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info);
++        ctx->descriptor_sets.insert(ctx->descriptor_sets.end(), sets.begin(), sets.end());
++
++        pool_idx++;
++    }
++}
++
++static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_command_pool& p) {
++    VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()");
++
++    if (p.cmd_buffers.size() > p.cmd_buffer_idx) {
++        // Reuse command buffer
++        return p.cmd_buffers[p.cmd_buffer_idx++];
++    }
++
++    vk::CommandBufferAllocateInfo command_buffer_alloc_info(
++        p.pool,
++        vk::CommandBufferLevel::ePrimary,
++        1);
++    const std::vector cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info);
++    auto buf = cmd_buffers.front();
++
++    p.cmd_buffers.push_back(buf);
++    p.cmd_buffer_idx++;
++
++    return buf;
++}
++
++static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) {
++    if (ctx->seqs.empty()) {
++        if (fence) {
++            std::lock_guard guard(queue_mutex);
++            ctx->p->q->queue.submit({}, fence);
++        }
++        return;
++    }
++    VK_LOG_DEBUG("ggml_vk_submit(" << ctx << ", " << fence << ")");
++
++    std::vector> tl_wait_vals;
++    std::vector> tl_signal_vals;
++    std::vector> tl_wait_semaphores;
++    std::vector> tl_signal_semaphores;
++    std::vector tl_submit_infos;
++    std::vector submit_infos;
++    int idx = -1;
++    std::vector> stage_flags;
++
++    size_t reserve = 0;
++
++    for (const auto& sequence : ctx->seqs) {
++        reserve += sequence.size();
++    }
++
++    // Pre-reserve vectors to prevent reallocation, which invalidates pointers
++    tl_wait_semaphores.reserve(reserve);
++    tl_wait_vals.reserve(reserve);
++    tl_signal_semaphores.reserve(reserve);
++    tl_signal_vals.reserve(reserve);
++    tl_submit_infos.reserve(reserve);
++    submit_infos.reserve(reserve);
++    stage_flags.reserve(reserve);
++
++    for (const auto& sequence : ctx->seqs) {
++        for (const auto& submission : sequence) {
++            stage_flags.push_back({});
++            idx++;
++            tl_wait_vals.push_back({});
++            tl_wait_semaphores.push_back({});
++            tl_signal_vals.push_back({});
++            tl_signal_semaphores.push_back({});
++            for (size_t i = 0; i < submission.wait_semaphores.size(); i++) {
++                stage_flags[idx].push_back(ctx->p->q->stage_flags);
++                tl_wait_vals[idx].push_back(submission.wait_semaphores[i].value);
++                tl_wait_semaphores[idx].push_back(submission.wait_semaphores[i].s);
++            }
++            for (size_t i = 0; i < submission.signal_semaphores.size(); i++) {
++                tl_signal_vals[idx].push_back(submission.signal_semaphores[i].value);
++                tl_signal_semaphores[idx].push_back(submission.signal_semaphores[i].s);
++            }
++            tl_submit_infos.push_back({
++                (uint32_t) submission.wait_semaphores.size(),
++                tl_wait_vals[idx].data(),
++                (uint32_t) submission.signal_semaphores.size(),
++                tl_signal_vals[idx].data(),
++            });
++            tl_submit_infos[idx].sType = vk::StructureType::eTimelineSemaphoreSubmitInfo;
++            tl_submit_infos[idx].pNext = nullptr;
++            vk::SubmitInfo si{
++                (uint32_t) submission.wait_semaphores.size(),
++                tl_wait_semaphores[idx].data(),
++                stage_flags[idx].data(),
++                1,
++                &submission.buffer,
++                (uint32_t) submission.signal_semaphores.size(),
++                tl_signal_semaphores[idx].data(),
++            };
++            si.setPNext(&tl_submit_infos[idx]);
++            submit_infos.push_back(si);
++        }
++    }
++
++    std::lock_guard guard(queue_mutex);
++    ctx->p->q->queue.submit(submit_infos, fence);
++
++    ctx->seqs.clear();
++}
++
++static uint32_t ggml_vk_find_queue_family_index(std::vector& queue_family_props, const vk::QueueFlags& required, const vk::QueueFlags& avoid, int32_t compute_index, uint32_t min_num_queues) {
++    VK_LOG_DEBUG("ggml_vk_find_queue_family_index()");
++    const uint32_t qfsize = queue_family_props.size();
++
++    // Try with avoid preferences first
++    for (uint32_t i = 0; i < qfsize; i++) {
++        if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required && !(queue_family_props[i].queueFlags & avoid)) {
++            return i;
++        }
++    }
++
++    // Fall back to only required
++    for (size_t i = 0; i < qfsize; i++) {
++        if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required) {
++            return i;
++        }
++    }
++
++    // Fall back to reusing compute queue
++    for (size_t i = 0; i < qfsize; i++) {
++        if (queue_family_props[i].queueCount >= min_num_queues && queue_family_props[i].queueFlags & required) {
++            return i;
++        }
++    }
++
++    // Fall back to ignoring min_num_queries
++    for (size_t i = 0; i < qfsize; i++) {
++        if (queue_family_props[i].queueFlags & required) {
++            return i;
++        }
++    }
++
++    // All commands that are allowed on a queue that supports transfer operations are also allowed on a queue that supports either graphics or compute operations.
++    // Thus, if the capabilities of a queue family include VK_QUEUE_GRAPHICS_BIT or VK_QUEUE_COMPUTE_BIT, then reporting the VK_QUEUE_TRANSFER_BIT capability separately for that queue family is optional.
++    if (compute_index >= 0) {
++        return compute_index;
++    }
++
++    std::cerr << "ggml_vulkan: No suitable queue family index found." << std::endl;
++
++    for(auto &q_family : queue_family_props) {
++        std::cerr << "Queue number: "  + std::to_string(q_family.queueCount) << " flags: " + to_string(q_family.queueFlags) << std::endl;
++    }
++    abort();
++}
++
++static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags, bool transfer_only) {
++    VK_LOG_DEBUG("ggml_vk_create_queue()");
++    std::lock_guard guard(device->mutex);
++
++    q.queue_family_index = queue_family_index;
++    q.transfer_only = transfer_only;
++
++    q.cmd_pool.init(device, &q);
++
++    q.queue = device->device.getQueue(queue_family_index, queue_index);
++
++    q.stage_flags = stage_flags;
++}
++
++static vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_command_pool& p) {
++    vk_context result = std::make_shared();
++    VK_LOG_DEBUG("ggml_vk_create_context(" << result << ")");
++    ctx->gc.contexts.emplace_back(result);
++    result->p = &p;
++    return result;
++}
++
++static vk_context ggml_vk_create_temporary_context(vk_command_pool& p) {
++    vk_context result = std::make_shared();
++    VK_LOG_DEBUG("ggml_vk_create_temporary_context(" << result << ")");
++    result->p = &p;
++    return result;
++}
++
++static vk_semaphore * ggml_vk_create_binary_semaphore(ggml_backend_vk_context * ctx) {
++    VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()");
++    vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eBinary, 0 };
++    vk::SemaphoreCreateInfo ci{};
++    ci.setPNext(&tci);
++    vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci);
++    ctx->gc.semaphores.push_back({ semaphore, 0 });
++    return &ctx->gc.semaphores[ctx->gc.semaphores.size() - 1];
++}
++
++static vk_semaphore * ggml_vk_create_timeline_semaphore(ggml_backend_vk_context * ctx) {
++    VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()");
++    if (ctx->semaphore_idx >= ctx->gc.tl_semaphores.size()) {
++        vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 };
++        vk::SemaphoreCreateInfo ci{};
++        ci.setPNext(&tci);
++        vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci);
++        ctx->gc.tl_semaphores.push_back({ semaphore, 0 });
++    }
++    return &ctx->gc.tl_semaphores[ctx->semaphore_idx++];
++}
++
++static vk::Event ggml_vk_create_event(ggml_backend_vk_context * ctx) {
++    if (ctx->event_idx >= ctx->gc.events.size()) {
++        ctx->gc.events.push_back(ctx->device->device.createEvent({}));
++    }
++    return ctx->gc.events[ctx->event_idx++];
++}
++
++static void ggml_vk_command_pool_cleanup(vk_device& device, vk_command_pool& p) {
++    VK_LOG_DEBUG("ggml_vk_command_pool_cleanup()");
++
++    // Requires command buffers to be done
++    device->device.resetCommandPool(p.pool);
++    p.cmd_buffer_idx = 0;
++}
++
++static void ggml_vk_queue_command_pools_cleanup(vk_device& device) {
++    VK_LOG_DEBUG("ggml_vk_queue_command_pools_cleanup()");
++
++    // Arbitrary frequency to cleanup/reuse command buffers
++    static constexpr uint32_t cleanup_frequency = 10;
++
++    if (device->compute_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) {
++        ggml_vk_command_pool_cleanup(device, device->compute_queue.cmd_pool);
++    }
++    if (device->transfer_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) {
++        ggml_vk_command_pool_cleanup(device, device->transfer_queue.cmd_pool);
++    }
++}
++
++static std::vector ggml_vk_find_memory_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) {
++    std::vector indices;
++
++    for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) {
++        vk::MemoryType memory_type = mem_props->memoryTypes[i];
++        if ((mem_req->memoryTypeBits & ((uint64_t)1 << i)) &&
++            (flags & memory_type.propertyFlags) == flags &&
++            mem_props->memoryHeaps[memory_type.heapIndex].size >= mem_req->size) {
++            indices.push_back(i);
++        }
++    }
++    return indices;
++}
++
++static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list & req_flags_list,
++                                       void *import_ptr = nullptr) {
++    VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")");
++    if (size > device->max_buffer_size) {
++        throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device buffer size limit");
++    }
++
++    vk_buffer buf = std::make_shared();
++
++    if (size == 0) {
++        buf->size = 0;
++        return buf;
++    }
++
++    vk::BufferUsageFlags usage_flags = vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst;
++    vk::MemoryAllocateFlags mem_flags {};
++    if (device->buffer_device_address) {
++        usage_flags |= vk::BufferUsageFlagBits::eShaderDeviceAddress;
++        mem_flags |= vk::MemoryAllocateFlagBits::eDeviceAddress;
++    }
++
++    vk::BufferCreateInfo buffer_create_info{
++        vk::BufferCreateFlags(),
++        size,
++        usage_flags,
++        vk::SharingMode::eExclusive,
++        0,
++        nullptr,
++    };
++
++    vk::ExternalMemoryBufferCreateInfo external_memory_bci;
++    if (import_ptr) {
++        external_memory_bci.handleTypes = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT;
++        buffer_create_info.setPNext(&external_memory_bci);
++    }
++
++    buf->buffer = device->device.createBuffer(buffer_create_info);
++
++    vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer);
++
++    vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
++
++    const vk::MemoryPriorityAllocateInfoEXT mem_priority_info { 1.0f };
++
++    vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags };
++
++    if (device->memory_priority) {
++        mem_flags_info.setPNext(&mem_priority_info);
++    }
++
++    if (import_ptr) {
++        vk::MemoryHostPointerPropertiesEXT host_pointer_props;
++        try {
++            host_pointer_props = device->device.getMemoryHostPointerPropertiesEXT(vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT, import_ptr);
++        } catch (vk::SystemError& e) {
++            GGML_LOG_WARN("ggml_vulkan: Failed getMemoryHostPointerPropertiesEXT (%s)\n", e.what());
++            device->device.destroyBuffer(buf->buffer);
++            return {};
++        }
++        vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
++
++        uint32_t memory_type_idx;
++        vk::MemoryPropertyFlags property_flags = *req_flags_list.begin();
++        for (memory_type_idx = 0; memory_type_idx < 32; ++memory_type_idx) {
++            if (!(host_pointer_props.memoryTypeBits & (1u << memory_type_idx))) {
++                continue;
++            }
++            if (!(mem_req.memoryTypeBits & (1u << memory_type_idx))) {
++                continue;
++            }
++
++            vk::MemoryType memory_type = mem_props.memoryTypes[memory_type_idx];
++            // check for visible+coherent+cached. Other flags (e.g. devicelocal) are allowed
++            if ((memory_type.propertyFlags & property_flags) == property_flags) {
++                property_flags = memory_type.propertyFlags;
++                break;
++            }
++        }
++        if (memory_type_idx == 32) {
++            GGML_LOG_WARN("ggml_vulkan: Memory type for host allocation not found\n");
++            device->device.destroyBuffer(buf->buffer);
++            return {};
++        }
++
++        buf->memory_property_flags = mem_props.memoryTypes[memory_type_idx].propertyFlags;
++        try {
++            vk::ImportMemoryHostPointerInfoEXT import_info;
++            import_info.handleType = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT;
++            import_info.pHostPointer = import_ptr;
++            import_info.setPNext(&mem_flags_info);
++            buf->device_memory = device->device.allocateMemory({ size, memory_type_idx, &import_info });
++        } catch (const vk::SystemError& e) {
++        }
++    } else {
++        for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
++            const auto & req_flags = *it;
++
++            const std::vector memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags);
++
++            if (memory_type_indices.empty()) {
++                continue;
++            }
++            buf->memory_property_flags = req_flags;
++
++            bool done = false;
++
++            for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) {
++                try {
++                    buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info });
++                    done = true;
++                    break;
++                } catch (const vk::SystemError& e) {
++                    // loop and retry
++                    // during last attempt throw the exception
++                    if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) {
++                        device->device.destroyBuffer(buf->buffer);
++                        throw e;
++                    }
++                }
++            }
++
++            if (done) {
++                break;
++            }
++        }
++    }
++
++    if (!buf->device_memory) {
++        device->device.destroyBuffer(buf->buffer);
++        throw vk::OutOfDeviceMemoryError("No suitable memory type found");
++    }
++
++    buf->ptr = nullptr;
++
++    if (import_ptr) {
++        buf->ptr = import_ptr;
++    } else {
++        if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
++            buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
++        }
++    }
++
++    device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0);
++
++    buf->device = device;
++    buf->size = size;
++
++    if (device->buffer_device_address) {
++        const vk::BufferDeviceAddressInfo addressInfo(buf->buffer);
++        buf->bda_addr = device->device.getBufferAddress(addressInfo);
++    }
++
++    device->memory_logger->log_allocation(buf, size);
++
++    return buf;
++}
++
++static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) {
++    try {
++        return ggml_vk_create_buffer(device, size, {req_flags, fallback_flags});
++    } catch (const vk::SystemError& e) {
++        std::cerr << "ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl;
++        std::cerr << "ggml_vulkan: " << e.what() << std::endl;
++        throw e;
++    }
++}
++
++static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {
++    vk_buffer buf;
++    try {
++        if (device->prefer_host_memory) {
++            buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
++                                                       vk::MemoryPropertyFlagBits::eDeviceLocal});
++        } else if (device->uma) {
++            // Fall back to host memory type
++            buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal,
++                                                       vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
++        } else if (device->disable_host_visible_vidmem) {
++            if (device->allow_sysmem_fallback) {
++                buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal,
++                                                           vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
++            } else {
++                buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal});
++            }
++        } else {
++            // use rebar if available, otherwise fallback to device only visible memory
++            if (device->allow_sysmem_fallback) {
++                buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
++                                                           vk::MemoryPropertyFlagBits::eDeviceLocal,
++                                                           vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
++            } else {
++                buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
++                                                           vk::MemoryPropertyFlagBits::eDeviceLocal});
++            }
++        }
++    } catch (const vk::SystemError& e) {
++        std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl;
++        std::cerr << "ggml_vulkan: " << e.what() << std::endl;
++        throw e;
++    }
++
++    return buf;
++}
++
++static void ggml_vk_destroy_buffer(vk_buffer& buf) {
++    if (buf == nullptr) {
++        return;
++    }
++
++    if (buf->device != nullptr) {
++        buf->device->memory_logger->log_deallocation(buf);
++    }
++
++    buf.reset();
++}
++
++static vk_subbuffer ggml_vk_subbuffer(const ggml_backend_vk_context* ctx, const vk_buffer& buf, size_t offset = 0) {
++    return { buf, offset, ggml_vk_get_max_buffer_range(ctx, buf, offset) };
++}
++
++static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subctx) {
++    VK_LOG_DEBUG("ggml_vk_sync_buffers()");
++
++    const bool transfer_queue = subctx->p->q->transfer_only;
++
++    if (ctx) {
++        ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
++    }
++
++    subctx->s->buffer.pipelineBarrier(
++        subctx->p->q->stage_flags,
++        subctx->p->q->stage_flags,
++        {},
++        { {
++          { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) },
++          { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) }
++        } },
++        {},
++        {}
++    );
++}
++
++static void ggml_vk_set_event(vk_context& ctx, vk::Event& event) {
++    VK_LOG_DEBUG("ggml_vk_set_event()");
++
++    ctx->s->buffer.setEvent(
++        event,
++        ctx->p->q->stage_flags
++    );
++}
++
++static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events) {
++    VK_LOG_DEBUG("ggml_vk_wait_events()");
++    if (events.empty()) {
++        return;
++    }
++
++    ctx->s->buffer.waitEvents(
++        events,
++        ctx->p->q->stage_flags,
++        ctx->p->q->stage_flags,
++        {},
++        {},
++        {}
++    );
++}
++
++struct vk_fa_tuning_params {
++    FaCodePath path;
++    uint32_t workgroup_size;
++    uint32_t subgroup_size;
++    uint32_t block_rows;
++    uint32_t block_cols;
++    uint32_t d_split;
++    uint32_t row_split;
++    bool shmem_staging;
++    bool disable_subgroups;
++    uint32_t limit_occupancy_shmem;
++
++    void print() const {
++        std::cerr << "path=" << path << " workgroup_size=" << workgroup_size << " subgroup_size=" << subgroup_size <<
++                     " block_rows=" << block_rows << " block_cols=" << block_cols << " d_split=" << d_split <<
++                     " row_split=" << row_split << " shmem_staging=" << shmem_staging << " disable_subgroups=" << disable_subgroups <<
++                     " limit_occupancy_shmem=" << limit_occupancy_shmem << std::endl;
++    }
++};
++
++static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
++static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
++
++static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
++    GGML_UNUSED(kv_type);
++
++    vk_fa_tuning_params result{};
++    result.path = FA_SCALAR;
++
++    if (device->vendor_id == VK_VENDOR_ID_INTEL) {
++        // Disable subgroup use due to performance issues when enforcing subgroup sizes
++        result.subgroup_size = 32;
++        result.disable_subgroups = true;
++    } else if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN) {
++        result.subgroup_size = n_rows < 4 ? 32 : device->subgroup_size;
++    } else {
++        result.subgroup_size = device->subgroup_size;
++    }
++
++    // Row split splits the workgroup so that synchronization only has to happen within subgroups, which avoids barriers
++    uint32_t row_split_max_hsk = 64;
++    if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && !device->uma) {
++        row_split_max_hsk = n_rows <= 8 ? 64 : 128;
++    }
++    result.row_split = (n_rows < 4 || hsk <= row_split_max_hsk) ? 1 : 4;
++
++    if (result.subgroup_size > 32 && (n_rows < 4 || hsk < (result.row_split == 1 ? 128 : 64))) {
++        result.workgroup_size = result.subgroup_size * 2;
++    } else {
++        result.workgroup_size = result.subgroup_size * 4;
++    }
++
++    const uint32_t D = hsk | hsv;
++
++    const bool reduce_block_rows = D & 8 || n_kv < 1024 || device->vendor_id == VK_VENDOR_ID_INTEL;
++
++    if (n_rows == 1) {
++        result.block_rows = 1;
++        result.block_cols = 64;
++    } else {
++        // row_split 1 means higher register use per row, so block size has to be adjusted
++        if (result.row_split == 1) {
++            result.block_rows = n_rows == 2 ? 2 : ((n_rows <= 4 || reduce_block_rows) ? 4 : 8);
++        } else {
++            result.block_rows = n_rows <= 4 ? 4 : ((n_rows <= 8 || reduce_block_rows) ? 8 : 16);
++        }
++
++        result.block_cols = (D & 8) ? 64 : 32;
++    }
++
++    const uint32_t D_lsb = D ^ (D & (D-1));  // extract lowest set bit
++
++    result.d_split = std::min(std::min(result.subgroup_size, 8u), D_lsb / 4);
++
++    result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
++
++    if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc)) {
++        result.block_rows /= 2;
++    }
++
++    // On AMD RDNA, for small head sizes and big batch size the shader uses few registers, so too many subgroups get scheduled
++    // at once and end up thrashing the cache. Fix this by setting a large (unused) shmem buffer that reduces occupancy.
++    // This targets an occupancy of 4 subgroups per SIMD.
++    if (device->vendor_id == VK_VENDOR_ID_AMD && device->properties.limits.maxComputeSharedMemorySize == 65536) {
++        if (device->architecture != AMD_GCN && n_rows >= 64 && hsk <= 128) {
++            // 30kb target for hsk > 64, 26kb for <= 64 due to smaller workgroup size
++            // Values are guessed, tested on RDNA2
++            result.limit_occupancy_shmem = (hsk <= 64 ? 26 : 30) * 1024 / 4 / 4;
++        } else if (device->architecture == AMD_GCN && n_rows <= 8 && hsk >= 256) {
++            // Same thing for GCN, with an occupancy target of 2 subgroups per SIMD.
++            // Here low-batch FA with large head size is affected.
++            // n_rows < 4 switch because workgroup size switches from 128 to 256 there.
++            result.limit_occupancy_shmem = (n_rows < 4 ? 14 : 26) * 1024 / 4 / 4;
++        }
++    }
++
++    return result;
++}
++
++static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
++    GGML_UNUSED(n_rows);
++    GGML_UNUSED(n_kv);
++    GGML_UNUSED(kv_type);
++    GGML_UNUSED(f32acc);
++
++    vk_fa_tuning_params result{};
++    result.path = FA_COOPMAT1;
++
++    const uint32_t D = hsk | hsv;
++
++    const uint32_t coopmat_block_rows = 16;
++    const uint32_t coopmat_block_cols = 16;
++
++    const uint32_t num_subgroups = 4;
++
++    result.block_rows = coopmat_block_rows;
++    result.block_cols = coopmat_block_cols * num_subgroups;
++    result.row_split = num_subgroups;
++    result.subgroup_size = device->subgroup_size;
++    result.workgroup_size = num_subgroups * result.subgroup_size;
++
++    const uint32_t D_lsb = D ^ (D & (D-1));  // extract lowest set bit
++    result.d_split = std::min(std::min(result.subgroup_size, 8u), D_lsb / 4);
++
++    result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
++
++    return result;
++}
++
++static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
++    GGML_UNUSED(n_kv);
++    GGML_UNUSED(f32acc);
++
++    vk_fa_tuning_params result{};
++    result.path = FA_COOPMAT2;
++
++    const uint32_t D = hsk | hsv;
++
++    const bool small_rows = n_rows < 32;
++
++    if (small_rows) {
++        result.block_rows = 32;
++        result.block_cols = 32;
++    } else if (ggml_is_quantized(kv_type) || hsk >= 256 || hsv >= 256) {
++        result.block_rows = (hsk >= 512 || hsv >= 512) ? 32 : 64;
++        result.block_cols = 32;
++    } else {
++        result.block_rows = 64;
++        result.block_cols = 64;
++    }
++
++    result.subgroup_size = device->subgroup_size;
++    result.workgroup_size = (small_rows && (D % 32) == 0) ? 256 : 128;
++
++    return result;
++}
++
++static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
++    FaCodePath path = device->coopmat2 ? FA_COOPMAT2 :
++                      device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
++
++    if (path == FA_COOPMAT1 && device->architecture == vk_device_architecture::NVIDIA_TURING) {
++        // Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090
++        path = FA_SCALAR;
++    }
++
++    if (path == FA_COOPMAT1) {
++        bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) ||
++                        (!f32acc && device->coopmat_support_16x16x16_f16acc);
++        const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
++        bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc);
++
++        if (!shape_ok || !shmem_ok) {
++            path = FA_SCALAR;
++        }
++    }
++
++    // scalar is faster than coopmat when N==1
++    if (n_rows == 1 && (path == FA_COOPMAT1 || path == FA_COOPMAT2)) {
++        path = FA_SCALAR;
++    }
++
++    switch (path) {
++    case FA_SCALAR:
++        return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
++    case FA_COOPMAT1:
++        return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
++    case FA_COOPMAT2:
++        return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
++    default:
++        throw std::runtime_error("unsupported FaCodePath");
++    }
++}
++
++static vk_fa_pipeline_state get_fa_pipeline_state(const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc,
++                                                  bool use_mask, bool use_mask_opt, bool use_logit_softcap) {
++    uint32_t flags = (use_mask_opt      ? 1 : 0) |
++                     (use_mask          ? 2 : 0) |
++                     (use_logit_softcap ? 4 : 0);
++
++    const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size;
++
++    return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem};
++}
++
++static std::vector get_fa_spec_constants(const vk_fa_pipeline_state& state) {
++    return {state.workgroup_size, state.Br, state.Bc, state.HSK, state.HSV, !state.aligned, state.D_split,
++            state.row_split, state.subgroup_size, state.shmem_staging ? 1u : 0u, state.flags, state.limit_occupancy_shmem};
++}
++
++static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) {
++
++    uint32_t lut_size = 0;
++    switch (src0_type) {
++    case GGML_TYPE_IQ1_S:
++    case GGML_TYPE_IQ1_M:
++        lut_size = 2*2048 + 4*2048;
++        break;
++    case GGML_TYPE_IQ2_XXS:
++        lut_size = 8*256;
++        break;
++    case GGML_TYPE_IQ2_XS:
++        lut_size = 8*512;
++        break;
++    case GGML_TYPE_IQ2_S:
++        lut_size = 8*1024;
++        break;
++    case GGML_TYPE_IQ3_XXS:
++        lut_size = 4*256;
++        break;
++    case GGML_TYPE_IQ3_S:
++        lut_size = 4*512;
++        break;
++    case GGML_TYPE_IQ4_NL:
++    case GGML_TYPE_IQ4_XS:
++    case GGML_TYPE_MXFP4:
++        lut_size = 4*16;
++        break;
++    default:
++        break;
++    }
++
++    // Needs to be kept up to date on shader changes
++    const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1;
++    const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
++    const uint32_t warps = warptile[0] / warptile[10];
++
++    const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
++    const uint32_t mmid_row_ids = mul_mat_id ? (warptile[2] * 2 * sizeof(uint16_t)) : 0;
++    const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
++    const uint32_t ballots_sh = mul_mat_id ? (warps * 4 * sizeof(uint32_t)) : 0;
++
++    const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size + ballots_sh;
++    const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
++
++    VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), "
++                 "mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", supported=" << supported);
++
++    return supported;
++}
++
++struct GpuPipelineConfig {
++    // GPU architecture identifier.
++    // Example: vk_device_architecture::AMD_GCN
++    vk_device_architecture arch;
++
++    // Mapping of pipeline names to their specific subgroup sizes.
++    // Example: {"soft_max_f32", 64}
++    std::unordered_map pipelines;
++
++    // Default subgroup size for this GPU.
++    // Defaults to 0 if not explicitly provided.
++    uint32_t default_subgroup_size = 0;
++};
++
++// Pipeline configuration for RDNA1 GPUs.
++static const std::unordered_map rdna1_pipelines = {
++    {"soft_max", 64}, {"im2col", 64},
++    {"argmax", 64}, {"mul_mat_vec", 64},
++    {"mul_mat_vec_f16", 32}, {"mul_mat_vec_f32_f16", 32}
++};
++
++// Pipeline configuration for RDNA2 GPUs.
++static const std::unordered_map rdna2_pipelines = {
++    {"soft_max", 64}, {"im2col", 64},
++};
++
++static constexpr uint32_t RDNA_DEFAULT_SUBGROUP_SIZE = 32;
++
++// Define configurations for different GPUs.
++static std::vector gpu_pipeline_configs = {
++    {
++        vk_device_architecture::AMD_RDNA1,
++        {
++            rdna1_pipelines,
++        },
++        RDNA_DEFAULT_SUBGROUP_SIZE
++    },
++    {
++        vk_device_architecture::AMD_RDNA2,
++        {
++            rdna2_pipelines,
++        },
++        RDNA_DEFAULT_SUBGROUP_SIZE
++    },
++};
++
++static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_device_architecture &arch) {
++    for (const auto &config : gpu_pipeline_configs) {
++        if (config.arch == arch) {
++            auto pipIt = config.pipelines.find(pipeline_name);
++            if (pipIt != config.pipelines.end()) {
++                return pipIt->second;
++            }
++            std::vector> sorted_pipelines(config.pipelines.begin(), config.pipelines.end());
++            std::sort(sorted_pipelines.begin(), sorted_pipelines.end(),
++                      [](const auto &a, const auto &b) { return a.first.size() > b.first.size(); });
++            for (const auto &entry : sorted_pipelines) {
++                if (pipeline_name.find(entry.first) != std::string::npos) {
++                    return entry.second;
++                }
++            }
++            return config.default_subgroup_size;
++        }
++    }
++    return 0; // If no matching configuration is found
++}
++
++static void ggml_vk_load_shaders(vk_device& device) {
++    VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
++
++    std::lock_guard guard(device->mutex);
++    // some shaders have a minimum subgroup size
++    const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u);
++    const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
++    const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u);
++
++    const uint32_t mul_mat_subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
++    const uint32_t mul_mat_subgroup_size_8 = std::max(mul_mat_subgroup_size, 8u);
++    const uint32_t mul_mat_subgroup_size_16 = std::max(mul_mat_subgroup_size, 16u);
++    const uint32_t mul_mat_subgroup_size_32 = std::max(mul_mat_subgroup_size, 32u);
++
++    const bool subgroup_min_size_16 = (!device->subgroup_size_control && device->subgroup_size >= 16) ||
++                                      (device->subgroup_size_control && device->subgroup_max_size >= 16);
++
++    // mulmat
++    std::vector l_warptile, m_warptile, s_warptile,
++                          l_warptile_id, m_warptile_id, s_warptile_id,
++                          l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
++                          l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
++                          l_warptile_mmq_int_k, m_warptile_mmq_int_k, s_warptile_mmq_int_k,
++                          l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
++                          l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid,
++                          l_warptile_mmqid_int, m_warptile_mmqid_int, s_warptile_mmqid_int,
++                          l_warptile_mmqid_int_k, m_warptile_mmqid_int_k, s_warptile_mmqid_int_k;
++    std::array l_wg_denoms, m_wg_denoms, s_wg_denoms,
++                            l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms,
++                            l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k,
++                            l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms;
++
++    uint32_t l_align, m_align, s_align;
++    if (device->coopmat2) {
++        // spec constants and tile sizes for non-quant matmul/matmul_id
++        l_warptile = { 256, 128, 256, 64, 1 };
++        m_warptile = { 256, 128, 128, 64, 0 };
++        s_warptile = { 128,  64,  64, 64, 0 };
++        l_wg_denoms = {128, 256, 1 };
++        m_wg_denoms = {128, 128, 1 };
++        s_wg_denoms = { 64,  64, 1 };
++
++        // spec constants and tile sizes for quant matmul (non-Qi_K)
++        l_warptile_mmq = { 256, 128, 256, 64, 1 };
++        m_warptile_mmq = { 256, 128, 128, 64, 1 };
++        s_warptile_mmq = { 256, 32,  64, 128, 0 };
++        l_mmq_wg_denoms = { 128, 256, 1 };
++        m_mmq_wg_denoms = { 128, 128, 1 };
++        s_mmq_wg_denoms = { 32,  64,  1 };
++
++        // spec constants and tile sizes for quant matmul (Qi_K)
++        l_warptile_mmq_k = { 256, 128, 256, 64, 1 };
++        m_warptile_mmq_k = { 256, 128, 128, 64, 1 };
++        s_warptile_mmq_k = { 256, 32,  64, 128, 0 };
++        l_mmq_wg_denoms_k = { 128, 256, 1 };
++        m_mmq_wg_denoms_k = { 128, 128, 1 };
++        s_mmq_wg_denoms_k = { 32,  64,  1 };
++
++        // spec constants and tile sizes for quant matmul_id
++        l_warptile_mmqid = { 256, 128, 128, 32, 1, device->subgroup_size };
++        m_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size };
++        s_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size };
++        l_mmqid_wg_denoms = { 128, 128, 1 };
++        m_mmqid_wg_denoms = { 128, 64, 1 };
++        s_mmqid_wg_denoms = { 128, 64, 1 };
++
++        l_align = 128;
++        m_align =  64;
++        s_align =  32;
++    } else {
++        // Matrix cores require different warp group sizes
++        const uint32_t tm_l = device->coopmat_support ? device->coopmat_m : 4;
++        const uint32_t tm_m = device->coopmat_support ? device->coopmat_m : 4;
++        const uint32_t tm_s = device->coopmat_support ? device->coopmat_m : 2;
++        const uint32_t tn_l = device->coopmat_support ? device->coopmat_n : 4;
++        const uint32_t tn_m = device->coopmat_support ? device->coopmat_n : 2;
++        const uint32_t tn_s = device->coopmat_support ? device->coopmat_n : 2;
++        const uint32_t tk_l = device->coopmat_support ? device->coopmat_k : 1;
++        const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;
++        const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;
++
++        const uint32_t s_warptile_wm = device->subgroup_size == 8 ? 8 : 32;
++
++        l_warptile = { 128,             128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
++        m_warptile = { 128,              64,  64, 16, subgroup_size_8,     32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
++        s_warptile = { subgroup_size_32, 32,  32, 16, s_warptile_wm,       32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
++
++        l_warptile_mmq = { 128,             128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
++        m_warptile_mmq = { 128,              64,  64, 32, subgroup_size_8,     32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
++        s_warptile_mmq = { subgroup_size_32, 32,  32, 32, s_warptile_wm,       32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
++
++        // Integer MMQ has a smaller shared memory profile, but heavier register use
++        l_warptile_mmq_int = { 128,             128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
++        m_warptile_mmq_int = { 128,              64,  64, 32, subgroup_size_8,     32, 2, 2, 2, 1, subgroup_size_8 };
++        s_warptile_mmq_int = { subgroup_size_32, 32,  32, 32, s_warptile_wm,       32, 2, 2, 1, 1, subgroup_size_8 };
++
++        // K-quants use even more registers, mitigate by setting WMITER to 1
++        l_warptile_mmq_int_k = { 128,               128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
++        m_warptile_mmq_int_k = { 128,                64,  64, 32, subgroup_size_8,     32, 1, 2, 2, 1, subgroup_size_8 };
++        s_warptile_mmq_int_k = { subgroup_size_32,   32,  32, 32, s_warptile_wm,       32, 1, 2, 1, 1, subgroup_size_8 };
++
++        l_warptile_id = { 128,                      128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
++        m_warptile_id = { 128,                       64,  64, 16, mul_mat_subgroup_size_16,     32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
++        s_warptile_id = { mul_mat_subgroup_size_16,  32,  32, 16, s_warptile_wm,                32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
++
++        l_warptile_mmqid = { 128,                       128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 };
++        m_warptile_mmqid = { 128,                        64,  64, 32, mul_mat_subgroup_size_8,     32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
++        s_warptile_mmqid = { mul_mat_subgroup_size_32,   32,  32, 32, s_warptile_wm,               32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
++
++        l_warptile_mmqid_int = { 128,                       128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
++        m_warptile_mmqid_int = { 128,                        64,  64, 32, mul_mat_subgroup_size_8,     32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
++        s_warptile_mmqid_int = { mul_mat_subgroup_size_32,   32,  32, 32, s_warptile_wm,               32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
++
++        l_warptile_mmqid_int_k = { 128,                     128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
++        m_warptile_mmqid_int_k = { 128,                      64,  64, 32, mul_mat_subgroup_size_16,     32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
++        s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32,  32, 32, s_warptile_wm,                32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
++
++        // chip specific tuning
++        if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
++            m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
++            m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
++        } else if (device->vendor_id == VK_VENDOR_ID_AMD && device->coopmat_support && device->driver_id != vk::DriverId::eAmdProprietary) {
++            // This is intentionally using tx_m values, slight performance increase
++            l_warptile = { 256, 128, 128, 16, subgroup_size_8, 64, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
++            l_warptile_mmq = l_warptile_mmq_int = { 256, 128, 128, 32, subgroup_size_8, 64, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
++            l_warptile_mmq_int_k = { 256, 128, 128, 32, subgroup_size_16, 64, 1, 4, 2, 1, subgroup_size_16 };
++        } else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->coopmat_support && device->architecture == INTEL_XE2) {
++            // Xe2/Xe3 with coopmat enabled - warptile performance tuning
++            l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
++            l_warptile_mmq = { 512, 128, 128, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
++        }
++
++        l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
++        m_mmq_wg_denoms = m_wg_denoms = { 64,  64, 1 };
++        s_mmq_wg_denoms = s_wg_denoms = { 32,  32, 1 };
++        l_align = 128;
++        m_align =  64;
++        s_align =  32;
++
++        for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) {
++            ggml_type t = (ggml_type)i;
++            // Disable medium and large matrix multiplication if not enough shared memory is available
++            // Check mmq warptiles as the largest configuration
++            // Throw an error if not enough for any matrix multiplication is available
++            if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t)) {
++                std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl;
++                throw std::runtime_error("Shared memory size too small for matrix multiplication.");
++            } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t)) {
++                device->mul_mat_m[i] = false;
++                device->mul_mat_l[i] = false;
++            } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t)) {
++                device->mul_mat_l[i] = false;
++            }
++
++            // Disable mul_mat_id if not enough shared memory is available
++            if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmqid, true, t)) {
++                device->mul_mat_id_s[i] = false;
++                device->mul_mat_id_m[i] = false;
++                device->mul_mat_id_l[i] = false;
++            } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmqid, true, t)) {
++                device->mul_mat_id_m[i] = false;
++                device->mul_mat_id_l[i] = false;
++            } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) {
++                device->mul_mat_id_l[i] = false;
++            }
++        }
++    }
++
++    if (!device->pipeline_matmul_f32) {
++        device->pipeline_matmul_f32 = std::make_shared();
++    }
++    if (!device->pipeline_matmul_f32_f16) {
++        device->pipeline_matmul_f32_f16 = std::make_shared();
++    }
++    if (!device->pipeline_matmul_id_f32) {
++        device->pipeline_matmul_id_f32 = std::make_shared();
++    }
++    if (!device->pipeline_matmul_bf16) {
++        device->pipeline_matmul_bf16 = std::make_shared();
++    }
++    if (!device->pipeline_matmul_id_bf16) {
++        device->pipeline_matmul_id_bf16 = std::make_shared();
++    }
++
++    std::vector> compiles;
++    auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& base_pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint,
++                                              uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants,
++                                              uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
++
++        if (!require_full_subgroups && required_subgroup_size == 0) {
++            required_subgroup_size = get_subgroup_size(name, device->architecture);
++        }
++
++        vk_pipeline *ptr = &base_pipeline;
++
++        int num_pipelines = 1;
++#if defined(VK_EXT_shader_64bit_indexing)
++        if (device->shader_64b_indexing) {
++            num_pipelines = 2;
++        }
++#endif
++        for (int i = 0; i < num_pipelines; ++i, ptr = &(*ptr)->next) {
++            vk_pipeline &pipeline = *ptr;
++            if (!pipeline) {
++                pipeline = std::make_shared();
++            }
++            if (!pipeline->initialized) {
++                pipeline->name = name;
++                pipeline->parameter_count = parameter_count;
++                pipeline->push_constant_size = push_constant_size;
++                pipeline->wg_denoms = wg_denoms;
++                pipeline->align = align;
++                pipeline->initialized = true;
++#if defined(VK_EXT_shader_64bit_indexing)
++                pipeline->is_64b_indexing = (i == 1);
++#endif
++            }
++
++            if (!pipeline->needed || pipeline->compiled) {
++                continue;
++            }
++            // TODO: We're no longer benefitting from the async compiles (shaders are
++            // compiled individually, as needed) and this complexity can be removed.
++            {
++                // wait until fewer than N compiles are in progress
++                uint32_t N = std::max(1u, std::thread::hardware_concurrency());
++                std::unique_lock guard(compile_count_mutex);
++                while (compile_count >= N) {
++                    compile_count_cond.wait(guard);
++                }
++                compile_count++;
++            }
++
++            compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
++                                          parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
++        }
++    };
++
++    auto const &ggml_vk_create_pipeline2 = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const char *entrypoint,
++                                              uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants,
++                                              uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
++        return ggml_vk_create_pipeline(device, pipeline, name.c_str(), spv_size, spv_data, entrypoint,
++                                       parameter_count, push_constant_size, wg_denoms, specialization_constants,
++                                       align, disable_robustness, require_full_subgroups, required_subgroup_size);
++    };
++
++#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
++        for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \
++            FaCodePath path = fa.first.path; \
++            uint32_t Br = fa.first.Br; \
++            uint32_t Bc = fa.first.Bc; \
++            bool aligned = fa.first.aligned; \
++            bool f32acc = fa.first.f32acc; \
++            uint32_t fa_sgs = fa.first.subgroup_size; \
++            bool fa_ds = fa.first.subgroup_size == 0; \
++            if (path == FAPATH) { \
++                if (aligned) { \
++                    if (f32acc) { \
++                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0));     \
++                    } else { \
++                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0));     \
++                    } \
++                } else { \
++                    if (f32acc) { \
++                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1,  true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0));     \
++                    } else { \
++                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1,  true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0));     \
++                    } \
++                } \
++            } \
++        }
++
++    if (device->flash_attention_fp16) {
++        CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
++        CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
++        CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
++        CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
++    } else {
++        CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
++        CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
++        CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
++        CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
++    }
++#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
++    if (device->coopmat1_fa_support) {
++        CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
++        CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
++        CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
++        CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
++    }
++#endif
++#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
++    if (device->coopmat2) {
++        CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2)
++        CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
++        CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
++        CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
++        CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2)
++        CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2)
++        CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2)
++        CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2)
++    }
++#endif
++#undef CREATE_FA
++
++    const int mul_mat_id_param_count = 5;
++
++#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
++    if (device->coopmat2) {
++
++        // Create 6 variants, {s,m,l}x{unaligned,aligned}
++#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
++        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, true);   \
++        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, true);   \
++        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, true);   \
++        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, true);   \
++        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, true);   \
++        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, true);   \
++
++        // Create 2 variants, {f16,f32} accumulator
++#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
++        CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT)   \
++        CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT)   \
++
++        CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
++#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
++        if (device->coopmat_bf16_support) {
++            CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
++        }
++#endif
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0], matmul_q4_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1], matmul_q4_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0], matmul_q5_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1], matmul_q5_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0], matmul_q8_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K], matmul_q2_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K], matmul_q3_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K], matmul_q4_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K], matmul_q5_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K], matmul_q6_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_S],   matmul_iq1_s_f16,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_M],   matmul_iq1_m_f16,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS],  matmul_iq2_xs_f16,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S],   matmul_iq2_s_f16,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S],   matmul_iq3_s_f16,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS],  matmul_iq4_xs_f16,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL],  matmul_iq4_nl_f16,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4],   matmul_mxfp4_f16,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
++
++        GGML_ASSERT(device->subgroup_ballot);
++
++        CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 5)
++#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
++        if (device->coopmat_bf16_support) {
++            CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 5)
++        }
++#endif
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S],   matmul_id_subgroup_iq1_s_f16,   mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M],   matmul_id_subgroup_iq1_m_f16,   mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS],  matmul_id_subgroup_iq2_xs_f16,  mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S],   matmul_id_subgroup_iq2_s_f16,   mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S],   matmul_id_subgroup_iq3_s_f16,   mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS],  matmul_id_subgroup_iq4_xs_f16,  mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL],  matmul_id_subgroup_iq4_nl_f16,  mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
++        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4],   matmul_id_subgroup_mxfp4_f16,   mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
++#undef CREATE_MM
++#undef CREATE_MM2
++    } else
++#endif  // defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
++#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
++    if (device->coopmat_support) {
++        // Create 6 variants, {s,m,l}x{unaligned,aligned}
++#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
++        if (device->mul_mat ## ID ## _l[TYPE]) \
++            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true);   \
++        if (device->mul_mat ## ID ## _m[TYPE]) \
++            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true);   \
++        if (device->mul_mat ## ID ## _s[TYPE]) \
++            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true);   \
++        if (device->mul_mat ## ID ## _l[TYPE]) \
++            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true);   \
++        if (device->mul_mat ## ID ## _m[TYPE]) \
++            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true);   \
++        if (device->mul_mat ## ID ## _s[TYPE]) \
++            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true);   \
++
++        // Create 2 variants, {f16,f32} accumulator
++#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
++        if (device->coopmat_acc_f16_support) { \
++            CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
++        } \
++        if (device->coopmat_acc_f32_support) { \
++            CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
++        } \
++
++        CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
++        CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
++        CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
++        CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
++#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
++        if (device->coopmat_bf16_support) {
++            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, )
++        }
++#endif
++
++        if (device->coopmat_acc_f16_support) {
++            CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++
++            CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM2(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S],   matmul_iq1_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM2(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M],   matmul_iq1_m_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM2(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS],  matmul_iq2_xs_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM2(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S],   matmul_iq2_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM2(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S],   matmul_iq3_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM2(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS],  matmul_iq4_xs_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM2(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL],  matmul_iq4_nl_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM2(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4],   matmul_mxfp4_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++        } else {
++            CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++
++            CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc,   matmul_iq1_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc,   matmul_iq1_m_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc,  matmul_iq2_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc,   matmul_iq2_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc,   matmul_iq3_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc,  matmul_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc,  matmul_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++            CREATE_MM(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc,   matmul_mxfp4_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
++        }
++
++        GGML_ASSERT(device->subgroup_ballot);
++
++        CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
++        CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
++        CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
++#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
++        if (device->coopmat_bf16_support) {
++            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
++        }
++#endif
++
++        CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
++        CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
++        CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
++        CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
++        CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
++        CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
++        CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
++        CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
++        CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
++        CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
++        CREATE_MM2(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S],   matmul_id_subgroup_iq1_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
++        CREATE_MM2(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M],   matmul_id_subgroup_iq1_m_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
++        CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
++        CREATE_MM2(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS],  matmul_id_subgroup_iq2_xs_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
++        CREATE_MM2(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S],   matmul_id_subgroup_iq2_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
++        CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
++        CREATE_MM2(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S],   matmul_id_subgroup_iq3_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
++        CREATE_MM2(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS],  matmul_id_subgroup_iq4_xs_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
++        CREATE_MM2(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL],  matmul_id_subgroup_iq4_nl_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
++        CREATE_MM2(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4],   matmul_id_subgroup_mxfp4_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
++#undef CREATE_MM2
++#undef CREATE_MM
++    } else
++#endif  // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
++    if (device->fp16) {
++        // Create 6 variants, {s,m,l}x{unaligned,aligned}
++#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
++        if (device->mul_mat ## ID ## _l[TYPE]) \
++            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
++        if (device->mul_mat ## ID ## _m[TYPE]) \
++            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
++        if (device->mul_mat ## ID ## _s[TYPE]) \
++            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
++        if (device->mul_mat ## ID ## _l[TYPE]) \
++            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
++        if (device->mul_mat ## ID ## _m[TYPE]) \
++            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
++        if (device->mul_mat ## ID ## _s[TYPE]) \
++            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
++
++#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
++        if (device->mul_mat ## ID ## _l[TYPE]) { \
++            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC        "_l", NAMELC ## _len,        NAMELC ##  _data,        "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
++        } \
++        if (device->mul_mat ## ID ## _m[TYPE]) { \
++            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC        "_m", NAMELC ## _len,        NAMELC ##  _data,        "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
++        } \
++        if (device->mul_mat ## ID ## _s[TYPE]) { \
++            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC        "_s", NAMELC ## _len,        NAMELC ##  _data,        "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
++        } \
++
++        // Create 2 variants, {f16,f32} accumulator
++#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
++        CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
++        CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
++
++        CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
++
++        CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
++
++        CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++
++        CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM2(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S],   matmul_iq1_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM2(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M],   matmul_iq1_m_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM2(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS],  matmul_iq2_xs_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM2(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S],   matmul_iq2_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM2(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S],   matmul_iq3_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM2(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS],  matmul_iq4_xs_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM2(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL],  matmul_iq4_nl_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM2(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4],   matmul_mxfp4_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++
++#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
++        if (device->integer_dot_product) {
++            CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
++            CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
++            CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
++            CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
++            CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
++
++            CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_MXFP4], matmul_mxfp4_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
++
++            CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
++            CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K], matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
++            CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K], matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
++            CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K], matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
++            CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K], matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
++        }
++#endif
++
++        if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
++            CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
++            CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
++            CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
++            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
++
++            CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM2(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S],   matmul_id_subgroup_iq1_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM2(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M],   matmul_id_subgroup_iq1_m_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM2(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS],  matmul_id_subgroup_iq2_xs_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM2(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S],   matmul_id_subgroup_iq2_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM2(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S],   matmul_id_subgroup_iq3_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM2(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS],  matmul_id_subgroup_iq4_xs_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM2(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL],  matmul_id_subgroup_iq4_nl_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM2(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4],   matmul_id_subgroup_mxfp4_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++
++#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
++            if (device->integer_dot_product) {
++                CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++                CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++                CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++                CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++                CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++
++                CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++
++                CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
++                CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
++                CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
++                CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
++                CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
++            }
++#endif
++        } else {
++            CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++
++            CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM2(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S],   matmul_id_iq1_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM2(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M],   matmul_id_iq1_m_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM2(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS],  matmul_id_iq2_xs_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM2(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S],   matmul_id_iq2_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM2(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S],   matmul_id_iq3_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM2(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS],  matmul_id_iq4_xs_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM2(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL],  matmul_id_iq4_nl_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM2(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4],   matmul_id_mxfp4_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++
++#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
++            if (device->integer_dot_product) {
++                CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++                CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++                CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++                CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++                CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++
++                CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++
++                CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++                CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++                CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++                CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++                CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            }
++#endif
++        }
++#undef CREATE_MM2
++#undef CREATE_MMQ
++#undef CREATE_MM
++    } else {
++        // Create 6 variants, {s,m,l}x{unaligned,aligned}
++#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
++        if (device->mul_mat ## ID ## _l[TYPE]) \
++            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
++        if (device->mul_mat ## ID ## _m[TYPE]) \
++            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
++        if (device->mul_mat ## ID ## _s[TYPE]) \
++            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
++        if (device->mul_mat ## ID ## _l[TYPE]) \
++            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
++        if (device->mul_mat ## ID ## _m[TYPE]) \
++            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
++        if (device->mul_mat ## ID ## _s[TYPE]) \
++            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
++
++#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
++        if (device->mul_mat ## ID ## _l[TYPE]) \
++            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1);   \
++        if (device->mul_mat ## ID ## _m[TYPE]) \
++            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1);   \
++        if (device->mul_mat ## ID ## _s[TYPE]) \
++            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1);   \
++
++        CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
++
++        CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
++
++        CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++
++        CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc,   matmul_iq1_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc,   matmul_iq1_m_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc,  matmul_iq2_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc,   matmul_iq2_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc,   matmul_iq3_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc,  matmul_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc,  matmul_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc,   matmul_mxfp4_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
++
++#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
++        if (device->integer_dot_product) {
++            CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
++            CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
++            CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
++            CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
++            CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
++
++            CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
++            CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
++            CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
++            CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
++            CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
++        }
++#endif
++
++        if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
++            CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
++            CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
++            CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
++            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
++
++            CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc,   matmul_id_subgroup_iq1_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc,   matmul_id_subgroup_iq1_m_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc,  matmul_id_subgroup_iq2_xs_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc,   matmul_id_subgroup_iq2_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc,   matmul_id_subgroup_iq3_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc,  matmul_id_subgroup_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc,  matmul_id_subgroup_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++            CREATE_MM(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc,   matmul_id_subgroup_mxfp4_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
++        } else {
++            CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++
++            CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc,   matmul_id_iq1_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc,   matmul_id_iq1_m_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc,  matmul_id_iq2_xs_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc,   matmul_id_iq2_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc,   matmul_id_iq3_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc,  matmul_id_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc,  matmul_id_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++            CREATE_MM(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc,   matmul_id_mxfp4_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++        }
++    }
++    // reusing CREATE_MM from the fp32 path
++    if ((device->coopmat2 || device->coopmat_support)
++#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
++        && !device->coopmat_bf16_support
++#endif
++        ) {
++        const uint32_t s_warptile_wm = device->subgroup_size == 8 ? 8 : 32;
++
++        // use scalar tile sizes
++        l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
++        m_warptile = { 128,  64,  64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 };
++        s_warptile = { subgroup_size_32, 32, 32, 16, s_warptile_wm, 32, 2, 2, 2, 1, subgroup_size_8 };
++
++        l_wg_denoms = {128, 128, 1 };
++        m_wg_denoms = { 64,  64, 1 };
++        s_wg_denoms = { 32,  32, 1 };
++
++        if (device->vendor_id == VK_VENDOR_ID_INTEL && device->architecture == INTEL_XE2) {
++            // Xe2/Xe3 - bf16 warptile performance tuning
++            l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, 4, 4, 1, subgroup_size_8 };
++        }
++
++        CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
++        CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
++    }
++#undef CREATE_MM
++
++    // mul mat vec
++
++    // the number of rows computed per shader depends on GPU model and quant
++    uint32_t rm_stdq = 1;
++    uint32_t rm_kq = 2;
++    uint32_t rm_stdq_int = 1;
++    uint32_t rm_kq_int = 1;
++    auto const &rm_iq_int = [](uint32_t i) { return i == 0 ? 8u : 4u; };
++    if (device->vendor_id == VK_VENDOR_ID_AMD) {
++        if (device->architecture == AMD_GCN) {
++            rm_stdq = 2;
++            rm_kq = 4;
++            rm_stdq_int = 4;
++        }
++    } else if (device->vendor_id == VK_VENDOR_ID_INTEL) {
++        rm_stdq = 2;
++        rm_stdq_int = 2;
++    }
++    uint32_t rm_iq = 2 * rm_kq;
++
++    const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN;
++    // Ensure a subgroup size >= 16 is available
++    const bool use_subgroups16 = use_subgroups && subgroup_min_size_16;
++
++    const uint32_t subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control && device->subgroup_min_size <= 16 && device->subgroup_max_size >= 16) ? 16 : device->subgroup_size;
++    const uint32_t subgroup_size16 = std::max(subgroup_size, 16u);
++
++    const uint32_t force_subgroup_size = use_subgroups ? subgroup_size : 0;
++    const uint32_t force_subgroup_size16 = use_subgroups16 ? subgroup_size16 : 0;
++    static constexpr uint32_t mul_mat_vec_num_bindings = 5;
++    static constexpr uint32_t mul_mat_vec_id_num_bindings = 6;
++
++    for (uint32_t w = 0; w < DMMV_WG_SIZE_COUNT; ++w) {
++        const uint32_t wg_size_subgroup   = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size : (subgroup_size * 4);
++        const uint32_t wg_size_subgroup16 = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size16 : (subgroup_size16 * 4);
++
++        const shader_reduction_mode reduc = (use_subgroups && w == DMMV_WG_SIZE_SUBGROUP) ? SHADER_REDUCTION_MODE_SUBGROUP :
++                                            (use_subgroups && w == DMMV_WG_SIZE_LARGE) ? SHADER_REDUCTION_MODE_HYBRID :
++                                            SHADER_REDUCTION_MODE_SHMEM;
++
++        const shader_reduction_mode reduc16 = (use_subgroups16 && w == DMMV_WG_SIZE_SUBGROUP) ? SHADER_REDUCTION_MODE_SUBGROUP :
++                                              (use_subgroups16 && w == DMMV_WG_SIZE_LARGE) ? SHADER_REDUCTION_MODE_HYBRID :
++                                              SHADER_REDUCTION_MODE_SHMEM;
++
++        for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32",  arr_dmmv_f32_f32_f32_len[reduc],  arr_dmmv_f32_f32_f32_data[reduc],  "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32",  arr_dmmv_f16_f32_f32_len[reduc],  arr_dmmv_f16_f32_f32_data[reduc],  "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32", arr_dmmv_bf16_f32_f32_len[reduc], arr_dmmv_bf16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32", arr_dmmv_q4_0_f32_f32_len[reduc], arr_dmmv_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32", arr_dmmv_q4_1_f32_f32_len[reduc], arr_dmmv_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32", arr_dmmv_q5_0_f32_f32_len[reduc], arr_dmmv_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32", arr_dmmv_q5_1_f32_f32_len[reduc], arr_dmmv_q5_1_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32", arr_dmmv_q8_0_f32_f32_len[reduc], arr_dmmv_q8_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32", arr_dmmv_q2_k_f32_f32_len[reduc16], arr_dmmv_q2_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32", arr_dmmv_q3_k_f32_f32_len[reduc16], arr_dmmv_q3_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32", arr_dmmv_q4_k_f32_f32_len[reduc16], arr_dmmv_q4_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32", arr_dmmv_q5_k_f32_f32_len[reduc16], arr_dmmv_q5_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32", arr_dmmv_q6_k_f32_f32_len[reduc16], arr_dmmv_q6_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_S][i],   "mul_mat_vec_iq1_s_f32_f32",   arr_dmmv_iq1_s_f32_f32_len[reduc16],   arr_dmmv_iq1_s_f32_f32_data[reduc16],   "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_M][i],   "mul_mat_vec_iq1_m_f32_f32",   arr_dmmv_iq1_m_f32_f32_len[reduc16],   arr_dmmv_iq1_m_f32_f32_data[reduc16],   "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32", arr_dmmv_iq2_xxs_f32_f32_len[reduc16], arr_dmmv_iq2_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XS][i],  "mul_mat_vec_iq2_xs_f32_f32",  arr_dmmv_iq2_xs_f32_f32_len[reduc16],  arr_dmmv_iq2_xs_f32_f32_data[reduc16],  "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_S][i],   "mul_mat_vec_iq2_s_f32_f32",   arr_dmmv_iq2_s_f32_f32_len[reduc16],   arr_dmmv_iq2_s_f32_f32_data[reduc16],   "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32", arr_dmmv_iq3_xxs_f32_f32_len[reduc16], arr_dmmv_iq3_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_S][i],   "mul_mat_vec_iq3_s_f32_f32",   arr_dmmv_iq3_s_f32_f32_len[reduc16],   arr_dmmv_iq3_s_f32_f32_data[reduc16],   "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i],  "mul_mat_vec_iq4_xs_f32_f32",  arr_dmmv_iq4_xs_f32_f32_len[reduc16],  arr_dmmv_iq4_xs_f32_f32_data[reduc16],  "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i],  "mul_mat_vec_iq4_nl_f32_f32",  arr_dmmv_iq4_nl_f32_f32_len[reduc16],  arr_dmmv_iq4_nl_f32_f32_data[reduc16],  "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i],   "mul_mat_vec_mxfp4_f32_f32",   arr_dmmv_mxfp4_f32_f32_len[reduc16],   arr_dmmv_mxfp4_f32_f32_data[reduc16],   "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32",  arr_dmmv_f32_f16_f32_len[reduc],  arr_dmmv_f32_f16_f32_data[reduc],  "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32",  arr_dmmv_f16_f16_f32_len[reduc],  arr_dmmv_f16_f16_f32_data[reduc],  "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32", arr_dmmv_bf16_f16_f32_len[reduc], arr_dmmv_bf16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32", arr_dmmv_q4_0_f16_f32_len[reduc], arr_dmmv_q4_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32", arr_dmmv_q4_1_f16_f32_len[reduc], arr_dmmv_q4_1_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32", arr_dmmv_q5_0_f16_f32_len[reduc], arr_dmmv_q5_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32", arr_dmmv_q5_1_f16_f32_len[reduc], arr_dmmv_q5_1_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32", arr_dmmv_q8_0_f16_f32_len[reduc], arr_dmmv_q8_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32", arr_dmmv_q2_k_f16_f32_len[reduc16], arr_dmmv_q2_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32", arr_dmmv_q3_k_f16_f32_len[reduc16], arr_dmmv_q3_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32", arr_dmmv_q4_k_f16_f32_len[reduc16], arr_dmmv_q4_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32", arr_dmmv_q5_k_f16_f32_len[reduc16], arr_dmmv_q5_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32", arr_dmmv_q6_k_f16_f32_len[reduc16], arr_dmmv_q6_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_S][i],   "mul_mat_vec_iq1_s_f16_f32",   arr_dmmv_iq1_s_f16_f32_len[reduc16],   arr_dmmv_iq1_s_f16_f32_data[reduc16],   "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_M][i],   "mul_mat_vec_iq1_m_f16_f32",   arr_dmmv_iq1_m_f16_f32_len[reduc16],   arr_dmmv_iq1_m_f16_f32_data[reduc16],   "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32", arr_dmmv_iq2_xxs_f16_f32_len[reduc16], arr_dmmv_iq2_xxs_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XS][i],  "mul_mat_vec_iq2_xs_f16_f32",  arr_dmmv_iq2_xs_f16_f32_len[reduc16],  arr_dmmv_iq2_xs_f16_f32_data[reduc16],  "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_S][i],   "mul_mat_vec_iq2_s_f16_f32",   arr_dmmv_iq2_s_f16_f32_len[reduc16],   arr_dmmv_iq2_s_f16_f32_data[reduc16],   "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32", arr_dmmv_iq3_xxs_f16_f32_len[reduc16], arr_dmmv_iq3_xxs_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_S][i],   "mul_mat_vec_iq3_s_f16_f32",   arr_dmmv_iq3_s_f16_f32_len[reduc16],   arr_dmmv_iq3_s_f16_f32_data[reduc16],   "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i],  "mul_mat_vec_iq4_xs_f16_f32",  arr_dmmv_iq4_xs_f16_f32_len[reduc16],  arr_dmmv_iq4_xs_f16_f32_data[reduc16],  "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i],  "mul_mat_vec_iq4_nl_f16_f32",  arr_dmmv_iq4_nl_f16_f32_len[reduc16],  arr_dmmv_iq4_nl_f16_f32_data[reduc16],  "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i],   "mul_mat_vec_mxfp4_f16_f32",   arr_dmmv_mxfp4_f16_f32_len[reduc16],   arr_dmmv_mxfp4_f16_f32_data[reduc16],   "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
++
++#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
++            if (device->integer_dot_product) {
++                const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
++                const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4);
++
++                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
++                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
++                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
++                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
++                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
++
++                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_q8_1_f32", arr_dmmv_mxfp4_q8_1_f32_len[reduc], arr_dmmv_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
++
++                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_q8_1_f32", arr_dmmv_q2_k_q8_1_f32_len[reduc], arr_dmmv_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
++                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_q8_1_f32", arr_dmmv_q3_k_q8_1_f32_len[reduc], arr_dmmv_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
++                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_q8_1_f32", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
++                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_q8_1_f32", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
++                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_q8_1_f32", arr_dmmv_q6_k_q8_1_f32_len[reduc], arr_dmmv_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
++
++                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_q8_1_f32", arr_dmmv_iq1_s_q8_1_f32_len[reduc], arr_dmmv_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int);
++                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_q8_1_f32", arr_dmmv_iq1_m_q8_1_f32_len[reduc], arr_dmmv_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int);
++
++            }
++#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
++        }
++
++        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32",        arr_dmmv_id_f32_f32_f32_len[reduc],     arr_dmmv_id_f32_f32_f32_data[reduc],     "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {wg_size_subgroup, 1}, 1, false, use_subgroups, force_subgroup_size);
++        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32",        arr_dmmv_id_f16_f32_f32_len[reduc],     arr_dmmv_id_f16_f32_f32_data[reduc],     "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
++        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32",       arr_dmmv_id_bf16_f32_f32_len[reduc],    arr_dmmv_id_bf16_f32_f32_data[reduc],    "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
++        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32",       arr_dmmv_id_q4_0_f32_f32_len[reduc],    arr_dmmv_id_q4_0_f32_f32_data[reduc],    "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
++        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32",       arr_dmmv_id_q4_1_f32_f32_len[reduc],    arr_dmmv_id_q4_1_f32_f32_data[reduc],    "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
++        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32",       arr_dmmv_id_q5_0_f32_f32_len[reduc],    arr_dmmv_id_q5_0_f32_f32_data[reduc],    "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
++        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32",       arr_dmmv_id_q5_1_f32_f32_len[reduc],    arr_dmmv_id_q5_1_f32_f32_data[reduc],    "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
++        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32",       arr_dmmv_id_q8_0_f32_f32_len[reduc],    arr_dmmv_id_q8_0_f32_f32_data[reduc],    "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
++        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32",       arr_dmmv_id_q2_k_f32_f32_len[reduc16],    arr_dmmv_id_q2_k_f32_f32_data[reduc16],    "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
++        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32",       arr_dmmv_id_q3_k_f32_f32_len[reduc16],    arr_dmmv_id_q3_k_f32_f32_data[reduc16],    "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
++        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32",       arr_dmmv_id_q4_k_f32_f32_len[reduc16],    arr_dmmv_id_q4_k_f32_f32_data[reduc16],    "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
++        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32",       arr_dmmv_id_q5_k_f32_f32_len[reduc16],    arr_dmmv_id_q5_k_f32_f32_data[reduc16],    "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
++        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32",       arr_dmmv_id_q6_k_f32_f32_len[reduc16],    arr_dmmv_id_q6_k_f32_f32_data[reduc16],    "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
++        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ1_S],   "mul_mat_vec_id_iq1_s_f32",   arr_dmmv_id_iq1_s_f32_f32_len[reduc16],   arr_dmmv_id_iq1_s_f32_f32_data[reduc16],   "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
++        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ1_M],   "mul_mat_vec_id_iq1_m_f32",   arr_dmmv_id_iq1_m_f32_f32_len[reduc16],   arr_dmmv_id_iq1_m_f32_f32_data[reduc16],   "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
++        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", arr_dmmv_id_iq2_xxs_f32_f32_len[reduc16], arr_dmmv_id_iq2_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
++        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_XS],  "mul_mat_vec_id_iq2_xs_f32",  arr_dmmv_id_iq2_xs_f32_f32_len[reduc16],  arr_dmmv_id_iq2_xs_f32_f32_data[reduc16],  "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
++        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_S],   "mul_mat_vec_id_iq2_s_f32",   arr_dmmv_id_iq2_s_f32_f32_len[reduc16],   arr_dmmv_id_iq2_s_f32_f32_data[reduc16],   "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
++        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", arr_dmmv_id_iq3_xxs_f32_f32_len[reduc16], arr_dmmv_id_iq3_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
++        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ3_S],   "mul_mat_vec_id_iq3_s_f32",   arr_dmmv_id_iq3_s_f32_f32_len[reduc16],   arr_dmmv_id_iq3_s_f32_f32_data[reduc16],   "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
++        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_XS],  "mul_mat_vec_id_iq4_xs_f32",  arr_dmmv_id_iq4_xs_f32_f32_len[reduc16],  arr_dmmv_id_iq4_xs_f32_f32_data[reduc16],  "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
++        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_NL],  "mul_mat_vec_id_iq4_nl_f32",  arr_dmmv_id_iq4_nl_f32_f32_len[reduc16],  arr_dmmv_id_iq4_nl_f32_f32_data[reduc16],  "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
++        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_MXFP4],   "mul_mat_vec_id_mxfp4_f32",   arr_dmmv_id_mxfp4_f32_f32_len[reduc16],   arr_dmmv_id_mxfp4_f32_f32_data[reduc16],   "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
++
++#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
++        if (device->integer_dot_product) {
++            const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
++            const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4);
++
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_q8_1_f32", arr_dmmv_id_q4_0_q8_1_f32_len[reduc], arr_dmmv_id_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_q8_1_f32", arr_dmmv_id_q4_1_q8_1_f32_len[reduc], arr_dmmv_id_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_q8_1_f32", arr_dmmv_id_q5_0_q8_1_f32_len[reduc], arr_dmmv_id_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_q8_1_f32", arr_dmmv_id_q5_1_q8_1_f32_len[reduc], arr_dmmv_id_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_q8_1_f32", arr_dmmv_id_q8_0_q8_1_f32_len[reduc], arr_dmmv_id_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
++
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_q8_1_f32", arr_dmmv_id_mxfp4_q8_1_f32_len[reduc], arr_dmmv_id_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
++
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_q8_1_f32", arr_dmmv_id_q2_k_q8_1_f32_len[reduc], arr_dmmv_id_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_q8_1_f32", arr_dmmv_id_q3_k_q8_1_f32_len[reduc], arr_dmmv_id_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
++
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_q8_1_f32", arr_dmmv_id_iq1_s_q8_1_f32_len[reduc], arr_dmmv_id_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int);
++            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_q8_1_f32", arr_dmmv_id_iq1_m_q8_1_f32_len[reduc], arr_dmmv_id_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int);
++        }
++#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
++    }
++
++#if !defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
++    GGML_UNUSED(rm_stdq_int);
++    GGML_UNUSED(rm_kq_int);
++    GGML_UNUSED(rm_iq_int);
++#endif
++
++    // dequant shaders
++    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16",   dequant_f32_len,  dequant_f32_data,  "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_k", dequant_q2_k_len, dequant_q2_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_k", dequant_q3_k_len, dequant_q3_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_S],   "dequant_iq1_s",   dequant_iq1_s_len,   dequant_iq1_s_data,   "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_M],   "dequant_iq1_m",   dequant_iq1_m_len,   dequant_iq1_m_data,   "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XXS], "dequant_iq2_xxs", dequant_iq2_xxs_len, dequant_iq2_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XS],  "dequant_iq2_xs",  dequant_iq2_xs_len,  dequant_iq2_xs_data,  "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_S],   "dequant_iq2_s",   dequant_iq2_s_len,   dequant_iq2_s_data,   "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_XXS], "dequant_iq3_xxs", dequant_iq3_xxs_len, dequant_iq3_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S],   "dequant_iq3_s",   dequant_iq3_s_len,   dequant_iq3_s_data,   "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS],  "dequant_iq4_xs",  dequant_iq4_xs_len,  dequant_iq4_xs_data,  "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL],  "dequant_iq4_nl",  dequant_iq4_nl_len,  dequant_iq4_nl_data,  "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4],   "dequant_mxfp4",   dequant_mxfp4_len,   dequant_mxfp4_data,   "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
++
++    // get_rows
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32",  get_rows_f32_len,  get_rows_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16",  get_rows_f16_len,  get_rows_f16_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16", get_rows_bf16_len, get_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q2_K], "get_rows_q2_k", get_rows_q2_k_len, get_rows_q2_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q3_K], "get_rows_q3_k", get_rows_q3_k_len, get_rows_q3_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_K], "get_rows_q4_k", get_rows_q4_k_len, get_rows_q4_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_K], "get_rows_q5_k", get_rows_q5_k_len, get_rows_q5_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q6_K], "get_rows_q6_k", get_rows_q6_k_len, get_rows_q6_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_S],   "get_rows_iq1_s",   get_rows_iq1_s_len,   get_rows_iq1_s_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_M],   "get_rows_iq1_m",   get_rows_iq1_m_len,   get_rows_iq1_m_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs", get_rows_iq2_xxs_len, get_rows_iq2_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XS],  "get_rows_iq2_xs",  get_rows_iq2_xs_len,  get_rows_iq2_xs_data,  "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_S],   "get_rows_iq2_s",   get_rows_iq2_s_len,   get_rows_iq2_s_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs", get_rows_iq3_xxs_len, get_rows_iq3_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S],   "get_rows_iq3_s",   get_rows_iq3_s_len,   get_rows_iq3_s_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS],  "get_rows_iq4_xs",  get_rows_iq4_xs_len,  get_rows_iq4_xs_data,  "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL],  "get_rows_iq4_nl",  get_rows_iq4_nl_len,  get_rows_iq4_nl_data,  "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4],   "get_rows_mxfp4",   get_rows_mxfp4_len,   get_rows_mxfp4_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_I32],     "get_rows_i32",     get_rows_i32_len,     get_rows_i32_data,     "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32",  get_rows_f32_f32_len,  get_rows_f32_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32",  get_rows_f16_f32_len,  get_rows_f16_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32", get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q2_K], "get_rows_q2_k_f32", get_rows_q2_k_f32_len, get_rows_q2_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q3_K], "get_rows_q3_k_f32", get_rows_q3_k_f32_len, get_rows_q3_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_K], "get_rows_q4_k_f32", get_rows_q4_k_f32_len, get_rows_q4_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_K], "get_rows_q5_k_f32", get_rows_q5_k_f32_len, get_rows_q5_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q6_K], "get_rows_q6_k_f32", get_rows_q6_k_f32_len, get_rows_q6_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_S],   "get_rows_iq1_s_f32",   get_rows_iq1_s_f32_len,   get_rows_iq1_s_f32_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_M],   "get_rows_iq1_m_f32",   get_rows_iq1_m_f32_len,   get_rows_iq1_m_f32_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs_f32", get_rows_iq2_xxs_f32_len, get_rows_iq2_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XS],  "get_rows_iq2_xs_f32",  get_rows_iq2_xs_f32_len,  get_rows_iq2_xs_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_S],   "get_rows_iq2_s_f32",   get_rows_iq2_s_f32_len,   get_rows_iq2_s_f32_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs_f32", get_rows_iq3_xxs_f32_len, get_rows_iq3_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S],   "get_rows_iq3_s_f32",   get_rows_iq3_s_f32_len,   get_rows_iq3_s_f32_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS],  "get_rows_iq4_xs_f32",  get_rows_iq4_xs_f32_len,  get_rows_iq4_xs_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL],  "get_rows_iq4_nl_f32",  get_rows_iq4_nl_f32_len,  get_rows_iq4_nl_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4],   "get_rows_mxfp4_f32",   get_rows_mxfp4_f32_len,   get_rows_mxfp4_f32_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
++
++    for (auto &it : device->pipeline_fa_mask_opt) {
++        auto BrBc = it.first;
++        ggml_vk_create_pipeline(device, it.second, "fa_mask_opt", fa_mask_opt_len, fa_mask_opt_data, "main", 2, sizeof(vk_op_flash_attn_mask_opt_push_constants), {1, 1, 1}, {128, 128 / device->subgroup_size, BrBc.first, BrBc.second}, 1, true, true, device->subgroup_size);
++    }
++
++    if (device->subgroup_clustered && device->subgroup_require_full_support) {
++        ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
++    } else {
++        ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_len, quantize_q8_1_x4_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
++    }
++
++    for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
++        if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
++            ggml_vk_create_pipeline2(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_p021_push_constants), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true);
++        } else {
++            ggml_vk_create_pipeline2(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len,              mul_mat_vec_p021_f16_f32_data,              "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_p021_push_constants), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
++        }
++    }
++    ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_nc_push_constants), {1, 1, 1}, {}, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
++    ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
++    ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
++    ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
++
++    if (device->float_controls_rte_fp16 &&
++        sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) {
++        ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f32, "rms_norm_mul_rope_f32_f32", rms_norm_mul_rope_f32_f32_len, rms_norm_mul_rope_f32_f32_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
++        ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_rte_len, rms_norm_mul_rope_f32_f16_rte_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
++    }
++
++    ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_cpy_i32_f32, "cpy_i32_f32", cpy_i32_f32_len, cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_i32, "cpy_f32_i32", cpy_f32_i32_len, cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_32, "cpy_transpose_32", cpy_transpose_32_len, cpy_transpose_32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_16, "cpy_transpose_16", cpy_transpose_16_len, cpy_transpose_16_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
++
++    if (device->float_controls_rte_fp16) {
++        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
++        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
++        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
++        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
++        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
++        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
++    } else {
++        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
++        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
++        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
++        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
++        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
++        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
++    }
++
++#define SET_ROWS(itype, rte) \
++        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32],  "set_rows_f32" #itype,  set_rows_f32 ## itype ## rte ## _len,  set_rows_f32 ## itype ## rte ## _data,  "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
++        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16],  "set_rows_f16" #itype,  set_rows_f16 ## itype ## rte ## _len,  set_rows_f16 ## itype ## rte ## _data,  "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
++        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## rte ## _len, set_rows_bf16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
++        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## rte ## _len, set_rows_q4_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
++        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## rte ## _len, set_rows_q4_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
++        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## rte ## _len, set_rows_q5_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
++        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## rte ## _len, set_rows_q5_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
++        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## rte ## _len, set_rows_q8_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
++        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## rte ## _len, set_rows_iq4_nl ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
++
++    if (device->float_controls_rte_fp16) {
++        SET_ROWS(_i32, _rte)
++        SET_ROWS(_i64, _rte)
++    } else {
++        SET_ROWS(_i32, )
++        SET_ROWS(_i64, )
++    }
++#undef SET_ROWS
++
++
++    ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_0], "cpy_q5_0_f32", cpy_q5_0_f32_len, cpy_q5_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_1], "cpy_q5_1_f32", cpy_q5_1_f32_len, cpy_q5_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
++
++    auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {
++        std::string s;
++        s += std::string(src0_f16 ? "_f16" : "_f32");
++        s += std::string(src1_f16 ? "_f16" : "_f32");
++        s += std::string(dst_f16 ? "_f16" : "_f32");
++        return s;
++    };
++
++    bool rte = device->float_controls_rte_fp16;
++#define CREATE_BINARY(name, namemod, spec, bindings) \
++    for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
++        ggml_vk_create_pipeline2(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
++                                #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
++                                "main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
++
++    CREATE_BINARY(add, , {0}, 4)
++    CREATE_BINARY(add, _norepeat, {1}, 4)
++    CREATE_BINARY(sub, , {0}, 3)
++    CREATE_BINARY(sub, _norepeat, {1}, 3)
++    CREATE_BINARY(mul, , {0}, 3)
++    CREATE_BINARY(mul, _norepeat, {1}, 3)
++    CREATE_BINARY(div, , {0}, 3)
++    CREATE_BINARY(div, _norepeat, {1}, 3)
++    CREATE_BINARY(add_rms, , {0}, 4)
++    CREATE_BINARY(add_rms, _norepeat, {1}, 4)
++#undef CREATE_BINARY
++
++    if (device->multi_add) {
++        for (uint32_t i = 0; i < MAX_FUSED_ADDS; ++i) {
++            ggml_vk_create_pipeline2(device, device->pipeline_multi_add[i],     "multi_add_f32_"     + std::to_string(i+1), multi_add_f32_len,     multi_add_f32_data,     "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
++            ggml_vk_create_pipeline2(device, device->pipeline_multi_add_rms[i], "multi_add_rms_f32_" + std::to_string(i+1), multi_add_rms_f32_len, multi_add_rms_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
++        }
++    }
++
++    ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 1}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_set_f32, "set_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 0}, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_upscale_bicubic_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BICUBIC}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_antialias_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS}, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++
++    if (device->float_controls_rte_fp16) {
++        ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32_rte", log_f32_rte_len, log_f32_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++        ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16_rte", log_f16_rte_len, log_f16_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++    } else {
++        ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++        ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++    }
++
++    ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_diag[0], "diag_f32", diag_f32_len, diag_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_diag[1], "diag_f16", diag_f16_len, diag_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
++
++#define CREATE_UNARY(name)  \
++    ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);  \
++    ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
++
++    CREATE_UNARY(gelu)
++    CREATE_UNARY(gelu_erf)
++    CREATE_UNARY(gelu_quick)
++    CREATE_UNARY(silu)
++    CREATE_UNARY(relu)
++    CREATE_UNARY(xielu)
++    CREATE_UNARY(neg)
++    CREATE_UNARY(tanh)
++    CREATE_UNARY(sigmoid)
++    CREATE_UNARY(hardsigmoid)
++    CREATE_UNARY(hardswish)
++    CREATE_UNARY(abs)
++    CREATE_UNARY(softplus)
++    CREATE_UNARY(step)
++    CREATE_UNARY(round)
++    CREATE_UNARY(ceil)
++    CREATE_UNARY(floor)
++    CREATE_UNARY(trunc)
++#undef CREATE_UNARY
++
++#define CREATE_UNARY_RTE(name)  \
++    if (device->float_controls_rte_fp16) {  \
++        ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);   \
++        ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);   \
++    } else {    \
++        ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);   \
++        ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);   \
++    }
++    CREATE_UNARY_RTE(exp)
++#undef CREATE_UNARY_RTE
++
++    ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f16, "add1_f16_f16", add1_f16_f16_len, add1_f16_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
++
++#define CREATE_GLU(name)  \
++    if (device->float_controls_rte_fp16) {  \
++        ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);   \
++        ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);   \
++    } else {    \
++        ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);   \
++        ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);   \
++    }
++
++    CREATE_GLU(geglu)
++    CREATE_GLU(reglu)
++    CREATE_GLU(swiglu)
++    CREATE_GLU(swiglu_oai)
++    CREATE_GLU(geglu_erf)
++    CREATE_GLU(geglu_quick)
++#undef CREATE_GLU
++
++    ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
++
++    ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true);
++
++    ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32,     "soft_max_large1_f32",     soft_max_large1_f32_len,     soft_max_large1_f32_data,     "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
++    ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32,     "soft_max_large2_f32",     soft_max_large2_f32_len,     soft_max_large2_f32_data,     "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
++    ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32,     "soft_max_large3_f32",     soft_max_large3_f32_len,     soft_max_large3_f32_data,     "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
++    ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32_f16, "soft_max_large1_f32_f16", soft_max_large1_f32_f16_len, soft_max_large1_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
++    ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32_f16, "soft_max_large2_f32_f16", soft_max_large2_f32_f16_len, soft_max_large2_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
++    ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32_f16, "soft_max_large3_f32_f16", soft_max_large3_f32_f16_len, soft_max_large3_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
++
++    ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
++
++    if (device->float_controls_rte_fp16) {
++        ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
++        ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
++        ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
++        ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
++
++        ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
++        ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
++        ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_rte_len, rope_multi_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
++    } else {
++        ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
++        ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
++        ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
++        ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
++
++        ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
++        ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
++        ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
++    }
++
++    for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
++        uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2);
++        if (i <= device->max_workgroup_size_log2 &&
++            2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
++            const uint32_t NCOLS_PADDED_LOG2 = i;
++            ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);
++        }
++        const uint32_t WG_UNROLL_FACTOR = BLOCK_SIZE > 1 ? 2 : 1;
++        BLOCK_SIZE /= WG_UNROLL_FACTOR;
++        ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true);
++    }
++
++    for (uint32_t i = 0; i < num_topk_pipelines; ++i) {
++        const uint32_t BLOCK_SIZE = 1u << i;
++        const uint32_t NCOLS_PADDED_LOG2 = i;
++        if (i <= device->max_workgroup_size_log2) {
++            uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE +
++                                  sizeof(int) * device->subgroup_size +
++                                  2 * sizeof(int) +
++                                  2 * (BLOCK_SIZE / device->subgroup_size) * sizeof(int);
++            if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot &&
++                nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) {
++                ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size);
++            } else if (2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
++                ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_argsort_f32_len, topk_argsort_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);
++            }
++        }
++    }
++
++    ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
++
++    const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4;
++    ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32,       "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 256, device->subgroup_size, cumsum_elem_per_thread }, 1, true, true, device->subgroup_size);
++    ggml_vk_create_pipeline(device, device->pipeline_cumsum_small_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size, 1 }, 1, true, true, device->subgroup_size);
++    ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass1_f32, "cumsum_multipass1_f32", cumsum_multipass1_f32_len, cumsum_multipass1_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size);
++    ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass2_f32, "cumsum_multipass2_f32", cumsum_multipass2_f32_len, cumsum_multipass2_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size);
++
++    ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_count_experts, "count_experts", count_experts_len, count_experts_data, "main", 2, sizeof(vk_op_count_experts_push_constants), {1, 1, 1}, {}, 1, true);
++
++    for (auto &s : device->pipeline_solve_tri_f32) {
++        const vk_solve_tri_pipeline_state &state = s.first;
++
++        // Max number of rows to load at a time, limited by shared memory
++        const uint32_t batch_N = device->properties.limits.maxComputeSharedMemorySize / ((state.N + state.K) * sizeof(float));
++        // Need at least K invocations, and prefer a minimum of 128 to spread out loading shared memory
++        const uint32_t block_size = std::max(128u, 1u << (uint32_t)ceilf(log2f(float(state.K))));
++
++        ggml_vk_create_pipeline(
++            device, s.second, "solve_tri_f32",
++            solve_tri_f32_len, solve_tri_f32_data, "main", 3,
++            sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K, batch_N, block_size }, 1, true);
++    }
++
++#define IM2COL(bda) \
++    ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);   \
++    ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);      \
++    if (device->float_controls_rte_fp16) {  \
++        ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);   \
++        ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);      \
++    } else {    \
++        ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);   \
++        ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);      \
++    }
++    if (device->shader_int64 && device->buffer_device_address) {
++        IM2COL(_bda)
++    } else {
++        IM2COL()
++    }
++
++    ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
++
++    if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
++        ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true);
++        ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true);
++    } else {
++        ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
++        ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
++    }
++
++    ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
++
++    ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
++
++    // conv2d, conv_transpose_2d
++    for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
++        uint32_t conv2d_WG_SIZE  = 256;
++        uint32_t use_collectives = 0;  // Enables subgroup ops for preventing the re-calculation of indices.
++        uint32_t conv2d_TS_K     = (s == CONV_SHAPE_64x32) ? 4 : 8;
++        uint32_t conv2d_SHMEM_PAD = 4;
++        vk_conv_block_size conv2d_BS = vk_conv_block_sizes[s];
++        bool conv2d_UNROLL = true;
++
++#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
++        if (device->coopmat2) {
++            conv2d_SHMEM_PAD = 8; // 8 float16_t
++        }
++#endif
++
++        if (device->vendor_id == VK_VENDOR_ID_INTEL) {
++            conv2d_SHMEM_PAD = 0;
++            conv2d_UNROLL = false;
++        } else if (device->vendor_id == VK_VENDOR_ID_AMD) {
++            conv2d_SHMEM_PAD = device->architecture == vk_device_architecture::AMD_GCN ? 1 : 4;
++            if (s == CONV_SHAPE_128x128 && device->architecture != vk_device_architecture::AMD_GCN) {
++                conv2d_UNROLL = false;
++            }
++        }
++
++        // Use collectives on pre-Turing NVIDIA GPUs and GCN AMD cards, which had slower integer math.
++        bool allow_collectives_nv = device->vendor_id != VK_VENDOR_ID_NVIDIA ||
++                                    device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;
++        bool allow_collectives_amd = device->vendor_id != VK_VENDOR_ID_AMD ||
++                                     device->architecture == vk_device_architecture::AMD_GCN;
++
++        if (device->subgroup_shuffle &&
++            device->vendor_id != VK_VENDOR_ID_INTEL &&   // Do not enable collectives on Intel, see PR 14316.
++            allow_collectives_nv &&
++            allow_collectives_amd) {
++            use_collectives = 1;
++            conv2d_BS.CRS   = std::min(
++                device->subgroup_size,
++                conv2d_BS.CRS);  // CRS block size should be capped at subgroup size for correctness when shuffle is used.
++        }
++
++        uint32_t conv2d_shmem_req =
++            (conv2d_BS.K * (conv2d_BS.CRS + conv2d_SHMEM_PAD) + conv2d_BS.CRS * (conv2d_BS.NPQ + conv2d_SHMEM_PAD)) * sizeof(float);
++        if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {
++            conv2d_BS.CRS = 8;
++            if (use_collectives) {
++                conv2d_BS.CRS = std::min(device->subgroup_size, conv2d_BS.CRS);
++            }
++        }
++
++        std::array wg_denoms = { conv2d_BS.K, 1, 1 };
++        std::vector spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };
++
++#define CREATE_CONV(name, type_suffix, spv_suffix) \
++        for (auto &c : device->pipeline_##name##type_suffix[s]) { \
++            const vk_conv2d_pipeline_state &state = c.first;  \
++            std::vector spec_constants_cpy = spec_constants; \
++            spec_constants_cpy.push_back(state.s0); \
++            spec_constants_cpy.push_back(state.s1); \
++            spec_constants_cpy.push_back(state.p0); \
++            spec_constants_cpy.push_back(state.p1); \
++            spec_constants_cpy.push_back(state.d0); \
++            spec_constants_cpy.push_back(state.d1); \
++            spec_constants_cpy.push_back(state.KW); \
++            spec_constants_cpy.push_back(state.KH); \
++            ggml_vk_create_pipeline( \
++                device, c.second, #name #type_suffix, \
++                name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \
++                sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives);    \
++        }
++#define CREATE_CONVS(spv_suffix) \
++        CREATE_CONV(conv2d, _f32, spv_suffix) \
++        CREATE_CONV(conv2d, _f16_f32, spv_suffix) \
++        CREATE_CONV(conv_transpose_2d, _f32, spv_suffix) \
++        CREATE_CONV(conv_transpose_2d, _f16_f32, spv_suffix)
++#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
++        if (device->coopmat2) {
++            CREATE_CONVS(_cm2)
++        } else
++#endif
++        if (conv2d_UNROLL) {
++            CREATE_CONVS(_unroll)
++        } else {
++            CREATE_CONVS( )
++        }
++#undef CREATE_CONV
++#undef CREATE_CONVS
++    }
++
++    ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
++    ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
++
++    for (uint32_t use_push = 0; use_push < 2; ++use_push) {
++        for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
++            ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][use_push], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 4, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size);
++        }
++    }
++
++    for (auto &c : compiles) {
++        c.wait();
++    }
++}
++
++static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch);
++static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev);
++
++static vk_device ggml_vk_get_device(size_t idx) {
++    VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
++
++    if (vk_instance.devices[idx] == nullptr) {
++        VK_LOG_DEBUG("Initializing new vk_device");
++        vk_device device = std::make_shared();
++        vk_instance.devices[idx] = device;
++
++        device->memory_logger = std::unique_ptr(new vk_memory_logger());
++
++        size_t dev_num = vk_instance.device_indices[idx];
++
++        std::vector physical_devices = vk_instance.instance.enumeratePhysicalDevices();
++
++        if (dev_num >= physical_devices.size()) {
++            std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl;
++            throw std::runtime_error("Device not found");
++        }
++
++        device->physical_device = physical_devices[dev_num];
++        const std::vector ext_props = device->physical_device.enumerateDeviceExtensionProperties();
++
++        device->architecture = get_device_architecture(device->physical_device);
++
++        const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY");
++        device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr;
++
++        const char* GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM = getenv("GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM");
++        device->disable_host_visible_vidmem = GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM != nullptr;
++
++        const char* GGML_VK_ALLOW_SYSMEM_FALLBACK = getenv("GGML_VK_ALLOW_SYSMEM_FALLBACK");
++        device->allow_sysmem_fallback = GGML_VK_ALLOW_SYSMEM_FALLBACK != nullptr;
++
++        const char* GGML_VK_DISABLE_GRAPH_OPTIMIZE = getenv("GGML_VK_DISABLE_GRAPH_OPTIMIZE");
++        device->disable_graph_optimize = GGML_VK_DISABLE_GRAPH_OPTIMIZE != nullptr;
++
++        bool fp16_storage = false;
++        bool fp16_compute = false;
++        bool maintenance4_support = false;
++        bool sm_builtins = false;
++        bool amd_shader_core_properties2 = false;
++        bool pipeline_robustness = false;
++        bool coopmat2_support = false;
++        bool pipeline_executable_properties_support = false;
++        device->coopmat_support = false;
++        device->integer_dot_product = false;
++        device->shader_64b_indexing = false;
++        bool bfloat16_support = false;
++
++        for (const auto& properties : ext_props) {
++            if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
++                maintenance4_support = true;
++            } else if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
++                fp16_storage = true;
++            } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
++                fp16_compute = true;
++            } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
++                sm_builtins = true;
++            } else if (strcmp("VK_AMD_shader_core_properties2", properties.extensionName) == 0) {
++                amd_shader_core_properties2 = true;
++            } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) {
++                pipeline_robustness = true;
++            } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
++                device->subgroup_size_control = true;
++#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
++            } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
++                       !getenv("GGML_VK_DISABLE_COOPMAT")) {
++                device->coopmat_support = true;
++                device->coopmat_m = 0;
++                device->coopmat_n = 0;
++                device->coopmat_k = 0;
++#endif
++#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
++            } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
++                       !getenv("GGML_VK_DISABLE_COOPMAT2")) {
++                coopmat2_support = true;
++#endif
++#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
++            } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
++                       !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
++                device->integer_dot_product = true;
++#endif
++#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
++            } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 &&
++                       !getenv("GGML_VK_DISABLE_BFLOAT16")) {
++                bfloat16_support = true;
++#endif
++            } else if (strcmp("VK_KHR_pipeline_executable_properties", properties.extensionName) == 0) {
++                pipeline_executable_properties_support = true;
++            } else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 &&
++                       getenv("GGML_VK_ENABLE_MEMORY_PRIORITY")) {
++                device->memory_priority = true;
++            } else if (strcmp("VK_EXT_external_memory_host", properties.extensionName) == 0) {
++                device->external_memory_host = true;
++#if defined(VK_EXT_shader_64bit_indexing)
++            } else if (strcmp("VK_EXT_shader_64bit_indexing", properties.extensionName) == 0) {
++                device->shader_64b_indexing = true;
++#endif
++            }
++        }
++
++        vk::PhysicalDeviceProperties2 props2;
++        vk::PhysicalDeviceMaintenance3Properties props3;
++        vk::PhysicalDeviceMaintenance4Properties props4;
++        vk::PhysicalDeviceSubgroupProperties subgroup_props;
++        vk::PhysicalDeviceDriverProperties driver_props;
++        vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
++        vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
++        vk::PhysicalDeviceVulkan11Properties vk11_props;
++        vk::PhysicalDeviceVulkan12Properties vk12_props;
++        vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
++        vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
++        vk::PhysicalDeviceExternalMemoryHostPropertiesEXT external_memory_host_props;
++
++        props2.pNext = &props3;
++        props3.pNext = &subgroup_props;
++        subgroup_props.pNext = &driver_props;
++        driver_props.pNext = &vk11_props;
++        vk11_props.pNext = &vk12_props;
++
++        VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props;
++
++        if (maintenance4_support) {
++            last_struct->pNext = (VkBaseOutStructure *)&props4;
++            last_struct = (VkBaseOutStructure *)&props4;
++        }
++        if (sm_builtins) {
++            last_struct->pNext = (VkBaseOutStructure *)&sm_props;
++            last_struct = (VkBaseOutStructure *)&sm_props;
++        }
++        if (amd_shader_core_properties2) {
++            last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
++            last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
++        }
++        if (device->subgroup_size_control) {
++            last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_props;
++            last_struct = (VkBaseOutStructure *)&subgroup_size_control_props;
++        }
++
++#if defined(VK_NV_cooperative_matrix2)
++        vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props;
++        if (coopmat2_support) {
++            last_struct->pNext = (VkBaseOutStructure *)&coopmat2_props;
++            last_struct = (VkBaseOutStructure *)&coopmat2_props;
++        }
++#endif
++
++        if (device->integer_dot_product) {
++            last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
++            last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
++        }
++
++        if (device->external_memory_host) {
++            last_struct->pNext = (VkBaseOutStructure *)&external_memory_host_props;
++            last_struct = (VkBaseOutStructure *)&external_memory_host_props;
++        }
++
++        device->physical_device.getProperties2(&props2);
++        device->properties = props2.properties;
++        device->vendor_id = device->properties.vendorID;
++        device->driver_id = driver_props.driverID;
++
++        if (device->driver_id == vk::DriverId::eMoltenvk) {
++            // Disable external_memory_host until https://github.com/KhronosGroup/MoltenVK/pull/2622
++            // is available in the Vulkan SDK.
++            device->external_memory_host = false;
++        }
++
++        // Implementing the async backend interfaces seems broken on older Intel HW,
++        // see https://github.com/ggml-org/llama.cpp/issues/17302.
++        device->support_async = (device->vendor_id != VK_VENDOR_ID_INTEL ||
++                                 std::string(device->properties.deviceName.data()).find("(DG1)") == std::string::npos) &&
++                                getenv("GGML_VK_DISABLE_ASYNC") == nullptr;
++
++        if (!device->support_async) {
++            GGML_LOG_DEBUG("ggml_vulkan: WARNING: Async execution disabled on certain Intel devices.\n");
++        }
++
++        const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
++
++        if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
++            device->max_memory_allocation_size = std::stoull(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
++        } else if (maintenance4_support) {
++            device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
++        } else {
++            device->max_memory_allocation_size = props3.maxMemoryAllocationSize;
++        }
++
++        const char* GGML_VK_FORCE_MAX_BUFFER_SIZE = getenv("GGML_VK_FORCE_MAX_BUFFER_SIZE");
++
++        if (GGML_VK_FORCE_MAX_BUFFER_SIZE != nullptr) {
++            device->max_buffer_size = std::stoull(GGML_VK_FORCE_MAX_BUFFER_SIZE);
++        } else if (maintenance4_support) {
++            device->max_buffer_size = props4.maxBufferSize;
++        } else {
++            device->max_buffer_size = device->max_memory_allocation_size;
++        }
++
++        const char* GGML_VK_SUBALLOCATION_BLOCK_SIZE = getenv("GGML_VK_SUBALLOCATION_BLOCK_SIZE");
++
++        if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) {
++            device->suballocation_block_size = std::stoull(GGML_VK_SUBALLOCATION_BLOCK_SIZE);
++        } else {
++            // Limit batching of allocations to 1GB by default to avoid fragmentation issues
++            device->suballocation_block_size = 1024*1024*1024;
++        }
++        device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
++
++        device->subgroup_size = subgroup_props.subgroupSize;
++        device->subgroup_size_log2 = uint32_t(log2f(float(device->subgroup_size)));
++        device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
++        if (sm_builtins) {
++            device->shader_core_count = sm_props.shaderSMCount;
++        } else if (amd_shader_core_properties2) {
++            device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount;
++        } else if (device->vendor_id == VK_VENDOR_ID_INTEL) {
++            device->shader_core_count = ggml_vk_intel_shader_core_count(device->physical_device);
++        } else {
++            device->shader_core_count = 0;
++        }
++        device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
++
++        device->subgroup_basic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
++                                 (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBasic);
++        device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
++                                      (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
++#ifdef __APPLE__
++        // Workaround for subgroup arithmetic failing on MoltenVK with AMD GPUs (issue 15846)
++        if (device->vendor_id == VK_VENDOR_ID_AMD) {
++            device->subgroup_arithmetic = false;
++        }
++#endif
++        device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
++                                   (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle);
++        device->subgroup_clustered = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
++                                     (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eClustered);
++
++        device->subgroup_ballot = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
++                                  (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBallot);
++
++        device->subgroup_vote = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
++                                (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eVote);
++
++        const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
++
++        device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
++
++        if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props, device->architecture)) {
++            device->coopmat_support = false;
++        }
++
++        device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
++
++        device->min_imported_host_pointer_alignment = external_memory_host_props.minImportedHostPointerAlignment;
++
++        device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations)));
++
++        std::vector queue_family_props = device->physical_device.getQueueFamilyProperties();
++
++        // Try to find a non-graphics compute queue and transfer-focused queues
++        const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1);
++        const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1);
++
++        const float priorities[] = { 1.0f, 1.0f };
++        device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1;
++
++        std::vector device_queue_create_infos;
++        if (compute_queue_family_index != transfer_queue_family_index) {
++            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
++            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1});
++        } else if(!device->single_queue) {
++            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities});
++        } else {
++            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
++        }
++        vk::DeviceCreateInfo device_create_info;
++        std::vector device_extensions;
++        vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures();
++
++        VkPhysicalDeviceFeatures2 device_features2;
++        device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
++        device_features2.pNext = nullptr;
++        device_features2.features = (VkPhysicalDeviceFeatures)device_features;
++
++        VkPhysicalDeviceVulkan11Features vk11_features;
++        vk11_features.pNext = nullptr;
++        vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
++        device_features2.pNext = &vk11_features;
++
++        VkPhysicalDeviceVulkan12Features vk12_features;
++        vk12_features.pNext = nullptr;
++        vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
++        vk11_features.pNext = &vk12_features;
++
++        last_struct = (VkBaseOutStructure *)&vk12_features;
++
++        VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features;
++        pl_robustness_features.pNext = nullptr;
++        pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT;
++        pl_robustness_features.pipelineRobustness = VK_FALSE;
++
++        if (pipeline_robustness) {
++            last_struct->pNext = (VkBaseOutStructure *)&pl_robustness_features;
++            last_struct = (VkBaseOutStructure *)&pl_robustness_features;
++            device_extensions.push_back("VK_EXT_pipeline_robustness");
++        }
++
++        VkPhysicalDeviceMemoryPriorityFeaturesEXT memory_priority_features;
++        memory_priority_features.pNext = nullptr;
++        memory_priority_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_PRIORITY_FEATURES_EXT;
++        memory_priority_features.memoryPriority = VK_FALSE;
++        if (device->memory_priority) {
++            last_struct->pNext = (VkBaseOutStructure *)&memory_priority_features;
++            last_struct = (VkBaseOutStructure *)&memory_priority_features;
++            device_extensions.push_back("VK_EXT_memory_priority");
++        }
++
++        VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features;
++        subgroup_size_control_features.pNext = nullptr;
++        subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT;
++        subgroup_size_control_features.computeFullSubgroups = false;
++        subgroup_size_control_features.subgroupSizeControl = false;
++
++        if (device->subgroup_size_control) {
++            last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_features;
++            last_struct = (VkBaseOutStructure *)&subgroup_size_control_features;
++        }
++
++#if defined(VK_KHR_cooperative_matrix)
++        VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
++        coopmat_features.pNext = nullptr;
++        coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
++        coopmat_features.cooperativeMatrix = VK_FALSE;
++
++        if (device->coopmat_support) {
++            last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
++            last_struct = (VkBaseOutStructure *)&coopmat_features;
++        }
++#endif
++
++#if defined(VK_NV_cooperative_matrix2)
++        VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {};
++        coopmat2_features.pNext = nullptr;
++        coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV;
++        if (coopmat2_support) {
++            last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features;
++            last_struct = (VkBaseOutStructure *)&coopmat2_features;
++            device_extensions.push_back("VK_NV_cooperative_matrix2");
++        }
++#endif
++
++#if defined(VK_KHR_shader_bfloat16)
++        VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
++        bfloat16_features.pNext = nullptr;
++        bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR;
++        if (bfloat16_support) {
++            last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;
++            last_struct = (VkBaseOutStructure *)&bfloat16_features;
++            device_extensions.push_back("VK_KHR_shader_bfloat16");
++        }
++#endif
++
++        VkPhysicalDeviceMaintenance4Features maint4_features {};
++        maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
++        if (maintenance4_support) {
++            last_struct->pNext = (VkBaseOutStructure *)&maint4_features;
++            last_struct = (VkBaseOutStructure *)&maint4_features;
++            device_extensions.push_back("VK_KHR_maintenance4");
++        }
++
++        VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
++        shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
++        if (device->integer_dot_product) {
++            last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
++            last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
++            device_extensions.push_back("VK_KHR_shader_integer_dot_product");
++        }
++
++        VkPhysicalDevicePipelineExecutablePropertiesFeaturesKHR pep_features {};
++        pep_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_EXECUTABLE_PROPERTIES_FEATURES_KHR;
++        if (pipeline_executable_properties_support) {
++            last_struct->pNext = (VkBaseOutStructure *)&pep_features;
++            last_struct = (VkBaseOutStructure *)&pep_features;
++            device_extensions.push_back("VK_KHR_pipeline_executable_properties");
++        }
++
++        if (device->external_memory_host) {
++            device_extensions.push_back("VK_EXT_external_memory_host");
++        }
++
++#if defined(VK_EXT_shader_64bit_indexing)
++        VkPhysicalDeviceShader64BitIndexingFeaturesEXT shader_64bit_indexing_features {};
++        shader_64bit_indexing_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_64_BIT_INDEXING_FEATURES_EXT;
++        if (device->shader_64b_indexing) {
++            last_struct->pNext = (VkBaseOutStructure *)&shader_64bit_indexing_features;
++            last_struct = (VkBaseOutStructure *)&shader_64bit_indexing_features;
++            device_extensions.push_back("VK_EXT_shader_64bit_indexing");
++        }
++#endif
++
++        vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
++
++        device->pipeline_executable_properties_support = pipeline_executable_properties_support;
++
++        device->fp16 = device->fp16 && vk12_features.shaderFloat16;
++
++#if defined(VK_KHR_shader_bfloat16)
++        device->bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type;
++#else
++        device->bf16 = false;
++#endif
++
++        device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
++
++        device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 &&
++                            device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_multi_add_push_constants) &&
++                            getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr;
++
++        device->shader_int64 = device_features2.features.shaderInt64;
++        device->buffer_device_address = vk12_features.bufferDeviceAddress;
++        device->vulkan_memory_model = vk12_features.vulkanMemoryModel;
++
++        if (device->subgroup_size_control) {
++            device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
++            device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
++            device_extensions.push_back("VK_EXT_subgroup_size_control");
++        }
++
++        device->subgroup_size_control = device->subgroup_size_control &&
++                (subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) &&
++                subgroup_size_control_features.subgroupSizeControl;
++
++        device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;
++
++#if defined(VK_KHR_cooperative_matrix)
++        device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
++        device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support;
++#endif
++
++        if (coopmat2_support) {
++#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
++            if (coopmat2_features.cooperativeMatrixWorkgroupScope &&
++                coopmat2_features.cooperativeMatrixFlexibleDimensions &&
++                coopmat2_features.cooperativeMatrixReductions &&
++                coopmat2_features.cooperativeMatrixConversions &&
++                coopmat2_features.cooperativeMatrixPerElementOperations &&
++                coopmat2_features.cooperativeMatrixTensorAddressing &&
++                coopmat2_features.cooperativeMatrixBlockLoads &&
++                vk12_features.bufferDeviceAddress) {
++
++                std::vector flexible_dimensions;
++                uint32_t count = 0;
++
++                PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV
++                    _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV =
++                        (PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV)
++                        vk_instance.instance.getProcAddr("vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV");
++
++                _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, nullptr);
++
++                VkCooperativeMatrixFlexibleDimensionsPropertiesNV empty_prop {};
++                empty_prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV;
++                flexible_dimensions.resize(count, empty_prop);
++
++                _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, flexible_dimensions.data());
++
++                bool found_fp16_128 = false,
++                     found_fp16_256 = false,
++                     found_fp32_128 = false,
++                     found_fp32_256 = false;
++                // need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128
++                // with 32x16x16 and 256 with 32x32x16.
++                for (auto &prop : flexible_dimensions) {
++                    if (prop.saturatingAccumulation == VK_FALSE &&
++                        prop.scope == VK_SCOPE_WORKGROUP_KHR &&
++                        prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
++                        prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
++
++                        if (prop.workgroupInvocations == 128 &&
++                            prop.MGranularity <= 32 &&
++                            prop.NGranularity <= 16 &&
++                            prop.KGranularity <= 16) {
++                            if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
++                                prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
++                                found_fp16_128 = true;
++                            }
++                            if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
++                                prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
++                                found_fp32_128 = true;
++                            }
++                        }
++                        if (prop.workgroupInvocations == 256 &&
++                            prop.MGranularity <= 32 &&
++                            prop.NGranularity <= 32 &&
++                            prop.KGranularity <= 16) {
++                            if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
++                                prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
++                                found_fp16_256 = true;
++                            }
++                            if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
++                                prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
++                                found_fp32_256 = true;
++                            }
++                        }
++                    }
++                }
++                if (found_fp16_128 && found_fp16_256 &&
++                    found_fp32_128 && found_fp32_256 &&
++                    coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) {
++                    device->coopmat2 = true;
++                }
++            }
++#endif
++        }
++
++        if (!vk11_features.storageBuffer16BitAccess) {
++            std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
++            throw std::runtime_error("Unsupported device");
++        }
++
++        device_extensions.push_back("VK_KHR_16bit_storage");
++
++#ifdef GGML_VULKAN_VALIDATE
++        device_extensions.push_back("VK_KHR_shader_non_semantic_info");
++#endif
++
++        if (device->fp16) {
++            device_extensions.push_back("VK_KHR_shader_float16_int8");
++        }
++
++#if defined(VK_KHR_cooperative_matrix)
++        if (device->coopmat_support) {
++            // Query supported shapes
++            std::vector cm_props;
++
++            PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR =
++                (PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR)vkGetInstanceProcAddr(vk_instance.instance, "vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR");
++
++            uint32_t cm_props_num;
++
++            pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, nullptr);
++
++            cm_props.resize(cm_props_num);
++
++            for (auto& prop : cm_props) {
++                prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR;
++            }
++
++            pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, cm_props.data());
++
++            VK_LOG_DEBUG("ggml_vulkan: Cooperative Matrix Shapes: " << cm_props.size());
++
++            for (auto& prop : cm_props) {
++                VK_LOG_DEBUG("ggml_vulkan: M: " << prop.MSize << " N: " << prop.NSize << " K: " << prop.KSize << " A: " << vk::to_string((vk::ComponentTypeKHR)prop.AType) << " B: " << vk::to_string((vk::ComponentTypeKHR)prop.BType) << " C: " << vk::to_string((vk::ComponentTypeKHR)prop.CType) << " Result: " << vk::to_string((vk::ComponentTypeKHR)prop.ResultType) << " saturatingAccumulation: " << prop.saturatingAccumulation << " scope: " << vk::to_string((vk::ScopeKHR)prop.scope));
++
++                if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eFloat16 &&
++                    (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eFloat16 &&
++                    (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
++                ) {
++                    if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 &&
++                        (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32) {
++                        // coopmat sizes not set yet
++                        if (device->coopmat_m == 0) {
++                            device->coopmat_acc_f32_support = true;
++                            device->coopmat_m = prop.MSize;
++                            device->coopmat_n = prop.NSize;
++                            device->coopmat_k = prop.KSize;
++                        } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
++                            // Only enable if shape is identical
++                            device->coopmat_acc_f32_support = true;
++                        }
++                        if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
++                            device->coopmat_support_16x16x16_f32acc = true;
++                        }
++                    } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 &&
++                               (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) {
++                        // coopmat sizes not set yet
++                        if (device->coopmat_m == 0) {
++                            device->coopmat_acc_f16_support = true;
++                            device->coopmat_m = prop.MSize;
++                            device->coopmat_n = prop.NSize;
++                            device->coopmat_k = prop.KSize;
++                        } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
++                            // Only enable if shape is identical
++                            device->coopmat_acc_f16_support = true;
++                        }
++                        if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
++                            device->coopmat_support_16x16x16_f16acc = true;
++                        }
++                    }
++                } else if ((vk::ComponentTypeKHR)prop.AType      == vk::ComponentTypeKHR::eSint8 &&
++                           (vk::ComponentTypeKHR)prop.BType      == vk::ComponentTypeKHR::eSint8 &&
++                           (vk::ComponentTypeKHR)prop.CType      == vk::ComponentTypeKHR::eSint32 &&
++                           (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eSint32 &&
++                           (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup &&
++                           device->coopmat_int_m == 0
++                ) {
++                    device->coopmat_int_support = true;
++                    device->coopmat_int_m = prop.MSize;
++                    device->coopmat_int_n = prop.NSize;
++                    device->coopmat_int_k = prop.KSize;
++                }
++#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
++                if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
++                    prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
++                    prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
++                    prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
++                    (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
++                ) {
++                    // coopmat sizes not set yet
++                    if (device->coopmat_m == 0) {
++                        device->coopmat_bf16_support = true;
++                        device->coopmat_m = prop.MSize;
++                        device->coopmat_n = prop.NSize;
++                        device->coopmat_k = prop.KSize;
++                    } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
++                        // Only enable if shape is identical
++                        device->coopmat_bf16_support = true;
++                    }
++                }
++#endif
++            }
++
++            if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) {
++                // No suitable matmul mode found
++                GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
++                device->coopmat_support = false;
++            }
++            if (getenv("GGML_VK_DISABLE_BFLOAT16")) {
++                device->coopmat_bf16_support = false;
++            }
++        }
++
++        if (device->coopmat_support) {
++            device_extensions.push_back("VK_KHR_cooperative_matrix");
++        }
++#if defined(VK_KHR_shader_bfloat16)
++        if (device->coopmat_bf16_support) {
++            device_extensions.push_back("VK_KHR_shader_bfloat16");
++        }
++#endif
++#endif
++        device->name = GGML_VK_NAME + std::to_string(idx);
++
++        device_create_info = {
++            vk::DeviceCreateFlags(),
++            device_queue_create_infos,
++            {},
++            device_extensions
++        };
++        device_create_info.setPNext(&device_features2);
++        device->device = device->physical_device.createDevice(device_create_info);
++
++        // Queues
++        ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false);
++
++        // Shaders
++        // Disable matmul tile sizes early if performance low or not supported
++        for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) {
++            switch (device->vendor_id) {
++#ifndef GGML_VULKAN_RUN_TESTS
++            case VK_VENDOR_ID_AMD:
++                device->mul_mat_l[i]    = device->coopmat_support && device->driver_id != vk::DriverId::eAmdProprietary;
++                device->mul_mat_m[i]    = true;
++                device->mul_mat_s[i]    = true;
++                device->mul_mat_id_l[i] = false;
++                device->mul_mat_id_m[i] = true;
++                device->mul_mat_id_s[i] = true;
++                break;
++            case VK_VENDOR_ID_INTEL:
++                if (!device->coopmat_support || device->architecture != INTEL_XE2) {
++                    device->mul_mat_l[i] = false;
++                    device->mul_mat_id_l[i] = false;
++                } else {
++                    device->mul_mat_l[i] = true;  // if coopmat & XE2+, allow large matmul warptile config for Intel
++                    device->mul_mat_id_l[i] = true;
++                }
++                device->mul_mat_m[i] = true;
++                device->mul_mat_s[i] = true;
++                device->mul_mat_id_m[i] = true;
++                device->mul_mat_id_s[i] = true;
++                break;
++            case VK_VENDOR_ID_APPLE:
++                device->mul_mat_l[i] = false;
++                device->mul_mat_m[i] = true;
++                device->mul_mat_s[i] = false;
++                device->mul_mat_id_l[i] = false;
++                device->mul_mat_id_m[i] = true;
++                device->mul_mat_id_s[i] = false;
++                break;
++#endif
++            default:
++                device->mul_mat_l[i] = true;
++                device->mul_mat_m[i] = true;
++                device->mul_mat_s[i] = true;
++                device->mul_mat_id_l[i] = true;
++                device->mul_mat_id_m[i] = true;
++                device->mul_mat_id_s[i] = true;
++                break;
++            }
++        }
++
++
++        std::vector dsl_binding;
++        std::vector dsl_binding_flags;
++        for (uint32_t i = 0; i < MAX_PARAMETER_COUNT; i++) {
++            dsl_binding.push_back({i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute});
++            dsl_binding_flags.push_back({});
++        }
++
++        vk::DescriptorSetLayoutBindingFlagsCreateInfo dslbfci = { dsl_binding_flags };
++
++        vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info(
++            {},
++            dsl_binding);
++        descriptor_set_layout_create_info.setPNext(&dslbfci);
++        device->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info);
++
++        ggml_vk_load_shaders(device);
++
++        if (!device->single_queue) {
++            const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
++            ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true);
++        } else {
++            // TODO: Use pointer or reference to avoid copy
++            device->transfer_queue.copyFrom(device->compute_queue);
++            device->transfer_queue.cmd_pool.init(device, &device->transfer_queue);
++        }
++
++        device->buffer_type = {
++            /* .iface    = */ ggml_backend_vk_buffer_type_interface,
++            /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), idx),
++            /* .context  = */ new ggml_backend_vk_buffer_type_context{ device->name, device },
++        };
++
++        device->fence = device->device.createFence({});
++
++        device->idx = idx;
++
++        device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
++
++        device->add_rms_fusion = !device->disable_fusion &&
++                                 device->subgroup_arithmetic &&
++                                 device->vendor_id != VK_VENDOR_ID_INTEL;
++        device->partials_binding_alignment =
++            std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment);
++
++        device->mmvq_mode = 0;
++        if (getenv("GGML_VK_DISABLE_MMVQ")) {
++            device->mmvq_mode = -1;
++        } else if (getenv("GGML_VK_FORCE_MMVQ")) {
++            device->mmvq_mode = 1;
++        }
++
++        // Driver issues with older AMD GPUs on Windows, see https://github.com/ggml-org/llama.cpp/pull/19625#issuecomment-3940840613
++        const bool is_amd_proprietary_gcn = device->vendor_id == VK_VENDOR_ID_AMD && device->architecture == AMD_GCN && device->driver_id == vk::DriverId::eAmdProprietary;
++        device->flash_attention_fp16 = device->fp16 && !is_amd_proprietary_gcn;
++
++        return device;
++    }
++
++    return vk_instance.devices[idx];
++}
++
++static void ggml_vk_print_gpu_info(size_t idx) {
++    GGML_ASSERT(idx < vk_instance.device_indices.size());
++    size_t dev_num = vk_instance.device_indices[idx];
++    VK_LOG_DEBUG("ggml_vk_print_gpu_info(" << dev_num << ")");
++    GGML_ASSERT(vk_instance_initialized);
++
++    std::vector devices = vk_instance.instance.enumeratePhysicalDevices();
++
++    if (dev_num >= devices.size()) {
++        std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl;
++        throw std::runtime_error("Device not found");
++    }
++
++    vk::PhysicalDevice physical_device = devices[dev_num];
++    std::vector ext_props = physical_device.enumerateDeviceExtensionProperties();
++
++    bool fp16_storage = false;
++    bool fp16_compute = false;
++    bool coopmat_support = false;
++    bool coopmat2_support = false;
++    bool integer_dot_product = false;
++    bool bfloat16_support = false;
++
++    for (auto properties : ext_props) {
++        if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
++            fp16_storage = true;
++        } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
++            fp16_compute = true;
++#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
++       } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
++                   !getenv("GGML_VK_DISABLE_COOPMAT")) {
++            coopmat_support = true;
++#endif
++#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
++        } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
++                   !getenv("GGML_VK_DISABLE_COOPMAT2")) {
++            coopmat2_support = true;
++#endif
++#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
++        } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
++                    !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
++            integer_dot_product = true;
++#endif
++#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
++        } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 &&
++                    !getenv("GGML_VK_DISABLE_BFLOAT16")) {
++            bfloat16_support = true;
++#endif
++        }
++    }
++
++    const vk_device_architecture device_architecture = get_device_architecture(physical_device);
++
++    const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
++    bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
++
++    bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
++
++    vk::PhysicalDeviceProperties2 props2;
++    vk::PhysicalDeviceMaintenance3Properties props3;
++    vk::PhysicalDeviceSubgroupProperties subgroup_props;
++    vk::PhysicalDeviceDriverProperties driver_props;
++    vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
++    props2.pNext = &props3;
++    props3.pNext = &subgroup_props;
++    subgroup_props.pNext = &driver_props;
++
++    // Pointer to the last chain element
++    VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props;
++
++    if (integer_dot_product) {
++        last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
++        last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
++    }
++
++    physical_device.getProperties2(&props2);
++
++    VkPhysicalDeviceFeatures2 device_features2;
++    device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
++    device_features2.pNext = nullptr;
++
++    VkPhysicalDeviceVulkan11Features vk11_features;
++    vk11_features.pNext = nullptr;
++    vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
++    device_features2.pNext = &vk11_features;
++
++    VkPhysicalDeviceVulkan12Features vk12_features;
++    vk12_features.pNext = nullptr;
++    vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
++    vk11_features.pNext = &vk12_features;
++
++    // Pointer to the last chain element
++    last_struct = (VkBaseOutStructure *)&vk12_features;
++
++#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
++    VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
++    coopmat_features.pNext = nullptr;
++    coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
++    coopmat_features.cooperativeMatrix = VK_FALSE;
++
++    if (coopmat_support) {
++        last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
++        last_struct = (VkBaseOutStructure *)&coopmat_features;
++    }
++#endif
++
++    VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
++    shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
++    if (integer_dot_product) {
++        last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
++        last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
++    }
++
++#if defined(VK_KHR_shader_bfloat16)
++    VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
++    bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR;
++    if (bfloat16_support) {
++        last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;
++        last_struct = (VkBaseOutStructure *)&bfloat16_features;
++    }
++#endif
++
++    vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
++
++    fp16 = fp16 && vk12_features.shaderFloat16;
++
++#if defined(VK_KHR_shader_bfloat16)
++    bool bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type;
++#else
++    bool bf16 = false;
++#endif
++
++    uint32_t default_subgroup_size = get_subgroup_size("", device_architecture);
++    const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
++    const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
++
++    integer_dot_product = integer_dot_product
++                       && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated
++                       && shader_integer_dot_product_features.shaderIntegerDotProduct;
++
++    coopmat_support = coopmat_support
++#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
++                   && coopmat_features.cooperativeMatrix
++#endif
++                   && ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture);
++
++    std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
++
++    std::string device_name = props2.properties.deviceName.data();
++    GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
++              idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, bf16, subgroup_size,
++              props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str());
++
++    if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
++        GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
++    }
++}
++
++static bool ggml_vk_instance_layer_settings_available();
++static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions);
++static bool ggml_vk_instance_debug_utils_ext_available(const std::vector & instance_extensions);
++static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev);
++
++static DispatchLoaderDynamic ggml_vk_default_dispatcher_instance;
++DispatchLoaderDynamic & ggml_vk_default_dispatcher() {
++    return ggml_vk_default_dispatcher_instance;
++}
++
++static void ggml_vk_instance_init() {
++    if (vk_instance_initialized) {
++        return;
++    }
++    VK_LOG_DEBUG("ggml_vk_instance_init()");
++
++    // See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-
++    ggml_vk_default_dispatcher_instance.init(vkGetInstanceProcAddr);
++
++    uint32_t api_version = vk::enumerateInstanceVersion();
++
++    if (api_version < VK_API_VERSION_1_2) {
++        std::cerr << "ggml_vulkan: Error: Vulkan 1.2 required." << std::endl;
++        throw vk::SystemError(vk::Result::eErrorFeatureNotPresent, "Vulkan 1.2 required");
++    }
++
++    vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, api_version };
++
++    const std::vector instance_extensions = vk::enumerateInstanceExtensionProperties();
++    const bool layer_settings = ggml_vk_instance_layer_settings_available();
++#ifdef __APPLE__
++    const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);
++#endif
++    const bool debug_utils_ext = ggml_vk_instance_debug_utils_ext_available(instance_extensions) && getenv("GGML_VK_DEBUG_MARKERS") != nullptr;
++    std::vector layers;
++
++    if (layer_settings) {
++        layers.push_back("VK_LAYER_KHRONOS_validation");
++    }
++    std::vector extensions;
++    if (layer_settings) {
++        extensions.push_back("VK_EXT_layer_settings");
++    }
++#ifdef __APPLE__
++    if (portability_enumeration_ext) {
++        extensions.push_back("VK_KHR_portability_enumeration");
++    }
++#endif
++    if (debug_utils_ext) {
++        extensions.push_back("VK_EXT_debug_utils");
++    }
++    VkBool32 enable_best_practice = layer_settings;
++    std::vector settings = {
++        {
++            "VK_LAYER_KHRONOS_validation",
++            "validate_best_practices",
++            vk::LayerSettingTypeEXT::eBool32,
++            1,
++            &enable_best_practice
++        },
++    };
++    vk::LayerSettingsCreateInfoEXT layer_setting_info(settings);
++    vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions, &layer_setting_info);
++#ifdef __APPLE__
++    if (portability_enumeration_ext) {
++        instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR;
++    }
++#endif
++
++    vk_instance.instance = vk::createInstance(instance_create_info);
++    vk_instance_initialized = true;
++
++    if (debug_utils_ext) {
++        vk_instance.debug_utils_support              = true;
++        vk_instance.pfn_vkSetDebugUtilsObjectNameEXT = (PFN_vkSetDebugUtilsObjectNameEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkSetDebugUtilsObjectNameEXT");
++        vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT = (PFN_vkQueueBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueBeginDebugUtilsLabelEXT");
++        vk_instance.pfn_vkQueueEndDebugUtilsLabelEXT = (PFN_vkQueueEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueEndDebugUtilsLabelEXT");
++        vk_instance.pfn_vkCmdBeginDebugUtilsLabelEXT = (PFN_vkCmdBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdBeginDebugUtilsLabelEXT");
++        vk_instance.pfn_vkCmdEndDebugUtilsLabelEXT =   (PFN_vkCmdEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdEndDebugUtilsLabelEXT");
++        vk_instance.pfn_vkCmdInsertDebugUtilsLabelEXT = (PFN_vkCmdInsertDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdInsertDebugUtilsLabelEXT");
++    }
++
++    vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
++    vk_perf_logger_concurrent = getenv("GGML_VK_PERF_LOGGER_CONCURRENT") != nullptr;
++    vk_enable_sync_logger = getenv("GGML_VK_SYNC_LOGGER") != nullptr;
++    vk_memory_logger_enabled = getenv("GGML_VK_MEMORY_LOGGER") != nullptr;
++    const char* GGML_VK_PIPELINE_STATS = getenv("GGML_VK_PIPELINE_STATS");
++    if (GGML_VK_PIPELINE_STATS != nullptr) {
++        vk_pipeline_stats_filter = GGML_VK_PIPELINE_STATS;
++    }
++    const char* GGML_VK_PERF_LOGGER_FREQUENCY = getenv("GGML_VK_PERF_LOGGER_FREQUENCY");
++
++    if (GGML_VK_PERF_LOGGER_FREQUENCY != nullptr) {
++        vk_perf_logger_frequency = std::stoul(GGML_VK_PERF_LOGGER_FREQUENCY);
++    }
++
++    // See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-
++    VULKAN_HPP_DEFAULT_DISPATCHER.init(vk_instance.instance);
++
++    std::vector devices = vk_instance.instance.enumeratePhysicalDevices();
++
++    // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
++    char * devices_env = getenv("GGML_VK_VISIBLE_DEVICES");
++    if (devices_env != nullptr) {
++        size_t num_available_devices = devices.size();
++
++        std::string devices(devices_env);
++        std::replace(devices.begin(), devices.end(), ',', ' ');
++
++        std::stringstream ss(devices);
++        size_t tmp;
++        while (ss >> tmp) {
++            if(tmp >= num_available_devices) {
++                std::cerr << "ggml_vulkan: Invalid device index " << tmp << " in GGML_VK_VISIBLE_DEVICES." << std::endl;
++                throw std::runtime_error("Invalid Vulkan device index");
++            }
++            vk_instance.device_indices.push_back(tmp);
++        }
++    } else {
++        // If no vulkan devices are found, return early
++        if (devices.empty()) {
++            GGML_LOG_INFO("ggml_vulkan: No devices found.\n");
++            return;
++        }
++
++        // Default to using all dedicated GPUs
++        for (size_t i = 0; i < devices.size(); i++) {
++            vk::PhysicalDeviceProperties2 new_props;
++            vk::PhysicalDeviceDriverProperties new_driver;
++            vk::PhysicalDeviceIDProperties new_id;
++            new_props.pNext = &new_driver;
++            new_driver.pNext = &new_id;
++            devices[i].getProperties2(&new_props);
++
++            if ((new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu || new_props.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu) && ggml_vk_device_is_supported(devices[i])) {
++                // Check if there are two physical devices corresponding to the same GPU
++                // This handles the case where the same GPU appears with different drivers (e.g., RADV + AMDVLK on Linux),
++                // see https://github.com/ggml-org/llama.cpp/pull/7582 for original deduplication.
++                // MoltenVK on macOS may report the same UUID for distinct GPUs on multi-GPU cards,
++                // see https://github.com/KhronosGroup/MoltenVK/issues/2683. Skip when both old/new
++                // driver is MoltenVK
++                auto old_device = std::find_if(
++                    vk_instance.device_indices.begin(),
++                    vk_instance.device_indices.end(),
++                    [&devices, &new_id, &new_driver](const size_t k){
++                        vk::PhysicalDeviceProperties2 old_props;
++                        vk::PhysicalDeviceDriverProperties old_driver;
++                        vk::PhysicalDeviceIDProperties old_id;
++                        old_props.pNext = &old_driver;
++                        old_driver.pNext = &old_id;
++                        devices[k].getProperties2(&old_props);
++
++                        bool same_uuid = std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID));
++                        same_uuid = same_uuid || (
++                            old_id.deviceLUIDValid && new_id.deviceLUIDValid &&
++                            std::equal(std::begin(old_id.deviceLUID), std::end(old_id.deviceLUID), std::begin(new_id.deviceLUID))
++                        );
++                        bool both_molten_vk = (new_driver.driverID == vk::DriverId::eMoltenvk && old_driver.driverID == vk::DriverId::eMoltenvk);
++
++                        return same_uuid && !both_molten_vk;
++                    }
++                );
++                if (old_device == vk_instance.device_indices.end()) {
++                    vk_instance.device_indices.push_back(i);
++                } else {
++                    // There can be two physical devices corresponding to the same GPU if there are 2 different drivers
++                    // This can cause error when splitting layers aross the devices, need to keep only 1
++                    VK_LOG_DEBUG("Device " << i << " and device " << *old_device << " have the same deviceUUID");
++
++                    vk::PhysicalDeviceProperties2 old_props;
++                    vk::PhysicalDeviceDriverProperties old_driver;
++                    old_props.pNext = &old_driver;
++                    devices[*old_device].getProperties2(&old_props);
++
++                    std::map driver_priorities {};
++                    int old_priority = std::numeric_limits::max();
++                    int new_priority = std::numeric_limits::max();
++
++                    // Check https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkDriverId.html for the list of driver id
++                    // Smaller number -> higher priority
++                    switch (old_props.properties.vendorID) {
++                        case VK_VENDOR_ID_AMD:
++                            driver_priorities[vk::DriverId::eMesaRadv] = 1;
++                            driver_priorities[vk::DriverId::eAmdOpenSource] = 2;
++                            driver_priorities[vk::DriverId::eAmdProprietary] = 3;
++                            break;
++                        case VK_VENDOR_ID_INTEL:
++                            driver_priorities[vk::DriverId::eIntelOpenSourceMESA] = 1;
++                            driver_priorities[vk::DriverId::eIntelProprietaryWindows] = 2;
++                            break;
++                        case VK_VENDOR_ID_NVIDIA:
++                            driver_priorities[vk::DriverId::eNvidiaProprietary] = 1;
++#if defined(VK_API_VERSION_1_3) && VK_HEADER_VERSION >= 235
++                            driver_priorities[vk::DriverId::eMesaNvk] = 2;
++#endif
++                            break;
++                        case VK_VENDOR_ID_QUALCOMM:
++                            driver_priorities[vk::DriverId::eQualcommProprietary] = 1;
++                            driver_priorities[vk::DriverId::eMesaTurnip] = 2;
++                            break;
++                    }
++                    driver_priorities[vk::DriverId::eMesaDozen] = 100;
++
++                    if (driver_priorities.count(old_driver.driverID)) {
++                        old_priority = driver_priorities[old_driver.driverID];
++                    }
++                    if (driver_priorities.count(new_driver.driverID)) {
++                        new_priority = driver_priorities[new_driver.driverID];
++                    }
++
++                    if (new_priority < old_priority) {
++                        auto r = std::remove(vk_instance.device_indices.begin(), vk_instance.device_indices.end(), *old_device);
++                        vk_instance.device_indices.erase(r, vk_instance.device_indices.end());
++                        vk_instance.device_indices.push_back(i);
++
++                        VK_LOG_DEBUG("Prioritize device " << i << " driver " << new_driver.driverName << " over device " << *old_device << " driver " << old_driver.driverName);
++                    }
++                    else {
++                        VK_LOG_DEBUG("Prioritize device " << *old_device << " driver " << old_driver.driverName << " over device " << i << " driver " << new_driver.driverName << std::endl);
++                    }
++                }
++            }
++        }
++
++        // If no GPUs found, fall back to the first non-CPU device.
++        // If only CPU devices are available, return without devices.
++        if (vk_instance.device_indices.empty()) {
++            for (size_t i = 0; i < devices.size(); i++) {
++                if (devices[i].getProperties().deviceType != vk::PhysicalDeviceType::eCpu) {
++                    vk_instance.device_indices.push_back(i);
++                    break;
++                }
++            }
++        }
++
++        if (vk_instance.device_indices.empty()) {
++            GGML_LOG_INFO("ggml_vulkan: No devices found.\n");
++            return;
++        }
++    }
++    GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size());
++
++    for (size_t i = 0; i < vk_instance.device_indices.size(); i++) {
++        vk::PhysicalDevice vkdev = devices[vk_instance.device_indices[i]];
++        std::vector extensionprops = vkdev.enumerateDeviceExtensionProperties();
++
++        bool membudget_supported = false;
++        for (const auto & ext : extensionprops) {
++            if (strcmp(VK_EXT_MEMORY_BUDGET_EXTENSION_NAME, ext.extensionName) == 0) {
++                membudget_supported = true;
++                break;
++            }
++        }
++
++        vk_instance.device_supports_membudget.push_back(membudget_supported);
++
++        ggml_vk_print_gpu_info(i);
++    }
++}
++
++static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
++    VK_LOG_DEBUG("ggml_vk_init(" << ctx->name << ", " << idx << ")");
++    ggml_vk_instance_init();
++    GGML_ASSERT(idx < vk_instance.device_indices.size());
++
++    ctx->name = GGML_VK_NAME + std::to_string(idx);
++
++    ctx->device = ggml_vk_get_device(idx);
++
++    ctx->semaphore_idx = 0;
++    ctx->event_idx = 0;
++
++    ctx->prealloc_size_x = 0;
++    ctx->prealloc_size_y = 0;
++    ctx->prealloc_size_split_k = 0;
++    // Fixed size of 1KB, for deterministic behavior
++    ctx->prealloc_size_add_rms_partials = 1024;
++
++    ctx->fence = ctx->device->device.createFence({});
++    ctx->almost_ready_fence = ctx->device->device.createFence({});
++
++    ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue);
++
++    if (vk_perf_logger_enabled) {
++        ctx->perf_logger = std::unique_ptr(new vk_perf_logger());
++    }
++
++#ifdef GGML_VULKAN_CHECK_RESULTS
++    const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS");
++    vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks));
++    const char* output_tensor = getenv("GGML_VULKAN_OUTPUT_TENSOR");
++    vk_output_tensor = (output_tensor == NULL ? 0 : atoi(output_tensor));
++#endif
++}
++
++static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type type) {
++    VK_LOG_DEBUG("ggml_vk_get_to_fp16()");
++    switch (type) {
++        case GGML_TYPE_F32:
++        case GGML_TYPE_Q4_0:
++        case GGML_TYPE_Q4_1:
++        case GGML_TYPE_Q5_0:
++        case GGML_TYPE_Q5_1:
++        case GGML_TYPE_Q8_0:
++        case GGML_TYPE_Q2_K:
++        case GGML_TYPE_Q3_K:
++        case GGML_TYPE_Q4_K:
++        case GGML_TYPE_Q5_K:
++        case GGML_TYPE_Q6_K:
++        case GGML_TYPE_IQ1_S:
++        case GGML_TYPE_IQ1_M:
++        case GGML_TYPE_IQ2_XXS:
++        case GGML_TYPE_IQ2_XS:
++        case GGML_TYPE_IQ2_S:
++        case GGML_TYPE_IQ3_XXS:
++        case GGML_TYPE_IQ3_S:
++        case GGML_TYPE_IQ4_XS:
++        case GGML_TYPE_IQ4_NL:
++        case GGML_TYPE_MXFP4:
++            break;
++        default:
++            return nullptr;
++    }
++
++    return ctx->device->pipeline_dequant[type];
++}
++
++static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
++    VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ", " << prec << ")");
++    if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
++        return ctx->device->pipeline_matmul_f32;
++    }
++    if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
++        return ctx->device->pipeline_matmul_f32_f16;
++    }
++    if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
++        return ctx->device->pipeline_matmul_bf16;
++    }
++    if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
++        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_matmul_f16_f32.f16acc;
++        }
++        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
++            return ctx->device->pipeline_matmul_f16.f16acc;
++        }
++    } else {
++        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_matmul_f16_f32.f32acc;
++        }
++        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
++            return ctx->device->pipeline_matmul_f16.f32acc;
++        }
++    }
++
++    // MMQ
++    if (src1_type == GGML_TYPE_Q8_1) {
++        vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
++
++        if (pipelines->is_empty()) {
++            return nullptr;
++        }
++
++        return pipelines;
++    }
++
++    if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) {
++        return nullptr;
++    }
++
++    switch (src0_type) {
++        case GGML_TYPE_Q4_0:
++        case GGML_TYPE_Q4_1:
++        case GGML_TYPE_Q5_0:
++        case GGML_TYPE_Q5_1:
++        case GGML_TYPE_Q8_0:
++        case GGML_TYPE_Q2_K:
++        case GGML_TYPE_Q3_K:
++        case GGML_TYPE_Q4_K:
++        case GGML_TYPE_Q5_K:
++        case GGML_TYPE_Q6_K:
++        case GGML_TYPE_IQ1_S:
++        case GGML_TYPE_IQ1_M:
++        case GGML_TYPE_IQ2_XXS:
++        case GGML_TYPE_IQ2_XS:
++        case GGML_TYPE_IQ2_S:
++        case GGML_TYPE_IQ3_XXS:
++        case GGML_TYPE_IQ3_S:
++        case GGML_TYPE_IQ4_XS:
++        case GGML_TYPE_IQ4_NL:
++        case GGML_TYPE_MXFP4:
++            break;
++        default:
++            return nullptr;
++    }
++
++    if (ctx->device->coopmat2) {
++        assert(src1_type == GGML_TYPE_F16);
++        return prec == GGML_PREC_DEFAULT ? ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f32acc;
++    }
++    if (ctx->device->coopmat_support) {
++        return (ctx->device->fp16 && ctx->device->coopmat_acc_f16_support && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
++    }
++    return (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
++}
++
++static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols, uint32_t m, uint32_t k) {
++    VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()");
++    GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16 || b_type == GGML_TYPE_Q8_1);
++    GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols);
++
++    if (b_type == GGML_TYPE_Q8_1) {
++        switch (a_type) {
++            case GGML_TYPE_Q4_0:
++            case GGML_TYPE_Q4_1:
++            case GGML_TYPE_Q5_0:
++            case GGML_TYPE_Q5_1:
++            case GGML_TYPE_Q8_0:
++            case GGML_TYPE_MXFP4:
++            case GGML_TYPE_Q2_K:
++            case GGML_TYPE_Q3_K:
++            case GGML_TYPE_Q4_K:
++            case GGML_TYPE_Q5_K:
++            case GGML_TYPE_Q6_K:
++            case GGML_TYPE_IQ1_S:
++            case GGML_TYPE_IQ1_M:
++                break;
++            default:
++                return nullptr;
++        }
++    }
++
++    switch (a_type) {
++        case GGML_TYPE_F32:
++        case GGML_TYPE_F16:
++        case GGML_TYPE_BF16:
++        case GGML_TYPE_Q4_0:
++        case GGML_TYPE_Q4_1:
++        case GGML_TYPE_Q5_0:
++        case GGML_TYPE_Q5_1:
++        case GGML_TYPE_Q8_0:
++        case GGML_TYPE_Q2_K:
++        case GGML_TYPE_Q3_K:
++        case GGML_TYPE_Q4_K:
++        case GGML_TYPE_Q5_K:
++        case GGML_TYPE_Q6_K:
++        case GGML_TYPE_IQ1_S:
++        case GGML_TYPE_IQ1_M:
++        case GGML_TYPE_IQ2_XXS:
++        case GGML_TYPE_IQ2_XS:
++        case GGML_TYPE_IQ2_S:
++        case GGML_TYPE_IQ3_XXS:
++        case GGML_TYPE_IQ3_S:
++        case GGML_TYPE_IQ4_XS:
++        case GGML_TYPE_IQ4_NL:
++        case GGML_TYPE_MXFP4:
++            break;
++        default:
++            return nullptr;
++    }
++
++    // heuristic to choose workgroup size
++    uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
++    if ((ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA && ctx->device->architecture != vk_device_architecture::NVIDIA_PRE_TURING) || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
++        // Prefer larger workgroups when M is small, to spread the work out more
++        // and keep more SMs busy.
++        // q6_k seems to prefer small workgroup size even for "medium" values of M.
++        if (a_type == GGML_TYPE_Q6_K) {
++            if (m < 4096 && k >= 1024) {
++                dmmv_wg = DMMV_WG_SIZE_LARGE;
++            }
++        } else {
++            if (m <= 8192 && k >= 1024) {
++                dmmv_wg = DMMV_WG_SIZE_LARGE;
++            }
++        }
++    }
++
++    if (b_type == GGML_TYPE_Q8_1) {
++        if (ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
++            dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
++        }
++        return ctx->device->pipeline_dequant_mul_mat_vec_q8_1_f32[dmmv_wg][a_type][num_cols-1];
++    }
++
++    return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[dmmv_wg][a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[dmmv_wg][a_type][num_cols-1];
++}
++
++static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
++    VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_id_pipeline()");
++    if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
++        return ctx->device->pipeline_matmul_id_f32;
++    }
++    if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
++        return ctx->device->pipeline_matmul_id_bf16;
++    }
++    if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
++        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_matmul_id_f16_f32.f16acc;
++        }
++        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
++            return ctx->device->pipeline_matmul_id_f16.f16acc;
++        }
++    } else {
++        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_matmul_id_f16_f32.f32acc;
++        }
++        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
++            return ctx->device->pipeline_matmul_id_f16.f32acc;
++        }
++    }
++
++    // MMQ
++    if (src1_type == GGML_TYPE_Q8_1) {
++        vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc;
++
++        if (pipelines->is_empty()) {
++            return nullptr;
++        }
++
++        return pipelines;
++    }
++
++    GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16));
++
++    switch (src0_type) {
++        case GGML_TYPE_Q4_0:
++        case GGML_TYPE_Q4_1:
++        case GGML_TYPE_Q5_0:
++        case GGML_TYPE_Q5_1:
++        case GGML_TYPE_Q8_0:
++        case GGML_TYPE_Q2_K:
++        case GGML_TYPE_Q3_K:
++        case GGML_TYPE_Q4_K:
++        case GGML_TYPE_Q5_K:
++        case GGML_TYPE_Q6_K:
++        case GGML_TYPE_IQ1_S:
++        case GGML_TYPE_IQ1_M:
++        case GGML_TYPE_IQ2_XXS:
++        case GGML_TYPE_IQ2_XS:
++        case GGML_TYPE_IQ2_S:
++        case GGML_TYPE_IQ3_XXS:
++        case GGML_TYPE_IQ3_S:
++        case GGML_TYPE_IQ4_XS:
++        case GGML_TYPE_IQ4_NL:
++        case GGML_TYPE_MXFP4:
++            break;
++        default:
++            return nullptr;
++    }
++
++    vk_matmul_pipeline2& mmp = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type];
++    // XXX TODO 'prec' is not actually allowed in mul_mat_id.
++    bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/;
++    bool support_fp16acc = !mmp.f16acc->is_empty();
++    bool support_fp32acc = !mmp.f32acc->is_empty();
++
++    if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) {
++        return mmp.f16acc;
++    } else {
++        GGML_ASSERT(support_fp32acc);
++        return mmp.f32acc;
++    }
++}
++
++static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t m, uint32_t k) {
++    VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec_id()");
++    GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_Q8_1);
++
++    if (b_type == GGML_TYPE_Q8_1) {
++        switch (a_type) {
++            case GGML_TYPE_Q4_0:
++            case GGML_TYPE_Q4_1:
++            case GGML_TYPE_Q5_0:
++            case GGML_TYPE_Q5_1:
++            case GGML_TYPE_Q8_0:
++            case GGML_TYPE_MXFP4:
++            case GGML_TYPE_Q2_K:
++            case GGML_TYPE_Q3_K:
++            case GGML_TYPE_Q4_K:
++            case GGML_TYPE_Q5_K:
++            case GGML_TYPE_Q6_K:
++            case GGML_TYPE_IQ1_S:
++            case GGML_TYPE_IQ1_M:
++                break;
++            default:
++                return nullptr;
++        }
++    }
++
++    switch (a_type) {
++        case GGML_TYPE_F32:
++        case GGML_TYPE_F16:
++        case GGML_TYPE_BF16:
++        case GGML_TYPE_Q4_0:
++        case GGML_TYPE_Q4_1:
++        case GGML_TYPE_Q5_0:
++        case GGML_TYPE_Q5_1:
++        case GGML_TYPE_Q8_0:
++        case GGML_TYPE_Q2_K:
++        case GGML_TYPE_Q3_K:
++        case GGML_TYPE_Q4_K:
++        case GGML_TYPE_Q5_K:
++        case GGML_TYPE_Q6_K:
++        case GGML_TYPE_IQ1_S:
++        case GGML_TYPE_IQ1_M:
++        case GGML_TYPE_IQ2_XXS:
++        case GGML_TYPE_IQ2_XS:
++        case GGML_TYPE_IQ2_S:
++        case GGML_TYPE_IQ3_XXS:
++        case GGML_TYPE_IQ3_S:
++        case GGML_TYPE_IQ4_XS:
++        case GGML_TYPE_IQ4_NL:
++        case GGML_TYPE_MXFP4:
++            break;
++        default:
++            return nullptr;
++    }
++
++    // heuristic to choose workgroup size
++    uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
++    if ((ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA && ctx->device->architecture != vk_device_architecture::NVIDIA_PRE_TURING) || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
++        // Prefer larger workgroups when M is small, to spread the work out more
++        // and keep more SMs busy.
++        // q6_k seems to prefer small workgroup size even for "medium" values of M.
++        if (a_type == GGML_TYPE_Q6_K) {
++            if (m < 4096 && k >= 1024) {
++                dmmv_wg = DMMV_WG_SIZE_LARGE;
++            }
++        } else {
++            if (m <= 8192 && k >= 1024) {
++                dmmv_wg = DMMV_WG_SIZE_LARGE;
++            }
++        }
++    }
++
++    if (b_type == GGML_TYPE_Q8_1) {
++        if (ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
++            dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
++        }
++        return ctx->device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[dmmv_wg][a_type];
++    }
++
++    return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[dmmv_wg][a_type];
++}
++
++static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
++    VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")");
++    vk_buffer buf = ggml_vk_create_buffer(device, size,
++        {vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
++         vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
++
++    if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) {
++        fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory\n",
++            size/1024.0/1024.0);
++        device->device.freeMemory(buf->device_memory);
++        device->device.destroyBuffer(buf->buffer);
++        return nullptr;
++    }
++
++    std::lock_guard guard(device->mutex);
++    device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf));
++
++    return buf->ptr;
++}
++
++static void ggml_vk_host_free(vk_device& device, void* ptr) {
++    if (ptr == nullptr) {
++        return;
++    }
++    VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")");
++    std::lock_guard guard(device->mutex);
++
++    vk_buffer buf;
++    size_t index;
++    for (size_t i = 0; i < device->pinned_memory.size(); i++) {
++        const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]);
++        const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]);
++        if (ptr >= addr && ptr < endr) {
++            buf = std::get<2>(device->pinned_memory[i]);
++            index = i;
++            break;
++        }
++    }
++    if (buf == nullptr) {
++        fprintf(stderr, "WARNING: failed to free pinned memory: memory not in map\n");
++        return;
++    }
++
++    ggml_vk_destroy_buffer(buf);
++
++    device->pinned_memory.erase(device->pinned_memory.begin() + index);
++}
++
++static void ggml_vk_host_get(const vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) {
++    std::lock_guard guard(device->mutex);
++    buf = nullptr;
++    buf_offset = 0;
++    for (size_t i = 0; i < device->pinned_memory.size(); i++) {
++        const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]);
++        const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]);
++        if (ptr >= addr && ptr < endr) {
++            buf = std::get<2>(device->pinned_memory[i]);
++            buf_offset = ((const uint8_t *)ptr) - addr;
++            break;
++        }
++    }
++}
++
++static vk_subbuffer ggml_vk_tensor_subbuffer(
++    const ggml_backend_vk_context * ctx, const ggml_tensor * tensor, bool allow_misalign = false) {
++
++    vk_buffer buffer = nullptr;
++    size_t offset = 0;
++    if (ctx->device->uma) {
++        ggml_vk_host_get(ctx->device, tensor->data, buffer, offset);
++    }
++    if (!buffer) {
++        auto buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
++        buffer = buf_ctx->dev_buffer;
++        offset = vk_tensor_offset(tensor) + tensor->view_offs;
++    }
++    GGML_ASSERT(buffer != nullptr);
++
++    size_t size = ggml_nbytes(tensor);
++
++    size_t misalign_bytes = offset & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
++    // The shader must support misaligned offsets when indexing into the buffer
++    GGML_ASSERT(allow_misalign || misalign_bytes == 0);
++    offset &= ~misalign_bytes;
++    size += misalign_bytes;
++
++    return vk_subbuffer{buffer, offset, size};
++}
++
++static vk_submission ggml_vk_begin_submission(vk_device& device, vk_command_pool& p, bool one_time = true) {
++    vk_submission s;
++    s.buffer = ggml_vk_create_cmd_buffer(device, p);
++    if (one_time) {
++        s.buffer.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit });
++    } else {
++        s.buffer.begin({ vk::CommandBufferUsageFlags{} });
++    }
++
++    return s;
++}
++
++template  size_t push_constant_size(const T &t) {
++    static_assert(std::is_class::value, "T must be a struct/class");
++    GGML_UNUSED(t);
++    return sizeof(T);
++}
++template  size_t push_constant_size(const std::vector &t) {
++    GGML_UNUSED(t);
++    return sizeof(T) * t.size();
++}
++template  size_t push_constant_size(const std::array &t) {
++    GGML_UNUSED(t);
++    return sizeof(T) * N;
++}
++
++template  const T *push_constant_data(const T &t) {
++    static_assert(std::is_class::value, "T must be a struct/class");
++    return &t;
++}
++template  const T *push_constant_data(const std::vector &t) {
++    return t.data();
++}
++template  const T *push_constant_data(const std::array &t) {
++    return t.data();
++}
++
++template 
++static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list const& descriptor_buffer_infos, const T &push_constants, std::array elements) {
++    const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
++    const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
++    const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]);
++    VK_LOG_DEBUG("ggml_vk_dispatch_pipeline(" << pipeline->name << ", {";
++    for (auto& buffer : descriptor_buffer_infos) {
++        std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), ";
++    }
++    std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
++    GGML_ASSERT(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
++                wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
++                wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
++    GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
++    GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
++    GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
++    GGML_ASSERT(pipeline->push_constant_size == push_constant_size(push_constants));
++
++    vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++];
++    vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
++    ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {});
++
++    subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants));
++    subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline);
++    subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
++                                pipeline->layout,
++                                0,
++                                { descriptor_set },
++                                {});
++    subctx->s->buffer.dispatch(wg0, wg1, wg2);
++}
++
++static void ggml_vk_end_submission(vk_submission& s, std::vector wait_semaphores, std::vector signal_semaphores) {
++    s.buffer.end();
++
++    s.wait_semaphores = std::move(wait_semaphores);
++    s.signal_semaphores = std::move(signal_semaphores);
++}
++
++static void ggml_vk_ctx_end(vk_context& ctx) {
++    VK_LOG_DEBUG("ggml_vk_ctx_end(" << ctx << ", " << ctx->seqs.size() << ")");
++    if (ctx->s == nullptr) {
++        return;
++    }
++
++    ctx->s->buffer.end();
++    ctx->s = nullptr;
++}
++
++static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) {
++    VK_LOG_DEBUG("ggml_vk_ctx_begin(" << device->name << ")");
++    if (subctx->s != nullptr) {
++        ggml_vk_ctx_end(subctx);
++    }
++
++    subctx->seqs.push_back({ ggml_vk_begin_submission(device, *subctx->p) });
++    subctx->s = subctx->seqs[subctx->seqs.size() - 1].data();
++}
++
++static size_t ggml_vk_align_size(size_t width, size_t align) {
++    VK_LOG_DEBUG("ggml_vk_align_size(" << width << ", " << align << ")");
++    return CEIL_DIV(width, align) * align;
++}
++
++static void deferred_memcpy(void * dst, const void * src, size_t size, std::vector* memcpys = nullptr) {
++    if (memcpys == nullptr) {
++        memcpy(dst, src, size);
++    } else {
++        memcpys->emplace_back(dst, src, size);
++    }
++}
++
++static void deferred_memset(void * dst, uint32_t val, size_t size, std::vector* memsets = nullptr) {
++    if (memsets == nullptr) {
++        memset(dst, val, size);
++    } else {
++        memsets->emplace_back(dst, val, size);
++    }
++}
++
++static void ggml_vk_ensure_sync_staging_buffer(vk_device& device, size_t size) {
++    if (device->sync_staging == nullptr || device->sync_staging->size < size) {
++        VK_LOG_MEMORY("ggml_vk_ensure_sync_staging_buffer(" << size << ")");
++        ggml_vk_destroy_buffer(device->sync_staging);
++        device->sync_staging = ggml_vk_create_buffer_check(device, size,
++            vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
++            vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
++    }
++}
++
++static void ggml_vk_ensure_sync_staging_buffer(ggml_backend_vk_context * ctx, size_t size) {
++    if (ctx->sync_staging == nullptr || ctx->sync_staging->size < size) {
++        VK_LOG_MEMORY("ggml_vk_ensure_sync_staging_buffer(" << size << ")");
++        ggml_vk_destroy_buffer(ctx->sync_staging);
++        ctx->sync_staging = ggml_vk_create_buffer_check(ctx->device, size,
++            vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
++            vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
++    }
++}
++
++static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_context& subctx, vk_buffer& dst, size_t offset, const ggml_tensor * tensor, bool sync_staging = false) {
++    VK_LOG_DEBUG("ggml_vk_buffer_write_nc_async(" << tensor << ")");
++    GGML_ASSERT(!ggml_is_contiguous(tensor));
++    // Buffer is already mapped
++    if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
++        std::cerr << "ggml_vulkan: buffer_write_nc_async dst buffer is host_visible. Use synchronous write." << std::endl;
++        GGML_ABORT("fatal error");
++    }
++    // Check if src is pinned memory
++    vk_buffer buf = nullptr;
++    size_t buf_offset = 0;
++    ggml_vk_host_get(ctx->device, tensor->data, buf, buf_offset);
++
++    const uint64_t ne0 = tensor->ne[0];
++    const uint64_t ne1 = tensor->ne[1];
++    const uint64_t ne2 = tensor->ne[2];
++    const uint64_t ne3 = tensor->ne[3];
++    const uint64_t nb0 = tensor->nb[0];
++    const uint64_t nb1 = tensor->nb[1];
++    const uint64_t nb2 = tensor->nb[2];
++    const uint64_t nb3 = tensor->nb[3];
++    const ggml_type type = tensor->type;
++    const uint64_t ts = ggml_type_size(type);
++    const uint64_t bs = ggml_blck_size(type);
++
++    const uint64_t dstnb0 = ts;
++    const uint64_t dstnb1 = dstnb0*(ne0/bs);
++    const uint64_t dstnb2 = dstnb1*ne1;
++    const uint64_t dstnb3 = dstnb2*ne2;
++
++    const uint64_t ne = ggml_nelements(tensor);
++
++    if (buf != nullptr) {
++        // Memory is pinned, use as staging buffer
++        std::vector slices;
++
++        for (uint64_t i3 = 0; i3 < ne3; i3++) {
++            for (uint64_t i2 = 0; i2 < ne2; i2++) {
++                // Find longest contiguous slice
++                if (ne1*nb1 == dstnb2) {
++                    slices.push_back({ buf_offset + i3*nb3 + i2*nb2, offset + i3*dstnb3 + i2*dstnb2, dstnb2 });
++                } else {
++                    for (uint64_t i1 = 0; i1 < ne1; i1++) {
++                        if (ne0*nb0/bs == dstnb1) {
++                            slices.push_back({ buf_offset + i3*nb3 + i2*nb2 + i1*nb1, offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, dstnb1 });
++                        } else {
++                            const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1;
++                            const uint64_t d_off = offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1;
++                            for (uint64_t i0 = 0; i0 < ne0; i0++) {
++                                slices.push_back({ s_off + i1*nb0, d_off + i0*dstnb0, dstnb0 });
++                            }
++                        }
++                    }
++                }
++            }
++        }
++
++        ggml_vk_sync_buffers(ctx, subctx);
++        subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
++        return;
++    }
++
++    if (!sync_staging) {
++        GGML_ABORT("Asynchronous write to non-pinned memory not supported");
++    }
++
++    // Staging buffer required
++    vk_buffer& staging = ctx->device->sync_staging;
++    const uint64_t copy_size = ts*ne/bs;
++    ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size);
++    VkBufferCopy buf_copy{ 0, offset, copy_size };
++
++    ggml_vk_sync_buffers(ctx, subctx);
++    vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
++
++    for (uint64_t i3 = 0; i3 < ne3; i3++) {
++        for (uint64_t i2 = 0; i2 < ne2; i2++) {
++            // Find longest contiguous slice
++            if (ne1*nb1 == dstnb2) {
++                deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2, dstnb2, &subctx->in_memcpys);
++            } else {
++                for (uint64_t i1 = 0; i1 < ne1; i1++) {
++                    if (ne0*nb0/bs == dstnb1) {
++                        deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2 + i1*nb1, dstnb1, &subctx->in_memcpys);
++                    } else {
++                        const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1;
++                        const uint64_t d_off = i3*dstnb3 + i2*dstnb2 + i1*dstnb1;
++                        for (uint64_t i0 = 0; i0 < ne0; i0++) {
++                            deferred_memcpy((uint8_t *)staging->ptr + d_off + i0*dstnb0, (const uint8_t *) tensor->data + s_off + i0*nb0, dstnb0, &subctx->in_memcpys);
++                        }
++                    }
++                }
++            }
++        }
++    }
++}
++
++static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) {
++    VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")");
++    // Check if src is pinned memory
++    vk_buffer buf = nullptr;
++    size_t buf_offset = 0;
++    ggml_vk_host_get(dst->device, src, buf, buf_offset);
++
++    if (buf != nullptr) {
++        // Memory is pinned, use as staging buffer
++        std::vector slices(1);
++        if (width == spitch) {
++            // Only do single write if stride is equal
++            slices[0].srcOffset = buf_offset;
++            slices[0].dstOffset = offset;
++            slices[0].size = width * height;
++        } else {
++            slices.resize(height);
++            for (size_t i = 0; i < height; i++) {
++                slices[i].srcOffset = buf_offset + i * spitch;
++                slices[i].dstOffset = offset + i * width;
++                slices[i].size = width;
++            }
++        }
++
++        ggml_vk_sync_buffers(nullptr, subctx);
++        subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
++        return true;
++    }
++    VK_LOG_DEBUG("STAGING");
++
++    if (!sync_staging) {
++        // copy was not handled caller needs to fall back
++        return false;
++    }
++
++    // Staging buffer required
++    const size_t copy_size = width*height;
++    ggml_vk_ensure_sync_staging_buffer(dst->device, copy_size);
++
++    vk_buffer& staging_buffer = dst->device->sync_staging;
++
++    VkBufferCopy buf_copy = {
++        0,
++        offset,
++        copy_size};
++
++    ggml_vk_sync_buffers(nullptr, subctx);
++    vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
++
++    if (width == spitch) {
++        deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys);
++    } else {
++        for (size_t i = 0; i < height; i++) {
++            deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys);
++        }
++    }
++    return true;
++}
++
++static bool ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) {
++    VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")");
++    return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging);
++}
++
++static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) {
++    VK_LOG_DEBUG("ggml_vk_buffer_write_2d(" << width << ", " << height << ")");
++    // Buffer is already mapped
++    if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
++        GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);
++
++        for (size_t i = 0; i < height; i++) {
++            memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width);
++        }
++    } else {
++        std::lock_guard guard(dst->device->mutex);
++
++        vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
++        ggml_vk_ctx_begin(dst->device, subctx);
++        bool ret = ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true);
++        GGML_ASSERT(ret);
++        ggml_vk_ctx_end(subctx);
++
++        for (auto& cpy : subctx->in_memcpys) {
++            memcpy(cpy.dst, cpy.src, cpy.n);
++        }
++
++        for (auto& mset : subctx->memsets) {
++            memset(mset.dst, mset.val, mset.n);
++        }
++
++        ggml_vk_submit(subctx, dst->device->fence);
++        VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences");
++        dst->device->device.resetFences({ dst->device->fence });
++        ggml_vk_queue_command_pools_cleanup(dst->device);
++    }
++}
++
++static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src, size_t size) {
++    VK_LOG_DEBUG("ggml_vk_buffer_write(" << size << ")");
++    ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1);
++}
++
++static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) {
++    VK_LOG_DEBUG("ggml_vk_buffer_read_2d_async(offset=" << offset << ", width=" << width << ", height=" << height << ")");
++    GGML_ASSERT(width > 0);
++    GGML_ASSERT(height > 0);
++    GGML_ASSERT(src != nullptr);
++
++    // TODO: staging_offset is not used
++
++    // Check if dst is pinned memory
++    vk_buffer buf = nullptr;
++    size_t buf_offset = 0;
++    ggml_vk_host_get(src->device, dst, buf, buf_offset);
++
++    std::vector slices(1);
++    if (width == spitch && width == dpitch) {
++        // Only do single write if stride is equal
++        slices[0].srcOffset = offset;
++        slices[0].dstOffset = buf_offset;
++        slices[0].size = width * height;
++    } else {
++        slices.resize(height);
++        for (size_t i = 0; i < height; i++) {
++            slices[i].srcOffset = offset + i * spitch;
++            slices[i].dstOffset = buf_offset + i * dpitch;
++            slices[i].size = width;
++        }
++    }
++
++    if (buf != nullptr) {
++        // Memory is pinned, use as staging buffer
++        ggml_vk_sync_buffers(nullptr, subctx);
++        subctx->s->buffer.copyBuffer(src->buffer, buf->buffer, slices);
++
++        return true;
++    }
++    VK_LOG_DEBUG("STAGING");
++
++    if (!sync_staging) {
++        // copy was not handled caller needs to fall back
++        return false;
++    }
++
++    // Fall back to staging buffer
++    const size_t copy_size = dpitch * height;
++    ggml_vk_ensure_sync_staging_buffer(src->device, copy_size);
++
++    vk_buffer& staging_buffer = src->device->sync_staging;
++
++    ggml_vk_sync_buffers(nullptr, subctx);
++    subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices);
++
++    deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys);
++    return true;
++}
++
++static bool ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t size, bool sync_staging = false) {
++    return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, sync_staging);
++}
++
++static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) {
++    VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")");
++
++    // If the device is not an UMA device the memory is host-accessible through rebar. While writing
++    // through PCIe is sufficient fast reading back data from PCIe is slower than going through
++    // the HW device to host copy path.
++    if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) {
++        GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);
++
++        memcpy(dst, (uint8_t *) src->ptr + offset, size);
++    } else {
++        std::lock_guard guard(src->device->mutex);
++
++        vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
++        ggml_vk_ctx_begin(src->device, subctx);
++        bool ret = ggml_vk_buffer_read_async(subctx, src, offset, dst, size, true);
++        GGML_ASSERT(ret);
++        ggml_vk_ctx_end(subctx);
++
++        ggml_vk_submit(subctx, src->device->fence);
++        VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences");
++        src->device->device.resetFences({ src->device->fence });
++        ggml_vk_queue_command_pools_cleanup(src->device);
++
++        for (auto& cpy : subctx->out_memcpys) {
++            memcpy(cpy.dst, cpy.src, cpy.n);
++        }
++    }
++}
++
++static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
++    VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")");
++    // Make sure both buffers are on same device
++    GGML_ASSERT(src->device == dst->device);
++
++    VkBufferCopy bc{ src_offset, dst_offset, size };
++
++    vkCmdCopyBuffer(ctx->s->buffer, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, 1, &bc);
++}
++
++static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
++    if (src->device == dst->device) {
++        std::lock_guard guard(src->device->mutex);
++        VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")");
++        // Copy within the device
++        vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
++        ggml_vk_ctx_begin(src->device, subctx);
++        ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size);
++        ggml_vk_ctx_end(subctx);
++        ggml_vk_submit(subctx, src->device->fence);
++        VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences");
++        src->device->device.resetFences({ src->device->fence });
++        ggml_vk_queue_command_pools_cleanup(src->device);
++    } else {
++        VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")");
++        // Copy device to device
++        ggml_vk_ensure_sync_staging_buffer(src->device, size);
++
++        // Copy to src staging buffer
++        ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size);
++        // Copy to dst buffer
++        ggml_vk_buffer_write_2d(dst, dst_offset, src->device->sync_staging->ptr, 0, size, 1);
++    }
++}
++
++static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
++    VK_LOG_DEBUG("ggml_vk_buffer_memset_async(" << offset << ", " << c << ", " << size << ")");
++
++    if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible &&
++        dst->device->uma) {
++        deferred_memset((uint8_t*)dst->ptr + offset, c, size, &ctx->memsets);
++        return;
++    }
++
++    // Fall back to GPU fillBuffer for non-UMA or non-host-visible buffers
++    ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
++}
++
++static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
++    VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
++
++    if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible &&
++        dst->device->uma) {
++        memset((uint8_t*)dst->ptr + offset, c, size);
++        return;
++    }
++
++    std::lock_guard guard(dst->device->mutex);
++    vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
++    ggml_vk_ctx_begin(dst->device, subctx);
++    subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
++    ggml_vk_ctx_end(subctx);
++
++    ggml_vk_submit(subctx, dst->device->fence);
++    VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_memset waitForFences");
++    dst->device->device.resetFences({ dst->device->fence });
++    ggml_vk_queue_command_pools_cleanup(dst->device);
++}
++
++static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, bool disable_split_k, const vk_pipeline& pipeline) {
++    VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ", " << disable_split_k << ")");
++
++    if (disable_split_k) {
++        return 1;
++    }
++
++    uint32_t split_k = 1;
++    if (ctx->device->shader_core_count != 0 && m >= pipeline->wg_denoms[0] && n >= pipeline->wg_denoms[1]) {
++        // If k is 'large' and the SMs will fill less than halfway, use split_k.
++        uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]);
++        uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]);
++
++        if (k >= 2048) {
++            if (m_tiles * n_tiles <= ctx->device->shader_core_count / 2) {
++                split_k = ctx->device->shader_core_count / (m_tiles * n_tiles);
++            } else if (m_tiles * n_tiles <= ctx->device->shader_core_count * 2 / 3) {
++                split_k = 3;
++            }
++            // Cap the split at 8x. Unless k is huge this is a lot of overhead.
++            split_k = std::min(split_k, 8u);
++
++            // ggml_vk_matmul will align the splits to be a multiple of 256.
++            // If this rounded up size would cause the last split to be empty,
++            // then reduce the split count.
++            while (true) {
++                if (split_k == 1) {
++                    break;
++                }
++                uint32_t k_split = CEIL_DIV(k, split_k);
++                k_split = ROUNDUP_POW2(k_split, 256);
++                if (k_split * (split_k - 1) < k) {
++                    break;
++                }
++                split_k--;
++            }
++        }
++    }
++
++    return split_k;
++}
++
++static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
++    VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
++
++    if (ctx->device->coopmat2) {
++        const uint32_t shader_core_count = ctx->device->shader_core_count;
++        const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]);
++        const uint32_t tiles_m = CEIL_DIV(m, mmp->a_m->wg_denoms[0]) * CEIL_DIV(n, mmp->a_m->wg_denoms[1]);
++
++        // Use large shader when the N dimension is greater than the medium shader's tile size
++        uint32_t crossover_large = mmp->m->wg_denoms[1];
++
++        // Prefer large over medium if either:
++        // - medium or large tiles would overfill the GPU
++        // - large tiles with a split_k==3 fits in the GPU and medium tiles with split_k==2 does not
++        //   (medium with split_k==2 is probably better if it fits - more workgroups running and less split_k overhead)
++        bool prefer_large = tiles_m > shader_core_count || tiles_l > shader_core_count ||
++                            // split_k==3 with large tiles likely better than medium tiles with no split_k.
++                            (tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2);
++
++        if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
++            return aligned ? mmp->a_l : mmp->l;
++        }
++        // Use medium shader when the N dimension is greater than the small shader's tile size
++        uint32_t crossover_medium = mmp->s->wg_denoms[1];
++        if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) {
++            return aligned ? mmp->a_m : mmp->m;
++        }
++        return aligned ? mmp->a_s : mmp->s;
++    }
++
++    if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) {
++        return aligned ? mmp->a_s : mmp->s;
++    }
++    if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) {
++        return aligned ? mmp->a_m : mmp->m;
++    }
++    return aligned ? mmp->a_l : mmp->l;
++
++    GGML_UNUSED(src1_type);
++}
++
++static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
++    VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
++    return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align;
++}
++
++static void ggml_vk_matmul(
++        ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
++        vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer,
++        uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
++        uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
++        uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
++        uint32_t padded_n) {
++        VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
++    if (split_k == 1) {
++        ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
++
++        uint32_t base_work_group_z = 0;
++        while (base_work_group_z < batch) {
++            uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
++
++            const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k, ne02, ne12, broadcast2, broadcast3, padded_n };
++            ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, groups_z });
++            base_work_group_z += groups_z;
++        }
++        return;
++    }
++
++    if (ctx->prealloc_split_k_need_sync) {
++        ggml_vk_sync_buffers(ctx, subctx);
++    }
++
++    GGML_ASSERT(batch_stride_d == m * n);
++
++    // Round the split size up to a multiple of 256 (k-quant alignment)
++    uint32_t k_split = CEIL_DIV(k, split_k);
++    k_split = ROUNDUP_POW2(k_split, 256);
++
++    ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
++
++    uint32_t base_work_group_z = 0;
++    while (base_work_group_z < batch) {
++        uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
++
++        const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
++        // Make sure enough workgroups get assigned for split k to work
++        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, groups_z });
++        base_work_group_z += groups_z;
++    }
++    ggml_vk_sync_buffers(ctx, subctx);
++    const std::array pc2 = { (uint32_t)(m * n * batch), split_k };
++    ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 });
++    ctx->prealloc_split_k_need_sync = true;
++}
++
++static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
++    VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
++
++    if (ctx->device->coopmat2) {
++        // Use large shader when the N dimension is greater than the medium shader's tile size
++        uint32_t crossover_large = mmp->m->wg_denoms[1];
++        if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
++            return aligned ? mmp->a_l : mmp->l;
++        }
++        // Use medium shader when the N dimension is greater than the small shader's tile size
++        uint32_t crossover_medium = mmp->s->wg_denoms[1];
++        if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) {
++            return aligned ? mmp->a_m : mmp->m;
++        }
++        return aligned ? mmp->a_s : mmp->s;
++    }
++
++    if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) {
++        return aligned ? mmp->a_s : mmp->s;
++    }
++    if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) {
++        return aligned ? mmp->a_m : mmp->m;
++    }
++    return aligned ? mmp->a_l : mmp->l;
++}
++
++static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
++    VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
++    return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align;
++}
++
++static void ggml_vk_matmul_id(
++        ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
++        vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids, const vk_subbuffer & expert_count_buf,
++        uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
++        uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
++        uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11,
++        uint32_t padded_n) {
++    VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), expert_count: (" << expert_count_buf.buffer->buffer << ", " << expert_count_buf.offset << ", " << expert_count_buf.size << "), " <<
++        "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
++        "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
++        "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
++    const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
++                                              nei0, nei1, nbi1, ne11, padded_n };
++    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids, expert_count_buf }, pc, { m, nei1, n_as });
++}
++
++static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
++    return
++        tensor->nb[0] == ggml_type_size(tensor->type) &&
++        tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) &&
++        (tensor->ne[3] == 1 || tensor->nb[3] == tensor->nb[2]*tensor->ne[2]);
++}
++
++static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) {
++
++    // Choose "contiguous copy" shader if src/dst are contiguous
++    bool contig = ggml_is_contiguous(src) && (!dst || ggml_is_contiguous(dst));
++
++    // Use optimized "transpose" shader if src dim1 is the innermost dimension.
++    bool transpose = dst && src->nb[1] == ggml_type_size(to) && ggml_are_same_shape(dst, src);
++
++    if (transpose && src->type == to) {
++        if (ggml_type_size(to) == 4) {
++            return ctx->device->pipeline_cpy_transpose_32;
++        } else if (ggml_type_size(to) == 2) {
++            return ctx->device->pipeline_cpy_transpose_16;
++        }
++    }
++
++    if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F32) {
++        if (contig) {
++            return ctx->device->pipeline_contig_cpy_f32_f32;
++        } else {
++            return ctx->device->pipeline_cpy_f32_f32;
++        }
++    }
++    if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F16) {
++        if (contig) {
++            return ctx->device->pipeline_contig_cpy_f32_f16;
++        } else {
++            return ctx->device->pipeline_cpy_f32_f16;
++        }
++    }
++    if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F16) {
++        if (contig) {
++            return ctx->device->pipeline_contig_cpy_f16_f16;
++        } else {
++            return ctx->device->pipeline_cpy_f16_f16;
++        }
++    }
++    if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F32) {
++        if (contig) {
++            return ctx->device->pipeline_contig_cpy_f16_f32;
++        } else {
++            return ctx->device->pipeline_cpy_f16_f32;
++        }
++    }
++    if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_BF16) {
++        if (contig) {
++            return ctx->device->pipeline_contig_cpy_f32_bf16;
++        } else {
++            return ctx->device->pipeline_cpy_f32_bf16;
++        }
++    }
++    if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_I32) {
++        if (contig) {
++            return ctx->device->pipeline_contig_cpy_f32_i32;
++        } else {
++            return ctx->device->pipeline_cpy_f32_i32;
++        }
++    }
++    if (src->type == GGML_TYPE_I32 && to == GGML_TYPE_F32) {
++        if (contig) {
++            return ctx->device->pipeline_contig_cpy_i32_f32;
++        } else {
++            return ctx->device->pipeline_cpy_i32_f32;
++        }
++    }
++    if (src->type == GGML_TYPE_F32) {
++        switch (to) {
++        case GGML_TYPE_Q4_0:
++        case GGML_TYPE_Q4_1:
++        case GGML_TYPE_Q5_0:
++        case GGML_TYPE_Q5_1:
++        case GGML_TYPE_Q8_0:
++        case GGML_TYPE_IQ4_NL:
++            return ctx->device->pipeline_cpy_f32_quant[to];
++        default:
++            break;
++        }
++    }
++
++    if (to == GGML_TYPE_F32) {
++        switch (src->type) {
++        case GGML_TYPE_Q4_0:
++        case GGML_TYPE_Q4_1:
++        case GGML_TYPE_Q5_0:
++        case GGML_TYPE_Q5_1:
++        case GGML_TYPE_Q8_0:
++        case GGML_TYPE_IQ4_NL:
++            return ctx->device->pipeline_cpy_quant_f32[src->type];
++        default:
++            break;
++        }
++    }
++
++    if (src->type == to) {
++        // Copy two or four bytes at a time, depending on block size.
++        // For quantized types, we scale by block size/type size. But
++        // this path is also used for bf16->bf16 for example, where the
++        // type size must be exactly 2 or 4.
++        GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4);
++        if ((ggml_type_size(src->type) % 4) == 0) {
++            if (contig) {
++                return ctx->device->pipeline_contig_cpy_f32_f32;
++            } else {
++                return ctx->device->pipeline_cpy_f32_f32;
++            }
++        } else {
++            if (contig) {
++                return ctx->device->pipeline_contig_cpy_f16_f16;
++            } else {
++                return ctx->device->pipeline_cpy_f16_f16;
++            }
++        }
++    }
++
++    std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl;
++    GGML_ABORT("fatal error");
++}
++
++static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor, const vk_subbuffer & in, const vk_subbuffer & out) {
++    VK_LOG_DEBUG("ggml_vk_cpy_to_contiguous((" << tensor << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), ";
++    std::cerr << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")");
++    const int tensor_type_size = ggml_type_size(tensor->type);
++
++    const uint32_t ne = ggml_nelements(tensor);
++    std::array elements;
++
++    if (ne > 262144) {
++        elements = { 512, 512, CEIL_DIV(ne, 262144) };
++    } else if (ne > 512) {
++        elements = { 512, CEIL_DIV(ne, 512), 1 };
++    } else {
++        elements = { ne, 1, 1 };
++    }
++
++    vk_op_unary_push_constants pc = {
++        (uint32_t)ne,
++        (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size,
++        (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3],                       1                   , (uint32_t)tensor->ne[0]                   , (uint32_t)(tensor->ne[0] * tensor->ne[1]) , (uint32_t)(tensor->ne[0] * tensor->ne[1] * tensor->ne[2]),
++        0,
++        0.0f, 0.0f,
++        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
++    };
++    init_pushconst_fastdiv(pc);
++    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements);
++    ggml_vk_sync_buffers(ctx, subctx);
++}
++
++static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
++    switch(type) {
++        case GGML_TYPE_Q8_1:
++            return ctx->device->pipeline_quantize_q8_1_x4;
++        default:
++            std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl;
++            GGML_ABORT("fatal error");
++    }
++}
++
++static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, const vk_subbuffer & in, const vk_subbuffer & out, uint32_t ne) {
++    VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")");
++
++    vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
++
++    const uint32_t num_blocks = CEIL_DIV(ne, pipeline->wg_denoms[0]);
++    // clamp the number of elements to the max workgroup count. The shader will iterate over the total number of blocks.
++    const uint64_t max_elements = std::min(uint64_t{ctx->device->properties.limits.maxComputeWorkGroupCount[0]} * pipeline->wg_denoms[0], std::numeric_limits::max());
++    const uint32_t elements = std::min(ne, static_cast(max_elements));
++
++    const vk_quantize_q8_1_push_constants pc = {
++        ne,
++        num_blocks,
++    };
++
++    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, { elements, 1, 1 });
++    ggml_vk_sync_buffers(ctx, subctx);
++}
++
++static vk_pipeline ggml_vk_get_64b_indexing_pipeline(ggml_backend_vk_context * ctx, vk_pipeline &pipeline) {
++    GGML_UNUSED(ctx);
++#if defined(VK_EXT_shader_64bit_indexing)
++    vk_pipeline *ptr = &pipeline;
++    while (*ptr) {
++        if ((*ptr)->is_64b_indexing) {
++            return *ptr;
++        }
++        ptr = &(*ptr)->next;
++    }
++#endif
++    return pipeline;
++}
++
++static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool disable_split_k) {
++    VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << ggml_type_name(src0->type) << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
++    std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << ggml_type_name(src1->type) << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
++    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << ggml_type_name(dst->type) << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
++    std::cerr << "))");
++    GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16);  // NOLINT
++    GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);  // NOLINT
++
++    const uint64_t ne00 = src0->ne[0];
++    const uint64_t ne01 = src0->ne[1];
++    const uint64_t ne02 = src0->ne[2];
++    const uint64_t ne03 = src0->ne[3];
++
++    const uint64_t ne10 = src1->ne[0];
++    const uint64_t ne11 = src1->ne[1];
++    const uint64_t ne12 = src1->ne[2];
++    const uint64_t ne13 = src1->ne[3];
++
++    const uint64_t ne21 = dst->ne[1];
++    const uint32_t stride_d = dst->nb[1] / ggml_type_size(dst->type);
++    const uint32_t stride_batch_d = stride_d*ne21;
++
++    const uint64_t r2 = ne12 / ne02;
++    const uint64_t r3 = ne13 / ne03;
++
++    ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
++    ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
++    ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
++
++    vk_buffer d_Qx = nullptr;
++    size_t qx_buf_offset = 0;
++    vk_buffer d_Qy = nullptr;
++    size_t qy_buf_offset = 0;
++
++    bool src0_uma = false;
++    bool src1_uma = false;
++
++    if (ctx->device->uma) {
++        ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset);
++        ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
++        src0_uma = d_Qx != nullptr;
++        src1_uma = d_Qy != nullptr;
++    }
++
++    // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
++    const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
++                              !ggml_vk_dim01_contiguous(src0);
++    const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
++                              (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
++                              !ggml_vk_dim01_contiguous(src1);
++
++    // If src0 is BF16, try to use a BF16 x BF16 multiply
++    ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
++
++    const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
++
++    bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0;
++
++    // Check for mmq first
++    vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
++
++    if (mmp == nullptr) {
++        // Fall back to f16 dequant mul mat
++        mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
++        quantize_y = false;
++    }
++
++    const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
++    const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig);
++
++    if (qx_needs_dequant) {
++        // Fall back to dequant + f16 mulmat
++        mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]);
++    }
++
++    // Not implemented
++    GGML_ASSERT(y_non_contig || !qy_needs_dequant);  // NOLINT
++
++    const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
++    const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
++
++    vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
++
++    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
++        pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
++    }
++
++    // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
++    uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
++    const uint64_t x_ne = ggml_nelements(src0);
++    // 128 elements per Q8_1 x4 block
++    const uint64_t y_ne = padded_n * ne10 * ne12 * ne13;
++    const uint64_t d_ne = ggml_nelements(dst);
++
++    const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, disable_split_k, pipeline);
++
++    const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
++    const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
++    const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
++    const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
++    const uint64_t d_sz = sizeof(float) * d_ne;
++
++    vk_pipeline to_fp16_vk_0 = nullptr;
++    vk_pipeline to_fp16_vk_1 = nullptr;
++    vk_pipeline to_q8_1 = nullptr;
++
++    if (x_non_contig) {
++        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
++    } else {
++        to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
++    }
++    if (y_non_contig) {
++        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
++    } else {
++        to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
++    }
++    GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT
++    GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
++
++    if (quantize_y) {
++        to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
++    }
++
++    {
++        const uint64_t split_k_size = split_k > 1 ? d_sz * split_k : 0;
++        if (
++                (qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) ||
++                (qy_needs_dequant && y_sz > ctx->device->properties.limits.maxStorageBufferRange) ||
++                (split_k > 1 && split_k_size > ctx->device->properties.limits.maxStorageBufferRange)) {
++            GGML_ABORT("Requested preallocation size is too large");
++        }
++        if (qx_needs_dequant && ctx->prealloc_size_x < x_sz) {
++            ctx->prealloc_size_x = x_sz;
++            ggml_vk_preallocate_buffers(ctx, subctx);
++        }
++        if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz) {
++            ctx->prealloc_size_y = y_sz;
++            ggml_vk_preallocate_buffers(ctx, subctx);
++        }
++        if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {
++            ctx->prealloc_size_split_k = split_k_size;
++            ggml_vk_preallocate_buffers(ctx, subctx);
++        }
++
++        // Request descriptor sets
++        if (qx_needs_dequant) {
++            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
++        }
++        if (qy_needs_dequant) {
++            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
++        }
++        if (quantize_y) {
++            ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
++        }
++        if (split_k > 1) {
++            ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, 1);
++        }
++    }
++
++    vk_buffer d_D = dst_buf_ctx->dev_buffer;
++    const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
++    GGML_ASSERT(d_D != nullptr);
++    GGML_ASSERT(d_D->size >= d_buf_offset + d_sz);
++    vk_buffer d_X;
++    uint64_t x_buf_offset = 0;
++    vk_buffer d_Y;
++    uint64_t y_buf_offset = 0;
++    if (!src0_uma) {
++        d_Qx = src0_buf_ctx->dev_buffer;
++        qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
++        GGML_ASSERT(d_Qx != nullptr);
++    }
++    if (!src1_uma) {
++        d_Qy = src1_buf_ctx->dev_buffer;
++        qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
++        GGML_ASSERT(d_Qy != nullptr);
++    }
++    if (qx_needs_dequant) {
++        d_X = ctx->prealloc_x;
++        GGML_ASSERT(d_X->size >= x_sz);
++    } else {
++        d_X = d_Qx;
++        x_buf_offset = qx_buf_offset;
++        GGML_ASSERT(qx_sz == x_sz);
++    }
++    if (qy_needs_dequant) {
++        d_Y = ctx->prealloc_y;
++        GGML_ASSERT(d_Y->size >= y_sz);
++    } else if (quantize_y) {
++        d_Y = ctx->prealloc_y;
++        GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz, 144) * 144);
++    } else {
++        d_Y = d_Qy;
++        y_buf_offset = qy_buf_offset;
++        GGML_ASSERT(qy_sz == y_sz);
++    }
++
++    if (x_non_contig || qx_needs_dequant) {
++        if (ctx->prealloc_x_need_sync) {
++            ggml_vk_sync_buffers(ctx, subctx);
++        }
++    }
++
++    if (x_non_contig) {
++        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0));
++    } else if (qx_needs_dequant) {
++        const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
++        ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_X, 0, x_sz } }, pc, { (uint32_t)(x_ne), 1, 1});
++        ggml_vk_sync_buffers(ctx, subctx);
++    }
++    if (y_non_contig) {
++        if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
++            ctx->prealloc_y_last_tensor_used != src1) {
++            if (ctx->prealloc_y_need_sync) {
++                ggml_vk_sync_buffers(ctx, subctx);
++            }
++            ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
++            ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
++            ctx->prealloc_y_last_tensor_used = src1;
++        }
++    }
++    if (quantize_y) {
++        if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
++            ctx->prealloc_y_last_tensor_used != src1) {
++            if (ctx->prealloc_y_need_sync) {
++                ggml_vk_sync_buffers(ctx, subctx);
++            }
++            ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne);
++            ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
++            ctx->prealloc_y_last_tensor_used = src1;
++        }
++    }
++
++    uint32_t stride_batch_x = ne00*ne01;
++    uint32_t stride_batch_y = ne10*ne11;
++
++    if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
++        stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
++    }
++
++    if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) {
++        stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
++    }
++
++    // compute
++    ggml_vk_matmul(
++        ctx, subctx, pipeline,
++        { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz },
++        ggml_vk_subbuffer(ctx, d_D, d_buf_offset), { ctx->prealloc_split_k, 0, d_sz * split_k },
++        ne01, ne11, ne10,
++        ne10, ne10, stride_d, stride_batch_x, stride_batch_y, stride_batch_d,
++        split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
++    );  // NOLINT
++
++    if (x_non_contig || qx_needs_dequant) {
++        ctx->prealloc_x_need_sync = true;
++    }
++    if (y_non_contig || quantize_y) {
++        ctx->prealloc_y_need_sync = true;
++    }
++}
++
++// Device tuning
++static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_t n, uint32_t k, ggml_type src0_type) {
++    if (device->mmvq_mode == 1) {
++        return true;
++    } else if (device->mmvq_mode == -1) {
++        return false;
++    }
++
++    // General performance issue with q3_k and q6_k due to 2-byte alignment
++    if (src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) {
++        return false;
++    }
++
++    // MMVQ is generally good for batches
++    if (n > 1) {
++        return true;
++    }
++
++    // Quantization overhead is not worth it for small k
++    switch (device->vendor_id) {
++    case VK_VENDOR_ID_NVIDIA:
++        if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) {
++            return true;
++        }
++
++        if (k <= 4096) {
++            return false;
++        }
++
++        switch (src0_type) {
++        case GGML_TYPE_MXFP4:
++        case GGML_TYPE_Q8_0:
++            return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;
++        default:
++            return true;
++        }
++    case VK_VENDOR_ID_AMD:
++        if (k < 2048) {
++            return false;
++        }
++
++        switch (src0_type) {
++        case GGML_TYPE_Q8_0:
++            return device->architecture == vk_device_architecture::AMD_GCN;
++        default:
++            return true;
++        }
++    case VK_VENDOR_ID_INTEL:
++        if (k < 2048) {
++            return false;
++        }
++
++        switch (src0_type) {
++        // From tests on A770 Linux, may need more tuning
++        case GGML_TYPE_Q4_0:
++        case GGML_TYPE_Q5_1:
++            return false;
++        default:
++            return true;
++        }
++    default:
++        return true;
++    }
++
++    GGML_UNUSED(m);
++}
++
++static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
++    ggml_tensor * dst = cgraph->nodes[node_idx];
++    const ggml_tensor * src0 = dst->src[0];
++    const ggml_tensor * src1 = dst->src[1];
++
++    VK_LOG_DEBUG("ggml_vk_mul_mat_vec_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
++    std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
++    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
++    std::cerr << ")),)");
++    GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16);  // NOLINT
++    GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);  // NOLINT
++
++    const uint64_t ne00 = src0->ne[0];
++    const uint64_t ne01 = src0->ne[1];
++    const uint64_t ne02 = src0->ne[2];
++    const uint64_t ne03 = src0->ne[3];
++
++    const uint64_t ne10 = src1->ne[0];
++    const uint64_t ne11 = src1->ne[1];
++    const uint64_t ne12 = src1->ne[2];
++    const uint64_t ne13 = src1->ne[3];
++
++    const uint64_t ne20 = dst->ne[0];
++    const uint64_t ne21 = dst->ne[1];
++    // const uint64_t ne22 = dst->ne[2];
++    // const uint64_t ne23 = dst->ne[3];
++
++    const uint64_t r2 = ne12 / ne02;
++    const uint64_t r3 = ne13 / ne03;
++
++    // batch_n indicates that we need to compute a few vector results, and this assumes
++    // ne12 and ne13 are 1. It overloads the batch_strides to hold the row strides.
++    GGML_ASSERT(ne11 == 1 || ne12 * ne13 == 1);
++    bool batch_n = ne11 > 1;
++
++    const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
++    const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
++
++    const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
++    bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne11, ne10, src0->type);
++
++    vk_pipeline to_fp16_vk_0 = nullptr;
++    vk_pipeline to_fp16_vk_1 = nullptr;
++    if (x_non_contig) {
++        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type);
++    }
++    if (y_non_contig) {
++        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type);
++    } else {
++        to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
++    }
++
++    // Check for mmq first
++    vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, GGML_TYPE_Q8_1, ne11, ne20, ne00) : nullptr;
++    vk_pipeline to_q8_1 = nullptr;
++
++    if (dmmv == nullptr) {
++        // Fall back to f16 dequant mul mat
++        dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11, ne20, ne00);
++        quantize_y = false;
++    }
++
++    if (quantize_y) {
++        to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
++    }
++
++    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
++        dmmv = ggml_vk_get_64b_indexing_pipeline(ctx, dmmv);
++    }
++
++    const bool qx_needs_dequant = x_non_contig;
++    const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig);
++
++    // Not implemented
++    GGML_ASSERT(y_non_contig || !qy_needs_dequant);  // NOLINT
++
++    GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT
++    GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
++    GGML_ASSERT(dmmv != nullptr);
++
++    const uint64_t x_ne = ggml_nelements(src0);
++    const uint64_t y_ne = ggml_nelements(src1);
++
++    const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
++    const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
++    const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) :
++                         (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
++
++    {
++        if (
++                (qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) ||
++                (qy_needs_dequant && y_sz > ctx->device->properties.limits.maxStorageBufferRange)) {
++            GGML_ABORT("Requested preallocation size is too large");
++        }
++        if (qx_needs_dequant && ctx->prealloc_size_x < x_sz) {
++            ctx->prealloc_size_x = x_sz;
++            ggml_vk_preallocate_buffers(ctx, subctx);
++        }
++        if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz) {
++            ctx->prealloc_size_y = y_sz;
++            ggml_vk_preallocate_buffers(ctx, subctx);
++        }
++
++        // Request descriptor sets
++        if (qx_needs_dequant) {
++            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
++        }
++        if (qy_needs_dequant) {
++            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
++        }
++        if (quantize_y) {
++            ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
++        }
++    }
++
++    vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
++    vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0);
++    vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1);
++    vk_subbuffer d_X, d_Y;
++
++    if (qx_needs_dequant) {
++        d_X = { ctx->prealloc_x, 0, ctx->prealloc_x->size };
++    } else {
++        d_X = d_Qx;
++        GGML_ASSERT(qx_sz == x_sz);
++    }
++    if (qy_needs_dequant || quantize_y) {
++        d_Y = { ctx->prealloc_y, 0, ctx->prealloc_y->size };
++    } else {
++        d_Y = d_Qy;
++    }
++
++    if (x_non_contig) {
++        if (ctx->prealloc_x_need_sync) {
++            ggml_vk_sync_buffers(ctx, subctx);
++        }
++
++        GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
++        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, d_Qx, d_X);
++    }
++    if (y_non_contig) {
++        GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
++        if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
++            ctx->prealloc_y_last_tensor_used != src1) {
++            if (ctx->prealloc_y_need_sync) {
++                ggml_vk_sync_buffers(ctx, subctx);
++            }
++            ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, d_Qy, d_Y);
++            ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
++            ctx->prealloc_y_last_tensor_used = src1;
++        }
++    }
++    if (quantize_y) {
++        if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
++            ctx->prealloc_y_last_tensor_used != src1) {
++            if (ctx->prealloc_y_need_sync) {
++                ggml_vk_sync_buffers(ctx, subctx);
++            }
++            ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne);
++            ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
++            ctx->prealloc_y_last_tensor_used = src1;
++        }
++    }
++
++    // For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride
++    uint32_t stride_batch_x = batch_n ? 0 : ne00*ne01;
++    uint32_t stride_batch_y = batch_n ? ne10 : (ne10*ne11);
++    uint32_t stride_batch_d = batch_n ? ne20 : (ne20*ne21);
++
++    if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
++        stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
++    }
++
++    if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
++        stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
++    }
++
++    const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
++
++    uint32_t groups_x = ne01;
++    uint32_t groups_z = 1;
++
++    if (ne01 > max_groups_x) {
++        groups_z = 64;
++        groups_x = CEIL_DIV(groups_x, groups_z);
++    }
++
++    uint32_t fusion_flags = 0;
++
++    vk_subbuffer d_F0 = d_D;
++    if (ctx->num_additional_fused_ops > 0) {
++        const ggml_tensor * add = cgraph->nodes[node_idx + 1];
++        const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0];
++
++        d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);
++        fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;
++    }
++
++    vk_subbuffer d_F1 = d_D;
++    if (ctx->num_additional_fused_ops == 2) {
++        const ggml_tensor * add = cgraph->nodes[node_idx + 2];
++        const ggml_tensor * bias = add->src[0] == cgraph->nodes[node_idx + 1] ? add->src[1] : add->src[0];
++
++        d_F1 = ggml_vk_tensor_subbuffer(ctx, bias);
++        fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
++    }
++
++    ggml_pipeline_request_descriptor_sets(ctx, dmmv, CEIL_DIV(ne12 * ne13, ctx->device->properties.limits.maxComputeWorkGroupCount[1]));
++
++    uint32_t base_work_group_y = 0;
++    while (base_work_group_y < ne12 * ne13) {
++
++        uint32_t groups_y = std::min((uint32_t)(ne12 * ne13) - base_work_group_y, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
++        const vk_mat_vec_push_constants pc = {
++            (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
++            stride_batch_x, stride_batch_y, stride_batch_d,
++            fusion_flags, base_work_group_y,
++            (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
++        };
++        ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
++                                  {
++                                    d_X,
++                                    d_Y,
++                                    d_D,
++                                    d_F0,
++                                    d_F1,
++                                  },
++                                  pc, { groups_x, groups_y, groups_z });
++        base_work_group_y += groups_y;
++    }
++
++    if (x_non_contig) {
++        ctx->prealloc_x_need_sync = true;
++    }
++    if (y_non_contig || quantize_y) {
++        ctx->prealloc_y_need_sync = true;
++    }
++}
++
++static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
++    ggml_tensor * dst = cgraph->nodes[node_idx];
++    const ggml_tensor * src0 = dst->src[0];
++    const ggml_tensor * src1 = dst->src[1];
++    VK_LOG_DEBUG("ggml_vk_mul_mat_p021_f16_f32(" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
++    std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
++    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
++    std::cerr << "))");
++    GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
++    GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]);  // NOLINT
++    GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]);  // NOLINT
++    GGML_ASSERT(src0->type == GGML_TYPE_F16);
++    GGML_ASSERT(src1->type == GGML_TYPE_F32);
++
++    const uint64_t ne00 = src0->ne[0];
++    const uint64_t ne01 = src0->ne[1];
++    const uint64_t ne02 = src0->ne[2];
++    // const uint64_t ne03 = src0->ne[3];
++
++    //const uint64_t ne10 = src1->ne[0];
++    const uint64_t ne11 = src1->ne[1];
++    const uint64_t ne12 = src1->ne[2];
++    // const uint64_t ne13 = src1->ne[3];
++
++    GGML_ASSERT(ne11 == 1);
++
++    // With grouped query attention there are > 1 Q matrices per K, V matrix.
++    uint32_t gqa_ratio = (uint32_t)ne12 / (uint32_t)ne02;
++    if (gqa_ratio > 8 || gqa_ratio == 0 || ne12 != ne02 * gqa_ratio) {
++        gqa_ratio = 1;
++    }
++
++    vk_pipeline pipeline = ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1];
++
++    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
++        pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
++    }
++
++    {
++        // Request descriptor sets
++        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
++    }
++
++    vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops], true);
++    vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0);
++    vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1, true);
++
++    vk_subbuffer d_F0 = d_D;
++
++    uint32_t fusion_flags = 0;
++
++    if (ctx->num_additional_fused_ops > 0) {
++        const ggml_tensor * add = cgraph->nodes[node_idx + 1];
++        const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0];
++
++        d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);
++        fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;
++    }
++
++    vk_subbuffer d_F1 = d_D;
++    if (ctx->num_additional_fused_ops > 1) {
++        const ggml_tensor * bias = cgraph->nodes[node_idx + 2]->src[1];
++
++        d_F1 = ggml_vk_tensor_subbuffer(ctx, bias);
++        fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
++    }
++
++    // compute
++
++    vk_mat_vec_p021_push_constants pc = {
++        (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12,
++        0, 0, fusion_flags
++    };
++
++    init_pushconst_tensor_offsets(ctx, pc, src0, src1, nullptr, nullptr, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
++
++    uint32_t workgroups_z = (uint32_t)ne12;
++    // When gqa_ratio > 1, each invocation does multiple rows and we can launch fewer workgroups
++    if (gqa_ratio > 1) {
++        workgroups_z /= gqa_ratio;
++    }
++
++    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
++        {
++            d_Qx,
++            d_Qy,
++            d_D,
++            d_F0,
++            d_F1,
++        }, pc, { 1, (uint32_t)ne01, workgroups_z });
++}
++
++static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
++    ggml_tensor * dst = cgraph->nodes[node_idx];
++    const ggml_tensor * src0 = dst->src[0];
++    const ggml_tensor * src1 = dst->src[1];
++    VK_LOG_DEBUG("ggml_vk_mul_mat_nc_f16_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
++    std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
++    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
++    std::cerr << "))");
++    GGML_ASSERT(!ggml_is_transposed(src0));
++    GGML_ASSERT(!ggml_is_transposed(src1));
++    GGML_ASSERT(!ggml_is_permuted(src0));
++    GGML_ASSERT(src0->type == GGML_TYPE_F16);
++    GGML_ASSERT(src1->type == GGML_TYPE_F32);
++
++    const uint64_t ne00 = src0->ne[0];
++    const uint64_t ne01 = src0->ne[1];
++    const uint64_t ne02 = src0->ne[2];
++    const uint64_t ne03 = src0->ne[3];
++
++    const uint64_t nb01 = src0->nb[1];
++    const uint64_t nb02 = src0->nb[2];
++
++    const uint64_t nb12 = src1->nb[2];
++
++    // const uint64_t ne10 = src1->ne[0];
++    const uint64_t ne11 = src1->ne[1];
++    const uint64_t ne12 = src1->ne[2];
++    // const uint64_t ne13 = src1->ne[3];
++
++    const uint32_t nb03 = (uint32_t)(src0->nb[3] / sizeof(ggml_fp16_t));
++    const uint32_t nb13 = (uint32_t)(src1->nb[3] / sizeof(float));
++    const uint32_t nb23 = (uint32_t)(dst->nb[3] / sizeof(float));
++
++    GGML_ASSERT(ne11 == 1);
++    GGML_ASSERT(src0->ne[3] == src1->ne[3]); // checked in supports_op
++
++    const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t);
++    const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);
++    const uint32_t channel_stride_y = nb12 / sizeof(float);
++
++    vk_pipeline pipeline = ctx->device->pipeline_mul_mat_vec_nc_f16_f32;
++    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
++        pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
++    }
++
++    {
++        // Request descriptor sets
++        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
++    }
++
++    vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops], true);
++    vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0);
++    vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1, true);
++    vk_subbuffer d_F0 = d_D;
++
++    uint32_t fusion_flags = 0;
++
++    if (ctx->num_additional_fused_ops > 0) {
++        const ggml_tensor * add = cgraph->nodes[node_idx + 1];
++        const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0];
++
++        d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);
++        fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;
++    }
++
++    vk_subbuffer d_F1 = d_D;
++    if (ctx->num_additional_fused_ops > 1) {
++        const ggml_tensor * bias = cgraph->nodes[node_idx + 2]->src[1];
++
++        d_F1 = ggml_vk_tensor_subbuffer(ctx, bias);
++        fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
++    }
++
++    // compute
++    vk_mat_vec_nc_push_constants pc = {
++        (uint32_t)ne00, (uint32_t)ne01,
++        row_stride_x, channel_stride_x, channel_stride_y,
++        (uint32_t)(ne12 / ne02), (uint32_t)ne12,
++        0, 0,
++        nb03, nb13, nb23, fusion_flags
++    };
++
++    init_pushconst_tensor_offsets(ctx, pc, src0, src1, nullptr, nullptr, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
++
++    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
++        {
++            d_Qx,
++            d_Qy,
++            d_D,
++            d_F0,
++            d_F1,
++        }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 });
++}
++
++static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
++    ggml_tensor * dst = cgraph->nodes[node_idx];
++    ggml_tensor * src0 = dst->src[0];
++    ggml_tensor * src1 = dst->src[1];
++    VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")");
++
++    // Handle huge A matrix by splitting the M dimensions. This works well for convolution use cases
++    // where the M dimension is very large.
++    // Split_k doesn't work with M splitting.
++    // This only supports batchsize == 1.
++    const size_t nbytes = ggml_nbytes(src0);
++    const bool needs_split = dst->ne[2] == 1 && dst->ne[3] == 1 && nbytes > ctx->device->properties.limits.maxStorageBufferRange;
++    if (needs_split) {
++        // Choose the number of rows that can fit (and divide by two, to allow for any additional offsets)
++        const uint32_t M_split = ctx->device->properties.limits.maxStorageBufferRange / (2 * src0->nb[1]);
++        uint32_t m_offset = 0;
++        while (m_offset < dst->ne[0]) {
++            const uint32_t cur_M_size = std::min(M_split, (uint32_t)(dst->ne[0] - m_offset));
++            ggml_tensor dst2 = *dst;
++            ggml_tensor src02 = *src0;
++
++            dst2.view_src = dst->view_src ? dst->view_src : dst;
++            src02.view_src = src0->view_src ? src0->view_src : src0;
++
++            dst2.view_offs += m_offset * dst->nb[0];
++            src02.view_offs += m_offset * src0->nb[1];
++            dst2.ne[0] = cur_M_size;
++            src02.ne[1] = cur_M_size;
++
++            ggml_vk_mul_mat_q_f16(ctx, subctx, &src02, src1, &dst2, true);
++
++            m_offset += cur_M_size;
++        }
++    } else if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
++        // detect 0213 permutation, and batch size of 1
++        src0->nb[0] <= src0->nb[2] &&
++        src0->nb[2] <= src0->nb[1] &&
++        src0->nb[1] <= src0->nb[3] &&
++        src1->nb[0] <= src1->nb[2] &&
++        src1->nb[2] <= src1->nb[1] &&
++        src1->nb[1] <= src1->nb[3] &&
++        src0->ne[3] == 1 &&
++        src1->ne[3] == 1 &&
++        src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
++        src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
++        ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, cgraph, node_idx);
++    } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
++               !ggml_is_permuted(src0) && !ggml_is_permuted(src1) &&
++               src0->ne[3] <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
++               src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
++               src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
++        ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, cgraph, node_idx);
++    // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
++    // when ne12 and ne13 are one.
++    } else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
++               (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) {
++        ggml_vk_mul_mat_vec_q_f16(ctx, subctx, cgraph, node_idx);
++    } else {
++        ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, false);
++    }
++}
++
++static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
++    VK_LOG_DEBUG("ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
++    std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
++    std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
++    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
++    GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);  // NOLINT
++    GGML_ASSERT(ids->type == GGML_TYPE_I32);
++
++    const uint64_t ne00 = src0->ne[0];
++    const uint64_t ne01 = src0->ne[1];
++    const uint64_t ne02 = src0->ne[2];
++    // const uint64_t ne03 = src0->ne[3];
++
++    const uint64_t ne10 = src1->ne[0];
++    const uint64_t ne11 = src1->ne[1];
++    const uint64_t ne12 = src1->ne[2];
++    const uint64_t ne13 = src1->ne[3];
++
++    const uint64_t nei0 = ids->ne[0];
++    const uint64_t nei1 = ids->ne[1];
++
++    const uint32_t nbi0 = ids->nb[0];
++    const uint32_t nbi1 = ids->nb[1];
++    const uint32_t nbi2 = ids->nb[2];
++
++    const uint64_t ne20 = dst->ne[0];
++    const uint64_t ne21 = dst->ne[1];
++    // const uint64_t ne22 = dst->ne[2];
++    // const uint64_t ne23 = dst->ne[3];
++
++    const uint64_t n_as = ne02;
++
++    ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
++    ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
++    ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
++    ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;
++
++    vk_buffer d_Qx = nullptr;
++    size_t qx_buf_offset = 0;
++    vk_buffer d_Qy = nullptr;
++    size_t qy_buf_offset = 0;
++    vk_buffer d_ids = nullptr;
++    size_t ids_buf_offset = 0;
++
++    bool src0_uma = false;
++    bool src1_uma = false;
++    bool ids_uma = false;
++
++    if (ctx->device->uma) {
++        ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset);
++        ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
++        ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset);
++        src0_uma = d_Qx != nullptr;
++        src1_uma = d_Qy != nullptr;
++        ids_uma = d_ids != nullptr;
++    }
++
++    // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
++    const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
++                              !ggml_vk_dim01_contiguous(src0);
++    const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
++                              (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
++                              !ggml_vk_dim01_contiguous(src1);
++
++    // If src0 is BF16, try to use a BF16 x BF16 multiply
++    ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
++
++    const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
++
++    bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0;
++
++    // Check for mmq first
++    vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
++
++    if (mmp == nullptr) {
++        // Fall back to f16 dequant mul mat
++        mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
++        quantize_y = false;
++    }
++
++    const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
++    const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig);
++
++    if (qx_needs_dequant) {
++        // Fall back to dequant + f16 mulmat
++        mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]);
++    }
++
++    // Not implemented
++    GGML_ASSERT(y_non_contig || !qy_needs_dequant);  // NOLINT
++
++    const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
++    const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8;
++
++    vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
++
++    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
++        pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
++    }
++    // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
++    uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
++    const uint64_t x_ne = ggml_nelements(src0);
++    const uint64_t y_ne = padded_n * ne10 * ne12 * ne13;
++    const uint64_t d_ne = ggml_nelements(dst);
++
++    const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
++    const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
++    const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
++    const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
++    const uint64_t ids_sz = nbi2;
++    const uint64_t d_sz = sizeof(float) * d_ne;
++
++    vk_pipeline to_fp16_vk_0 = nullptr;
++    vk_pipeline to_fp16_vk_1 = nullptr;
++    vk_pipeline to_q8_1 = nullptr;
++
++    if (x_non_contig) {
++        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
++    } else {
++        to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
++    }
++    if (y_non_contig) {
++        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
++    } else {
++        to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
++    }
++    GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT
++    GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
++
++    if (quantize_y) {
++        to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
++    }
++    vk_pipeline count_experts = ctx->device->pipeline_count_experts;
++
++    uint32_t expert_count_size = sizeof(uint32_t) * n_as;
++
++    {
++        if (
++                (qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) ||
++                (qy_needs_dequant && y_sz > ctx->device->properties.limits.maxStorageBufferRange)) {
++            GGML_ABORT("Requested preallocation size is too large");
++        }
++        if (qx_needs_dequant && ctx->prealloc_size_x < x_sz) {
++            ctx->prealloc_size_x = x_sz;
++            ggml_vk_preallocate_buffers(ctx, subctx);
++        }
++        if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz) {
++            ctx->prealloc_size_y = y_sz;
++            ggml_vk_preallocate_buffers(ctx, subctx);
++        }
++        if (ctx->prealloc_size_split_k < expert_count_size) {
++            ctx->prealloc_size_split_k = expert_count_size;
++            ggml_vk_preallocate_buffers(ctx, subctx);
++        }
++
++        // Request descriptor sets
++        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
++        if (qx_needs_dequant) {
++            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
++        }
++        if (qy_needs_dequant) {
++            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
++        }
++        if (quantize_y) {
++            ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
++        }
++        ggml_pipeline_request_descriptor_sets(ctx, count_experts, 1);
++    }
++
++    vk_buffer d_D = dst_buf_ctx->dev_buffer;
++    const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
++    GGML_ASSERT(d_D != nullptr);
++    vk_buffer d_X;
++    uint64_t x_buf_offset = 0;
++    vk_buffer d_Y;
++    uint64_t y_buf_offset = 0;
++    if (!src0_uma) {
++        d_Qx = src0_buf_ctx->dev_buffer;
++        qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
++        GGML_ASSERT(d_Qx != nullptr);
++    }
++    if (!src1_uma) {
++        d_Qy = src1_buf_ctx->dev_buffer;
++        qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
++        GGML_ASSERT(d_Qy != nullptr);
++    }
++    if (!ids_uma) {
++        d_ids = ids_buf_ctx->dev_buffer;
++        ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs;
++        GGML_ASSERT(d_ids != nullptr);
++    }
++    if (qx_needs_dequant) {
++        d_X = ctx->prealloc_x;
++        GGML_ASSERT(d_X->size >= x_sz);
++    } else {
++        d_X = d_Qx;
++        x_buf_offset = qx_buf_offset;
++        GGML_ASSERT(qx_sz == x_sz);
++    }
++    if (qy_needs_dequant) {
++        d_Y = ctx->prealloc_y;
++        GGML_ASSERT(d_Y->size >= y_sz);
++    } else if (quantize_y) {
++        d_Y = ctx->prealloc_y;
++        GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz, 144) * 144);
++    } else {
++        d_Y = d_Qy;
++        y_buf_offset = qy_buf_offset;
++        GGML_ASSERT(qy_sz == y_sz);
++    }
++
++    if (x_non_contig || qx_needs_dequant) {
++        if (ctx->prealloc_x_need_sync) {
++            ggml_vk_sync_buffers(ctx, subctx);
++        }
++    }
++    // Count how many times each expert is used
++    vk_subbuffer expert_count_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
++    if (ctx->prealloc_split_k_need_sync) {
++        ggml_vk_sync_buffers(ctx, subctx);
++    }
++    {
++        const std::vector pc = { (uint32_t)nei0,
++                                           (uint32_t)nei1,
++                                           (uint32_t)(nbi0 / ggml_type_size(ids->type)),
++                                           (uint32_t)(nbi1 / ggml_type_size(ids->type)),
++                                           (uint32_t)(get_misalign_bytes(ctx, ids) / ggml_type_size(ids->type)) };
++        ggml_vk_dispatch_pipeline(ctx, subctx, count_experts,
++            { vk_subbuffer{ d_ids, ids_buf_offset, ids_sz }, expert_count_buf }, pc, { (uint32_t)n_as, 1, 1});
++    }
++
++    if (x_non_contig) {
++        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0));
++    } else if (qx_needs_dequant) {
++        const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
++        ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0,
++            { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_X, 0, x_sz } }, pc, { (uint32_t)x_ne, 1, 1});
++    }
++    if (y_non_contig) {
++        if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
++            ctx->prealloc_y_last_tensor_used != src1) {
++            if (ctx->prealloc_y_need_sync) {
++                ggml_vk_sync_buffers(ctx, subctx);
++            }
++            ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
++            ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
++            ctx->prealloc_y_last_tensor_used = src1;
++        }
++    }
++    if (quantize_y) {
++        if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
++            ctx->prealloc_y_last_tensor_used != src1) {
++            if (ctx->prealloc_y_need_sync) {
++                ggml_vk_sync_buffers(ctx, subctx);
++            }
++            ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne);
++            ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
++            ctx->prealloc_y_last_tensor_used = src1;
++        }
++    }
++    ggml_vk_sync_buffers(ctx, subctx);
++
++    uint32_t stride_batch_x = ne00*ne01;
++    uint32_t stride_batch_y = ne10*ne11;
++
++    if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
++        stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
++    }
++
++    if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) {
++        stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
++    }
++
++    // compute
++    ggml_vk_matmul_id(
++        ctx, subctx, pipeline,
++        { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz },
++        { d_D, d_buf_offset, d_sz }, { d_ids, ids_buf_offset, ids_sz }, expert_count_buf,
++        ne01, ne21, ne10, ne10, ne10, ne01,
++        stride_batch_x, stride_batch_y, ne20*ne21,
++        n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n
++    );  // NOLINT
++
++    if (x_non_contig || qx_needs_dequant) {
++        ctx->prealloc_x_need_sync = true;
++    }
++    if (y_non_contig || quantize_y) {
++        ctx->prealloc_y_need_sync = true;
++    }
++    ctx->prealloc_split_k_need_sync = true;
++}
++
++static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
++    ggml_tensor * dst = cgraph->nodes[node_idx];
++    ggml_tensor * src0 = dst->src[0];
++    ggml_tensor * src1 = dst->src[1];
++    ggml_tensor * ids = dst->src[2];
++    VK_LOG_DEBUG("ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
++    std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
++    std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
++    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
++    std::cerr << "))");
++    GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16);  // NOLINT
++    GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);  // NOLINT
++    GGML_ASSERT(ids->type == GGML_TYPE_I32);
++
++    const uint64_t ne00 = src0->ne[0];
++    const uint64_t ne01 = src0->ne[1];
++    // const uint64_t ne02 = src0->ne[2];
++    // const uint64_t ne03 = src0->ne[3];
++
++    const uint64_t ne10 = src1->ne[0];
++    const uint64_t ne11 = src1->ne[1];
++    const uint64_t ne12 = src1->ne[2];
++    // const uint64_t ne13 = src1->ne[3];
++
++    const uint64_t nei0 = ids->ne[0];
++    const uint64_t nei1 = ids->ne[1];
++    const uint32_t nbi1 = (uint32_t)(ids->nb[1] / sizeof(int));
++
++    const uint64_t ne20 = dst->ne[0];
++    const uint64_t ne21 = dst->ne[1];
++    // const uint64_t ne22 = dst->ne[2];
++    // const uint64_t ne23 = dst->ne[3];
++
++    const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
++    const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
++
++    const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
++    bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne12, ne10, src0->type);
++
++    vk_pipeline to_fp16_vk_0 = nullptr;
++    vk_pipeline to_fp16_vk_1 = nullptr;
++    if (x_non_contig) {
++        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type);
++    }
++    if (y_non_contig) {
++        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type);
++    } else {
++        to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
++    }
++
++    // Check for mmq first
++    vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, GGML_TYPE_Q8_1, ne20, ne00) : nullptr;
++    vk_pipeline to_q8_1 = nullptr;
++
++    if (dmmv == nullptr) {
++        // Fall back to f16 dequant mul mat
++        dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type, ne20, ne00);
++        quantize_y = false;
++    }
++
++    if (quantize_y) {
++        to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
++    }
++
++    const bool qx_needs_dequant = x_non_contig;
++    const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig);
++
++    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
++        dmmv = ggml_vk_get_64b_indexing_pipeline(ctx, dmmv);
++    }
++
++    // Not implemented
++    GGML_ASSERT(y_non_contig || !qy_needs_dequant);  // NOLINT
++    GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT
++    GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
++    GGML_ASSERT(dmmv != nullptr);
++
++    const uint64_t x_ne = ggml_nelements(src0);
++    const uint64_t y_ne = ggml_nelements(src1);
++
++    const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
++    const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
++    const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) :
++                                       (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
++
++    {
++        if (
++                (qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) ||
++                (qy_needs_dequant && y_sz > ctx->device->properties.limits.maxStorageBufferRange)) {
++            GGML_ABORT("Requested preallocation size is too large");
++        }
++        if (qx_needs_dequant && ctx->prealloc_size_x < x_sz) {
++            ctx->prealloc_size_x = x_sz;
++            ggml_vk_preallocate_buffers(ctx, subctx);
++        }
++        if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz) {
++            ctx->prealloc_size_y = y_sz;
++            ggml_vk_preallocate_buffers(ctx, subctx);
++        }
++
++        // Request descriptor sets
++        if (qx_needs_dequant) {
++            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
++        }
++        if (qy_needs_dequant) {
++            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
++        }
++        if (quantize_y) {
++            ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
++        }
++        ggml_pipeline_request_descriptor_sets(ctx, dmmv, nei1);
++    }
++
++    vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
++    vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0);
++    vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1);
++    vk_subbuffer d_ids = ggml_vk_tensor_subbuffer(ctx, ids);
++    vk_subbuffer d_F0 = d_D;
++    vk_subbuffer d_X, d_Y;
++
++    if (qx_needs_dequant) {
++        d_X = { ctx->prealloc_x, 0, ctx->prealloc_x->size };
++    } else {
++        d_X = d_Qx;
++    }
++    if (qy_needs_dequant || quantize_y) {
++        d_Y = { ctx->prealloc_y, 0, ctx->prealloc_y->size };
++    } else {
++        d_Y = d_Qy;
++    }
++
++    if (x_non_contig) {
++        if (ctx->prealloc_x_need_sync) {
++            ggml_vk_sync_buffers(ctx, subctx);
++        }
++    }
++
++    if (x_non_contig) {
++        GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
++        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, d_Qx, d_X);
++    }
++    if (y_non_contig) {
++        GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
++        if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
++            ctx->prealloc_y_last_tensor_used != src1) {
++            if (ctx->prealloc_y_need_sync) {
++                ggml_vk_sync_buffers(ctx, subctx);
++            }
++            ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, d_Qy, d_Y);
++            ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
++            ctx->prealloc_y_last_tensor_used = src1;
++        }
++    }
++    if (quantize_y) {
++        if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
++            ctx->prealloc_y_last_tensor_used != src1) {
++            if (ctx->prealloc_y_need_sync) {
++                ggml_vk_sync_buffers(ctx, subctx);
++            }
++            ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne);
++            ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
++            ctx->prealloc_y_last_tensor_used = src1;
++        }
++    }
++
++    uint32_t stride_batch_y = ne10*ne11;
++
++    if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
++        stride_batch_y = src1->nb[2] / ggml_type_size(src1->type);
++    }
++
++    const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
++
++    uint32_t groups_x = ne01;
++    uint32_t groups_z = 1;
++
++    if (ne01 > max_groups_x) {
++        groups_z = 64;
++        groups_x = CEIL_DIV(groups_x, groups_z);
++    }
++
++    uint32_t fusion_flags = 0;
++
++    if (ctx->num_additional_fused_ops > 0) {
++        const ggml_tensor * bias = cgraph->nodes[node_idx + 1]->src[1];
++
++        d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);
++
++        if (cgraph->nodes[node_idx + 1]->op == GGML_OP_MUL) {
++            fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE0;
++        } else {
++            GGML_ASSERT(cgraph->nodes[node_idx + 1]->op == GGML_OP_ADD_ID);
++            fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;
++        }
++    }
++
++    vk_subbuffer d_F1 = d_D;
++    if (ctx->num_additional_fused_ops > 1) {
++        const ggml_tensor * scale = cgraph->nodes[node_idx + 2]->src[1];
++
++        d_F1 = ggml_vk_tensor_subbuffer(ctx, scale);
++        fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1;
++    }
++
++    // Loop over the batch dimension
++    for (uint32_t expert_i1 = 0; expert_i1 < nei1; ++expert_i1) {
++        const vk_mat_vec_id_push_constants pc = {
++            (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
++            (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
++            fusion_flags,
++            (uint32_t)nei0, (uint32_t)ne11, expert_i1, nbi1
++        };
++        ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
++            {
++                d_X,
++                d_Y,
++                d_D,
++                d_F0,
++                d_F1,
++                d_ids,
++            },
++            pc, { groups_x, (uint32_t)nei0, groups_z });
++    }
++
++    if (x_non_contig) {
++        ctx->prealloc_x_need_sync = true;
++    }
++    if (y_non_contig || quantize_y) {
++        ctx->prealloc_y_need_sync = true;
++    }
++}
++
++static bool ggml_vk_use_mul_mat_vec_id(const struct ggml_cgraph * cgraph, int node_idx) {
++    ggml_tensor * dst = cgraph->nodes[node_idx];
++    ggml_tensor * src0 = dst->src[0];
++    ggml_tensor * src2 = dst->src[2];
++    return (src2->ne[1] <= 8) && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type));
++}
++
++static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
++    ggml_tensor * dst = cgraph->nodes[node_idx];
++    ggml_tensor * src0 = dst->src[0];
++    ggml_tensor * src1 = dst->src[1];
++    ggml_tensor * src2 = dst->src[2];
++    VK_LOG_DEBUG("ggml_vk_mul_mat_id(" << src0 << ", " << src1 << ", " << src2 << ", " << dst << ")");
++    if (ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) {
++        ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, cgraph, node_idx);
++    } else {
++        ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst);
++    }
++}
++
++static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
++    GGML_UNUSED(f32acc);
++    // Needs to be kept up to date on shader changes
++    const uint32_t wg_size = params.workgroup_size;
++    const uint32_t Br = params.block_rows;
++    const uint32_t Bc = params.block_cols;
++
++    const uint32_t float_type_size = device->flash_attention_fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
++
++    // tmpsh is overestimated slightly
++    const uint32_t tmpsh = wg_size * sizeof(float);
++    const uint32_t tmpshv4 = wg_size * 4 * float_type_size;
++
++    const uint32_t masksh = Bc * (Br + 1) * float_type_size;
++
++    const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
++
++    const uint32_t D = std::max(hsk, hsv);
++    const uint32_t kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
++
++    const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh;
++    const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
++
++    VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
++
++    return supported;
++}
++
++static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
++    // Needs to be kept up to date on shader changes
++    const uint32_t Br = params.block_rows;
++    const uint32_t Bc = params.block_cols;
++
++    const uint32_t MatBr = 16, MatBc = 16;
++
++    const uint32_t row_split = Bc / MatBc;
++
++    const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16);
++    const uint32_t hsv_pad = ROUNDUP_POW2(hsv, 16);
++
++    const uint32_t acctype = f32acc ? 4 : 2;
++    const uint32_t f16vec4 = 8;
++
++    const uint32_t tmpsh = (Bc / MatBc) * sizeof(float);
++
++    const uint32_t qstride = hsk_pad / 4 + 2;
++    const uint32_t Qf = Br * qstride * f16vec4;
++
++    const uint32_t psh_stride = Br / 4 + 2;
++    const uint32_t Psh = Bc * psh_stride * f16vec4;
++
++    const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
++    const uint32_t sfsh = Bc * sfshstride * acctype;
++
++    const uint32_t kvshstride = (params.shmem_staging ? std::max(hsk_pad, hsv_pad) : MatBr) / 4 + 2;
++    const uint32_t vsh_stride = MatBc / 4 * row_split;
++    const uint32_t ksh = ((kvshstride >= vsh_stride) ? (Bc * kvshstride) : (Bc * vsh_stride)) * f16vec4;
++
++    const uint32_t osh_stride = params.row_split * MatBr / 4;
++    const uint32_t pvsh = MatBc * osh_stride * f16vec4;
++
++    const uint32_t slope = Br * acctype;
++
++    const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + pvsh + slope;
++    const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
++
++    VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
++
++    return supported;
++}
++
++static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, const ggml_tensor * sinks, ggml_tensor * dst) {
++    VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
++    std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
++    std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3];
++    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
++    if (sinks) {
++        std::cerr << "), (" << sinks << ", name=" << sinks->name << ", type=" << sinks->type << ", ne0=" << sinks->ne[0] << ", ne1=" << sinks->ne[1] << ", ne2=" << sinks->ne[2] << ", ne3=" << sinks->ne[3] << ", nb0=" << sinks->nb[0] << ", nb1=" << sinks->nb[1] << ", nb2=" << sinks->nb[2] << ", nb3=" << sinks->nb[3];
++    }
++    std::cerr << "))");
++
++    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
++    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
++    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
++    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
++    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
++    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
++    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
++    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
++
++    const uint32_t nem0 = mask ? mask->ne[0] : 0;
++    const uint32_t nem1 = mask ? mask->ne[1] : 0;
++    const uint32_t nem2 = mask ? mask->ne[2] : 0;
++    const uint32_t nem3 = mask ? mask->ne[3] : 0;
++
++    const uint32_t HSK = nek0;
++    const uint32_t HSV = nev0;
++    uint32_t N = neq1;
++    const uint32_t KV = nek1;
++
++    GGML_ASSERT(ne0 == HSV);
++    GGML_ASSERT(ne2 == N);
++
++    // input tensor rows must be contiguous
++    GGML_ASSERT(nbq0 == ggml_type_size(q->type));
++    GGML_ASSERT(nbk0 == ggml_type_size(k->type));
++    GGML_ASSERT(nbv0 == ggml_type_size(v->type));
++
++    GGML_ASSERT(neq0 == HSK);
++
++    GGML_ASSERT(neq1 == N);
++
++    GGML_ASSERT(nev1 == nek1);
++
++    // dst cannot be transposed or permuted
++    GGML_ASSERT(nb0 == sizeof(float));
++    GGML_ASSERT(nb0 <= nb1);
++    GGML_ASSERT(nb1 <= nb2);
++    GGML_ASSERT(nb2 <= nb3);
++
++    assert(dst->type == GGML_TYPE_F32);
++    assert(q->type == GGML_TYPE_F32);
++    assert(k->type == v->type);
++
++    uint32_t gqa_ratio = 1;
++    uint32_t qk_ratio = neq2 / nek2;
++    uint32_t workgroups_x = (uint32_t)neq1;
++    uint32_t workgroups_y = (uint32_t)neq2;
++    uint32_t workgroups_z = (uint32_t)neq3;
++
++    const bool f32acc = !ctx->device->flash_attention_fp16 || dst->op_params[3] == GGML_PREC_F32;
++
++    // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
++    // For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
++    vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, f32acc);
++    const uint32_t max_gqa = std::min(tuning_params.block_rows, 32u);
++
++    if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa &&
++        qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
++        // grouped query attention - make the N dimension equal to gqa_ratio, reduce
++        // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
++        // and change addressing calculations to index Q's dimension 2.
++        gqa_ratio = qk_ratio;
++        N = gqa_ratio;
++        workgroups_y /= gqa_ratio;
++    }
++
++    tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, f32acc);
++
++    const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
++    uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
++    uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
++
++    // For F32, the shader treats it as a block of size 4 (for vec4 loads)
++    if (k->type == GGML_TYPE_F32) {
++        k_stride /= 4;
++    }
++    if (v->type == GGML_TYPE_F32) {
++        v_stride /= 4;
++    }
++
++    const uint32_t alignment = tuning_params.block_cols;
++    bool aligned = (KV % alignment) == 0 &&
++                   // the "aligned" shader variant will forcibly align strides, for performance
++                   (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
++
++    // Need to use the coopmat2 variant that clamps loads when HSK/HSV aren't sufficiently aligned.
++    if (((HSK | HSV) % 16) != 0 && tuning_params.path == FA_COOPMAT2) {
++        aligned = false;
++    }
++
++    float scale         = 1.0f;
++    float max_bias      = 0.0f;
++    float logit_softcap = 0.0f;
++
++    memcpy(&scale,         (const float *) dst->op_params + 0, sizeof(float));
++    memcpy(&max_bias,      (const float *) dst->op_params + 1, sizeof(float));
++    memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
++
++    if (logit_softcap != 0) {
++        scale /= logit_softcap;
++    }
++
++    // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
++    bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768;
++    vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(tuning_params, HSK, HSV, aligned, f32acc,
++                                                                   mask != nullptr, use_mask_opt, logit_softcap != 0);
++
++    vk_pipeline pipeline = nullptr;
++
++    {
++        std::lock_guard guard(ctx->device->mutex);
++        auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type];
++        auto it = pipelines.find(fa_pipeline_state);
++        if (it != pipelines.end()) {
++            pipeline = it->second;
++        } else {
++            pipelines[fa_pipeline_state] = pipeline = std::make_shared();
++        }
++    }
++
++    assert(pipeline);
++    // Compile early to initialize wg_denoms.
++    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
++
++    uint32_t split_kv = KV;
++    uint32_t split_k = 1;
++
++    // Intel Alchemist prefers more workgroups
++    const uint32_t shader_core_count_multiplier = (ctx->device->vendor_id == VK_VENDOR_ID_INTEL && ctx->device->architecture != INTEL_XE2) ? 2 : 1;
++
++    // Use a placeholder core count if one isn't available. split_k is a big help for perf.
++    const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count * shader_core_count_multiplier : 16;
++
++    const uint32_t Br = fa_pipeline_state.Br;
++    const uint32_t Bc = fa_pipeline_state.Bc;
++
++    GGML_ASSERT(Br == pipeline->wg_denoms[0]);
++    const uint32_t Tr = CEIL_DIV(N, Br);
++
++    // Try to use split_k when KV is large enough to be worth the overhead.
++    if (gqa_ratio > 1 && workgroups_x <= Br) {
++        split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z);
++    } else if (gqa_ratio <= 1) {
++        uint32_t total_wgs_no_split = Tr * workgroups_y * workgroups_z;
++        if (total_wgs_no_split < shader_core_count * 2) {
++            split_k = shader_core_count * 2 / total_wgs_no_split;
++        }
++    }
++
++    if (split_k > 1) {
++        // Try to evenly split KV into split_k chunks, but it needs to be a multiple
++        // of "align", so recompute split_k based on that.
++        split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
++        split_k = CEIL_DIV(KV, split_kv);
++    }
++
++    // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
++    // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
++    // For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3].
++    // For L/M, the order is (inner to outer) [ne1, k, ne2, ne3].
++    const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne2 * ne3 : 0;
++    if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) {
++        GGML_ABORT("Requested preallocation size is too large");
++    }
++    if (ctx->prealloc_size_split_k < split_k_size) {
++        ctx->prealloc_size_split_k = split_k_size;
++        ggml_vk_preallocate_buffers(ctx, subctx);
++    }
++
++    const uint32_t mask_opt_num_dwords = CEIL_DIV(nem0, 16 * Bc);
++    const uint64_t mask_opt_size = sizeof(uint32_t) * mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2 * nem3;
++
++    vk_pipeline pipeline_fa_mask_opt = nullptr;
++    if (use_mask_opt) {
++        std::lock_guard guard(ctx->device->mutex);
++        auto &pipelines = ctx->device->pipeline_fa_mask_opt;
++        auto it = pipelines.find({Br, Bc});
++        if (it != pipelines.end()) {
++            pipeline_fa_mask_opt = it->second;
++        } else {
++            pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared();
++        }
++        assert(pipeline_fa_mask_opt);
++        ggml_pipeline_request_descriptor_sets(ctx, pipeline_fa_mask_opt, 1);
++
++        if (ctx->prealloc_size_y < mask_opt_size) {
++            ctx->prealloc_size_y = mask_opt_size;
++            ggml_vk_preallocate_buffers(ctx, subctx);
++        }
++        if (ctx->prealloc_y_need_sync) {
++            ggml_vk_sync_buffers(ctx, subctx);
++        }
++    }
++
++    const uint32_t n_head_kv   = neq2;
++    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
++    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
++    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
++
++    vk_subbuffer q_buf = ggml_vk_tensor_subbuffer(ctx, q);
++    vk_subbuffer k_buf = ggml_vk_tensor_subbuffer(ctx, k);
++    vk_subbuffer v_buf = ggml_vk_tensor_subbuffer(ctx, v);
++    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
++    vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer(ctx, mask) : q_buf;
++    vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf;
++    vk_subbuffer mask_opt_buf = use_mask_opt ? ggml_vk_subbuffer(ctx, ctx->prealloc_y, 0) : q_buf;
++
++    uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | n_head_log2;
++
++    if (use_mask_opt)
++    {
++        const vk_op_flash_attn_mask_opt_push_constants opt_pc = {
++            nem0,
++            nem1,
++            nem2,
++            (uint32_t)(mask->nb[1] / sizeof(ggml_fp16_t)),
++            (uint32_t)(mask->nb[2] / sizeof(ggml_fp16_t)),
++            (uint32_t)(mask->nb[3] / sizeof(ggml_fp16_t)),
++            mask_opt_num_dwords,
++            mask_opt_num_dwords * CEIL_DIV(nem1, Br),
++            mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2,
++        };
++
++        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline_fa_mask_opt,
++                                  { mask_buf, mask_opt_buf }, opt_pc,
++                                  { mask_opt_num_dwords, CEIL_DIV(nem1, Br), nem2 * nem3 });
++        ggml_vk_sync_buffers(ctx, subctx);
++    }
++
++    const vk_flash_attn_push_constants pc = { N, KV,
++                                              (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
++                                              (uint32_t)neq2, (uint32_t)neq3,
++                                              (uint32_t)nek2, (uint32_t)nek3,
++                                              (uint32_t)nev2, (uint32_t)nev3,
++                                              nem1, nem2, nem3,
++                                              q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
++                                              k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
++                                              v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
++                                              scale, max_bias, logit_softcap,
++                                              mask_n_head_log2, m0, m1,
++                                              gqa_ratio, split_kv, split_k };
++
++    if (split_k > 1) {
++        ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
++
++        if (ctx->prealloc_split_k_need_sync) {
++            ggml_vk_sync_buffers(ctx, subctx);
++        }
++
++        // We reuse workgroups_x to mean the number of splits, so we need to
++        // cancel out the divide by wg_denoms[0].
++        uint32_t dispatch_x;
++        if (gqa_ratio > 1) {
++            workgroups_x *= pipeline->wg_denoms[0];
++            dispatch_x = split_k * workgroups_x;
++        } else {
++            dispatch_x = Tr * split_k * pipeline->wg_denoms[0];
++        }
++
++        vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
++        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
++                                    {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf, mask_opt_buf},
++                                    pc, { dispatch_x, workgroups_y, workgroups_z });
++
++        ggml_vk_sync_buffers(ctx, subctx);
++        const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) };
++        ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
++                                    {split_k_buf, sinks_buf, dst_buf},
++                                    pc2, { (uint32_t)ne1, HSV, (uint32_t)(ne2 * ne3) });
++        ctx->prealloc_split_k_need_sync = true;
++    } else {
++        if (gqa_ratio > 1) {
++            // When using gqa, we want one actual workgroup per batch, so cancel out wg_denoms
++            workgroups_x *= pipeline->wg_denoms[0];
++        }
++        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
++                                    {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf, mask_opt_buf},
++                                    pc, { workgroups_x, workgroups_y, workgroups_z });
++    }
++}
++
++static vk_conv_shapes ggml_vk_conv_select_shape(ggml_backend_vk_context * ctx, uint32_t K, uint32_t NPQ) {
++    auto n_tiles = [&](vk_conv_shapes s) {
++        return CEIL_DIV(K, vk_conv_block_sizes[s].K)
++            * CEIL_DIV(NPQ, vk_conv_block_sizes[s].NPQ);
++    };
++
++    // We can't query number of shader cores on Intel, use 32 as a placeholder
++    // so small convolutions will still choose a smaller tile.
++    const uint32_t shader_core_count = ctx->device->shader_core_count > 0 ? ctx->device->shader_core_count : 32;
++
++    if (K > 64 && n_tiles(CONV_SHAPE_128x128) >= shader_core_count * 2) {
++        return CONV_SHAPE_128x128;
++    } else if (K <= 32 && n_tiles(CONV_SHAPE_32x256) >= shader_core_count * 2) {
++        return CONV_SHAPE_32x256;
++    } else {
++        return CONV_SHAPE_64x32;
++    }
++}
++
++static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * dst, ggml_op op) {
++    switch (op) {
++    case GGML_OP_GET_ROWS:
++        GGML_ASSERT(src1->type == GGML_TYPE_I32);
++        if (src0->type == GGML_TYPE_I32) {
++            // i32 src only supports i32 result
++            GGML_ASSERT(dst->type == GGML_TYPE_I32);
++            return ctx->device->pipeline_get_rows[src0->type];
++        }
++        if (dst->type == GGML_TYPE_F16) {
++            return ctx->device->pipeline_get_rows[src0->type];
++        }
++        if (dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_get_rows_f32[src0->type];
++        }
++        return nullptr;
++    case GGML_OP_ACC:
++        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_acc_f32;
++        }
++        return nullptr;
++    case GGML_OP_SET:
++        if (src0->type == src1->type && src0->type == dst->type &&
++            (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32)) {
++            return ctx->device->pipeline_set_f32;
++        }
++        return nullptr;
++    case GGML_OP_ADD:
++    case GGML_OP_SUB:
++    case GGML_OP_MUL:
++    case GGML_OP_DIV:
++        if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
++            (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) ||
++            (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16)) {
++            return nullptr;
++        }
++        switch (op) {
++        case GGML_OP_ADD:
++        {
++            if (ctx->num_additional_fused_ops > 0) {
++                if (ctx->do_add_rms_partials) {
++                    return ctx->device->pipeline_multi_add_rms[ctx->num_additional_fused_ops];
++                } else {
++                    return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
++                }
++            }
++            if (ctx->do_add_rms_partials) {
++                auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms;
++                return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
++            } else {
++                auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
++                return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
++            }
++        }
++        case GGML_OP_SUB:
++        {
++            auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_norepeat : ctx->device->pipeline_sub;
++            return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
++        }
++        case GGML_OP_MUL:
++        {
++            auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_norepeat : ctx->device->pipeline_mul;
++            return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
++        }
++        case GGML_OP_DIV:
++        {
++            auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_norepeat : ctx->device->pipeline_div;
++            return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
++        }
++        default:
++            break;
++        }
++        return nullptr;
++    case GGML_OP_ADD_ID:
++        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && src2->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_add_id_f32;
++        }
++        return nullptr;
++    case GGML_OP_CONCAT:
++        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_concat_f32;
++        }
++        if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
++            return ctx->device->pipeline_concat_f16;
++        }
++        if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
++            return ctx->device->pipeline_concat_i32;
++        }
++        return nullptr;
++    case GGML_OP_UPSCALE:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            uint32_t mode = (ggml_get_op_params_i32(dst, 0) & (0xFF | GGML_SCALE_FLAG_ANTIALIAS));
++            switch (mode) {
++                case GGML_SCALE_MODE_NEAREST:
++                    return ctx->device->pipeline_upscale_nearest_f32;
++                case GGML_SCALE_MODE_BILINEAR:
++                    return ctx->device->pipeline_upscale_bilinear_f32;
++                case GGML_SCALE_MODE_BICUBIC:
++                    return ctx->device->pipeline_upscale_bicubic_f32;
++                case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS:
++                    return ctx->device->pipeline_upscale_bilinear_antialias_f32;
++                default:
++                    return nullptr;
++            }
++        }
++        return nullptr;
++    case GGML_OP_SCALE:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_scale_f32;
++        }
++        return nullptr;
++    case GGML_OP_SQR:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_sqr_f32;
++        }
++        return nullptr;
++    case GGML_OP_SQRT:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_sqrt_f32;
++        }
++        return nullptr;
++    case GGML_OP_SIN:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_sin_f32;
++        }
++        return nullptr;
++    case GGML_OP_COS:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_cos_f32;
++        }
++        return nullptr;
++    case GGML_OP_LOG:
++        if (src0->type == dst->type &&
++            (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
++            return ctx->device->pipeline_log[dst->type == GGML_TYPE_F16];
++        }
++        return nullptr;
++    case GGML_OP_TRI:
++        if (src0->type == dst->type &&
++            (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
++            return ctx->device->pipeline_tri[dst->type == GGML_TYPE_F16];
++        }
++        return nullptr;
++    case GGML_OP_DIAG:
++        if (src0->type == dst->type &&
++            (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
++            return ctx->device->pipeline_diag[dst->type == GGML_TYPE_F16];
++        }
++        return nullptr;
++    case GGML_OP_CLAMP:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_clamp_f32;
++        }
++        return nullptr;
++    case GGML_OP_PAD:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_pad_f32;
++        }
++        return nullptr;
++    case GGML_OP_ROLL:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_roll_f32;
++        }
++        return nullptr;
++    case GGML_OP_REPEAT:
++        if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) {
++            return ctx->device->pipeline_repeat_f32;
++        }
++        return nullptr;
++    case GGML_OP_REPEAT_BACK:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_repeat_back_f32;
++        }
++        return nullptr;
++    case GGML_OP_CPY:
++    case GGML_OP_CONT:
++    case GGML_OP_DUP:
++        return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
++    case GGML_OP_SET_ROWS:
++        if (src1->type == GGML_TYPE_I64) {
++            return ctx->device->pipeline_set_rows_i64[dst->type];
++        } else {
++            return ctx->device->pipeline_set_rows_i32[dst->type];
++        }
++    case GGML_OP_SILU_BACK:
++        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_silu_back_f32;
++        }
++        return nullptr;
++    case GGML_OP_NORM:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_norm_f32;
++        }
++        return nullptr;
++    case GGML_OP_GROUP_NORM:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_group_norm_f32;
++        }
++        return nullptr;
++    case GGML_OP_RMS_NORM:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            if (ctx->do_add_rms_partials) {
++                return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_partials_f32 : ctx->device->pipeline_rms_norm_partials_f32;
++            } else {
++                return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
++            }
++        }
++        return nullptr;
++    case GGML_OP_RMS_NORM_BACK:
++        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_rms_norm_back_f32;
++        }
++        return nullptr;
++    case GGML_OP_L2_NORM:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_l2_norm_f32;
++        }
++        return nullptr;
++    case GGML_OP_UNARY:
++        if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
++            (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
++            (src0->type != dst->type)) {
++            return nullptr;
++        }
++
++        switch (ggml_get_unary_op(dst)) {
++            case GGML_UNARY_OP_EXP:
++                return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16];
++            case GGML_UNARY_OP_SILU:
++                return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
++            case GGML_UNARY_OP_GELU:
++                return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16];
++            case GGML_UNARY_OP_GELU_ERF:
++                return ctx->device->pipeline_gelu_erf[dst->type == GGML_TYPE_F16];
++            case GGML_UNARY_OP_GELU_QUICK:
++                return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
++            case GGML_UNARY_OP_RELU:
++                return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16];
++            case GGML_UNARY_OP_XIELU:
++                return ctx->device->pipeline_xielu[dst->type == GGML_TYPE_F16];
++            case GGML_UNARY_OP_NEG:
++                return ctx->device->pipeline_neg[dst->type == GGML_TYPE_F16];
++            case GGML_UNARY_OP_TANH:
++                return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16];
++            case GGML_UNARY_OP_SIGMOID:
++                return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16];
++            case GGML_UNARY_OP_HARDSIGMOID:
++                return ctx->device->pipeline_hardsigmoid[dst->type == GGML_TYPE_F16];
++            case GGML_UNARY_OP_HARDSWISH:
++                return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16];
++            case GGML_UNARY_OP_ABS:
++                return ctx->device->pipeline_abs[dst->type == GGML_TYPE_F16];
++            case GGML_UNARY_OP_SOFTPLUS:
++                return ctx->device->pipeline_softplus[dst->type == GGML_TYPE_F16];
++            case GGML_UNARY_OP_STEP:
++                return ctx->device->pipeline_step[dst->type == GGML_TYPE_F16];
++            case GGML_UNARY_OP_ROUND:
++                return ctx->device->pipeline_round[dst->type == GGML_TYPE_F16];
++            case GGML_UNARY_OP_CEIL:
++                return ctx->device->pipeline_ceil[dst->type == GGML_TYPE_F16];
++            case GGML_UNARY_OP_FLOOR:
++                return ctx->device->pipeline_floor[dst->type == GGML_TYPE_F16];
++            case GGML_UNARY_OP_TRUNC:
++                return ctx->device->pipeline_trunc[dst->type == GGML_TYPE_F16];
++            default:
++                break;
++        }
++        return nullptr;
++    case GGML_OP_GLU:
++        if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
++            (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
++            (src0->type != dst->type)) {
++            return nullptr;
++        }
++
++        switch (ggml_get_glu_op(dst)) {
++            case GGML_GLU_OP_GEGLU:
++                return ctx->device->pipeline_geglu[dst->type == GGML_TYPE_F16];
++            case GGML_GLU_OP_REGLU:
++                return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
++            case GGML_GLU_OP_SWIGLU:
++                return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
++            case GGML_GLU_OP_SWIGLU_OAI:
++                return ctx->device->pipeline_swiglu_oai[dst->type == GGML_TYPE_F16];
++            case GGML_GLU_OP_GEGLU_ERF:
++                return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
++            case GGML_GLU_OP_GEGLU_QUICK:
++                return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16];
++            default:
++                break;
++        }
++        return nullptr;
++    case GGML_OP_DIAG_MASK_INF:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_diag_mask_inf_f32;
++        }
++        return nullptr;
++    case GGML_OP_SOFT_MAX:
++        GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
++        GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
++
++        if (ctx->num_additional_fused_ops) {
++            uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
++            GGML_ASSERT(idx < num_topk_moe_pipelines);
++            // use n_experts from push constant if it's not equal to the power of two spec constant
++            bool use_push = dst->ne[0] != (1u << idx);
++            return ctx->device->pipeline_topk_moe[idx][use_push];
++        }
++
++        if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
++            return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
++        }
++        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
++            return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16;
++        }
++        return nullptr;
++    case GGML_OP_SOFT_MAX_BACK:
++        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_soft_max_back_f32;
++        }
++        return nullptr;
++    case GGML_OP_ROPE:
++    case GGML_OP_ROPE_BACK:
++        {
++            const ggml_tensor *rope = ctx->num_additional_fused_ops == 2 ? dst->src[0]->src[0] : dst;
++            const int mode = ((const int32_t *) rope->op_params)[2];
++            const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
++            const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
++            const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
++
++            if (is_neox) {
++                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++                    return ctx->device->pipeline_rope_neox_f32;
++                }
++                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
++                    return ctx->device->pipeline_rope_neox_f32_f16;
++                }
++                if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
++                    return ctx->device->pipeline_rope_neox_f16;
++                }
++            } else if (is_mrope && !is_vision) {
++                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++                    return ctx->device->pipeline_rope_multi_f32;
++                }
++                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
++                    return ctx->device->pipeline_rope_multi_f32_f16;
++                }
++                if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
++                    return ctx->device->pipeline_rope_multi_f16;
++                }
++            } else if (is_vision) {
++                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++                    return ctx->device->pipeline_rope_vision_f32;
++                }
++                if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
++                    return ctx->device->pipeline_rope_vision_f16;
++                }
++            } else {
++                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++                    return ctx->device->pipeline_rope_norm_f32;
++                }
++                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
++                    return ctx->device->pipeline_rope_norm_f32_f16;
++                }
++                if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
++                    return ctx->device->pipeline_rope_norm_f16;
++                }
++            }
++            return nullptr;
++        }
++    case GGML_OP_SUM:
++    case GGML_OP_SUM_ROWS:
++    case GGML_OP_MEAN:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_sum_rows_f32;
++        }
++        return nullptr;
++    case GGML_OP_CUMSUM:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            if (src0->ne[0] <= 512) {
++                return ctx->device->pipeline_cumsum_small_f32;
++            } else {
++                return ctx->device->pipeline_cumsum_f32;
++            }
++        }
++        return nullptr;
++    case GGML_OP_SOLVE_TRI:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++
++            vk_solve_tri_pipeline_state solve_tri_pipeline_state(src0->ne[0], src1->ne[0]);
++
++            vk_pipeline pipeline = nullptr;
++
++            {
++                std::lock_guard guard(ctx->device->mutex);
++                auto it = ctx->device->pipeline_solve_tri_f32.find(solve_tri_pipeline_state);
++                if (it != ctx->device->pipeline_solve_tri_f32.end()) {
++                    pipeline = it->second;
++                } else {
++                    ctx->device->pipeline_solve_tri_f32[solve_tri_pipeline_state] = pipeline = std::make_shared();
++                }
++            }
++
++            return pipeline;
++        }
++        return nullptr;
++    case GGML_OP_ARGMAX:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
++            return ctx->device->pipeline_argmax_f32;
++        }
++        return nullptr;
++    case GGML_OP_COUNT_EQUAL:
++        if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I64) {
++            return ctx->device->pipeline_count_equal_i32;
++        }
++        return nullptr;
++    case GGML_OP_IM2COL:
++        if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_im2col_f32;
++        }
++        if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
++            return ctx->device->pipeline_im2col_f32_f16;
++        }
++        return nullptr;
++    case GGML_OP_IM2COL_3D:
++        if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_im2col_3d_f32;
++        }
++        if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
++            return ctx->device->pipeline_im2col_3d_f32_f16;
++        }
++        return nullptr;
++    case GGML_OP_TIMESTEP_EMBEDDING:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_timestep_embedding_f32;
++        }
++        return nullptr;
++    case GGML_OP_CONV_TRANSPOSE_1D:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_conv_transpose_1d_f32;
++        }
++        return nullptr;
++    case GGML_OP_POOL_2D:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_pool2d_f32;
++        }
++        return nullptr;
++    case GGML_OP_RWKV_WKV6:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_rwkv_wkv6_f32;
++        }
++        return nullptr;
++    case GGML_OP_RWKV_WKV7:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_rwkv_wkv7_f32;
++        }
++        return nullptr;
++    case GGML_OP_SSM_SCAN:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            const uint32_t d_state = src0->ne[0];
++            if (d_state == 128) {
++                return ctx->device->pipeline_ssm_scan_f32_d128;
++            } else if (d_state == 256) {
++                return ctx->device->pipeline_ssm_scan_f32_d256;
++            }
++        }
++        return nullptr;
++    case GGML_OP_SSM_CONV:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_ssm_conv_f32;
++        }
++        return nullptr;
++    case GGML_OP_OPT_STEP_ADAMW:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_opt_step_adamw_f32;
++        }
++        return nullptr;
++    case GGML_OP_OPT_STEP_SGD:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_opt_step_sgd_f32;
++        }
++        return nullptr;
++    case GGML_OP_LEAKY_RELU:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_leaky_relu_f32;
++        }
++        return nullptr;
++    case GGML_OP_CONV_2D:
++    case GGML_OP_CONV_TRANSPOSE_2D:
++        if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            uint32_t K = dst->ne[2]; // Cout
++            uint32_t NPQ = dst->ne[3] * dst->ne[1] * dst->ne[0]; // N * OH * OW
++            vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, K, NPQ);
++
++            bool transpose = dst->op == GGML_OP_CONV_TRANSPOSE_2D;
++            uint32_t KW = (uint32_t)src0->ne[0];
++            uint32_t KH = (uint32_t)src0->ne[1];
++            uint32_t s0 = (uint32_t)(ggml_get_op_params_i32(dst, 0));
++            uint32_t s1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 1) : s0;
++            uint32_t p0 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 2) : 0;
++            uint32_t p1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 3) : 0;
++            uint32_t d0 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 4) : 1;
++            uint32_t d1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 5) : 1;
++            vk_conv2d_pipeline_state conv2d_pipeline_state(s0, s1, p0, p1, d0, d1, KW, KH);
++
++            std::map *pipelines = nullptr;
++            if (op == GGML_OP_CONV_2D) {
++                if (src0->type == GGML_TYPE_F32) {
++                    pipelines = &ctx->device->pipeline_conv2d_f32[shape];
++                } else if (src0->type == GGML_TYPE_F16) {
++                    pipelines = &ctx->device->pipeline_conv2d_f16_f32[shape];
++                }
++            } else if (op == GGML_OP_CONV_TRANSPOSE_2D) {
++                if (src0->type == GGML_TYPE_F32) {
++                    pipelines = &ctx->device->pipeline_conv_transpose_2d_f32[shape];
++                } else if (src0->type == GGML_TYPE_F16) {
++                    pipelines = &ctx->device->pipeline_conv_transpose_2d_f16_f32[shape];
++                }
++            }
++
++            vk_pipeline pipeline = nullptr;
++
++            {
++                std::lock_guard guard(ctx->device->mutex);
++                auto it = pipelines->find(conv2d_pipeline_state);
++                if (it != pipelines->end()) {
++                    pipeline = it->second;
++                } else {
++                    (*pipelines)[conv2d_pipeline_state] = pipeline = std::make_shared();
++                }
++            }
++
++            return pipeline;
++        }
++        return nullptr;
++    case GGML_OP_CONV_2D_DW:
++        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            if (ggml_is_contiguous(src1)) {
++                return ctx->device->pipeline_conv2d_dw_whcn_f32;
++            } else if (ggml_is_contiguous_channels(src1)) {
++                return ctx->device->pipeline_conv2d_dw_cwhn_f32;
++            }
++        } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
++            if (ggml_is_contiguous(src1)) {
++                return ctx->device->pipeline_conv2d_dw_whcn_f16_f32;
++            } else if (ggml_is_contiguous_channels(src1)) {
++                return ctx->device->pipeline_conv2d_dw_cwhn_f16_f32;
++            }
++        }
++        return nullptr;
++    case GGML_OP_ADD1:
++        if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
++            return ctx->device->pipeline_add1_f16_f16;
++        }
++        if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
++            return ctx->device->pipeline_add1_f16_f32;
++        }
++        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_add1_f32_f32;
++        }
++        return nullptr;
++    case GGML_OP_ARANGE:
++        if (dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_arange_f32;
++        }
++        return nullptr;
++    case GGML_OP_FILL:
++        if (dst->type == GGML_TYPE_F32) {
++            return ctx->device->pipeline_fill_f32;
++        }
++        return nullptr;
++    default:
++        return nullptr;
++    }
++
++    GGML_UNUSED(src2);
++}
++
++template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
++    const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
++    const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
++
++    p.misalign_offsets = (a_offset << 16) | d_offset;
++
++    GGML_UNUSED(src1);
++    GGML_UNUSED(src2);
++    GGML_UNUSED(src3);
++}
++
++template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
++    const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
++    const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
++
++    p.misalign_offsets = (a_offset << 16) | d_offset;
++
++    GGML_UNUSED(src1);
++    GGML_UNUSED(src2);
++    GGML_UNUSED(src3);
++}
++
++template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_pad_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
++    const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
++    const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
++
++    p.misalign_offsets = (a_offset << 16) | d_offset;
++
++    GGML_UNUSED(src1);
++    GGML_UNUSED(src2);
++    GGML_UNUSED(src3);
++}
++
++template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
++    const uint32_t a_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
++    const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
++
++    p.misalign_offsets = (a_offset << 16) | d_offset;
++
++    GGML_UNUSED(src0);
++    GGML_UNUSED(src2);
++    GGML_UNUSED(src3);
++}
++
++template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
++    const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
++    const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
++    const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
++
++    GGML_ASSERT(dst->op != GGML_OP_GET_ROWS || (a_offset == 0 && b_offset == 0 && d_offset == 0));
++
++    p.misalign_offsets = (a_offset << 16) | (b_offset << 8) | d_offset;
++
++    GGML_UNUSED(src2);
++    GGML_UNUSED(src3);
++}
++
++template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
++    const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
++    const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
++
++    p.a_offset = a_offset;
++    p.d_offset = d_offset;
++
++    GGML_UNUSED(src1);
++    GGML_UNUSED(src2);
++    GGML_UNUSED(src3);
++}
++
++template
++static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst, ggml_op op, PC&& pc) {
++    VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
++    if (src1 != nullptr) {
++        std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
++    }
++    if (src2 != nullptr) {
++        std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3];
++    }
++    if (src3 != nullptr) {
++        std::cerr << "), (" << src3 << ", name=" << src3->name << ", type=" << src3->type << ", ne0=" << src3->ne[0] << ", ne1=" << src3->ne[1] << ", ne2=" << src3->ne[2] << ", ne3=" << src3->ne[3] << ", nb0=" << src3->nb[0] << ", nb1=" << src3->nb[1] << ", nb2=" << src3->nb[2] << ", nb3=" << src3->nb[3];
++    }
++    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
++    std::cerr << "), " << ggml_op_name(op) << ")");
++    GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type))));  // NOLINT
++    GGML_ASSERT(dst->buffer != nullptr);
++    const uint64_t ne00 = src0->ne[0];
++    const uint64_t ne01 = src0->ne[1];
++    const uint64_t ne02 = src0->ne[2];
++    const uint64_t ne03 = src0->ne[3];
++
++    const bool use_src1 = src1 != nullptr;
++    const uint64_t ne10 = use_src1 ? src1->ne[0] : 0;
++    const uint64_t ne11 = use_src1 ? src1->ne[1] : 0;
++    const uint64_t ne12 = use_src1 ? src1->ne[2] : 0;
++    const uint64_t ne13 = use_src1 ? src1->ne[3] : 0;
++
++    const bool use_src2 = src2 != nullptr;
++    const bool use_src3 = src3 != nullptr;
++
++    init_pushconst_fastdiv(pc);
++
++    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op);
++
++    if (pipeline == nullptr) {
++        std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(op) << " for " << ggml_type_name(src0->type);
++        if (src1 != nullptr) {
++            std::cerr << " and " << ggml_type_name(src1->type);
++        }
++        std::cerr << " to " << ggml_type_name(dst->type) << std::endl;
++        GGML_ABORT("fatal error");
++    }
++
++    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
++
++    vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0, true);
++    vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, src1, true) : vk_subbuffer{};
++    vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, src2, true) : vk_subbuffer{};
++    vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, src3, true) : vk_subbuffer{};
++    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, true);
++
++    // Compute misalignment offset for descriptors and store it in in push constants.
++    init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, src3, dst);
++
++    std::array elements;
++
++    switch (op) {
++    case GGML_OP_NORM:
++    case GGML_OP_RMS_NORM_BACK:
++    case GGML_OP_L2_NORM:
++    case GGML_OP_SOFT_MAX:
++    case GGML_OP_SOFT_MAX_BACK:
++    case GGML_OP_SUM_ROWS:
++    case GGML_OP_CUMSUM:
++    case GGML_OP_MEAN:
++    case GGML_OP_ARGMAX:
++        {
++            const uint32_t nr = ggml_nrows(src0);
++            if (nr > 262144) {
++                elements = { 512, 512, CEIL_DIV(nr, 262144) };
++            } else if (nr > 512) {
++                elements = { 512, CEIL_DIV(nr, 512), 1 };
++            } else {
++                elements = { nr, 1, 1 };
++            }
++        } break;
++    case GGML_OP_SOLVE_TRI:
++        {
++            uint32_t nr = (uint32_t)(ne02 * ne03);
++            if (nr > 262144) {
++                elements = { 512, 512, CEIL_DIV(nr, 262144) };
++            } else if (nr > 512) {
++                elements = { 512, CEIL_DIV(nr, 512), 1 };
++            } else {
++                elements = { nr, 1, 1 };
++            }
++        }
++        break;
++    case GGML_OP_RMS_NORM:
++        if (ctx->do_add_rms_partials) {
++            // Run one element per thread, 128 threads per workgroup
++            elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 };
++        } else {
++            elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
++        }
++        break;
++
++    case GGML_OP_SUM:
++        // We use GGML_OP_SUM_ROWS with 1 row.
++        elements = { 1, 1, 1 };
++        break;
++    case GGML_OP_GROUP_NORM:
++        {
++            const uint32_t num_groups = dst->op_params[0];
++            elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 };
++        } break;
++    case GGML_OP_DIAG_MASK_INF:
++        elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
++        break;
++    case GGML_OP_ROPE:
++    case GGML_OP_ROPE_BACK:
++        {
++            uint32_t nrows = (uint32_t)ggml_nrows(src0);
++            uint32_t z = 1;
++            if (nrows > ctx->device->properties.limits.maxComputeWorkGroupCount[0]) {
++                z = CEIL_DIV(nrows, 32768);
++                nrows = 32768;
++            }
++            elements = { nrows, (uint32_t)ne00, z };
++
++        } break;
++    case GGML_OP_GET_ROWS:
++        elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };
++        elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
++        elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
++        break;
++    case GGML_OP_ARGSORT:
++        GGML_ASSERT(0);
++        break;
++    case GGML_OP_IM2COL:
++        {
++            const bool is_2D = dst->op_params[6] == 1;
++
++            const uint32_t IC = src1->ne[is_2D ? 2 : 1];
++
++            const uint32_t KH = is_2D ? src0->ne[1] : 1;
++            const uint32_t KW =         src0->ne[0];
++
++            const uint32_t OH = is_2D ? dst->ne[2] : 1;
++            const uint32_t OW =         dst->ne[1];
++
++            const uint32_t batch = src1->ne[is_2D ? 3 : 2];
++
++            elements = { OW * KW * KH, OH, batch * IC };
++            elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
++            elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
++        } break;
++    case GGML_OP_IM2COL_3D:
++        {
++            const uint32_t IC = ((const uint32_t *)(dst->op_params))[9];
++
++            const uint32_t N  = ne13 / IC;
++
++            const uint32_t KD = ne02;
++            const uint32_t KH = ne01;
++            const uint32_t KW = ne00;
++
++            const uint32_t OD = dst->ne[3] / N;
++            const uint32_t OH = dst->ne[2];
++            const uint32_t OW = dst->ne[1];
++
++            const uint32_t IC_KD_KH_KW = IC*KD*KH*KW;
++            const uint32_t N_OD_OH = N*OD*OH;
++
++            elements = { IC_KD_KH_KW, OW, N_OD_OH };
++            elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
++        } break;
++    case GGML_OP_TIMESTEP_EMBEDDING:
++        {
++            const uint32_t dim = dst->op_params[0];
++            uint32_t half_ceil = (dim + 1) / 2;
++            elements = { half_ceil, (uint32_t)src0->ne[0], 1 };
++        } break;
++    case GGML_OP_CONV_TRANSPOSE_1D:
++        {
++            elements = {uint32_t(src0->ne[1]), 1, 1}; // parallelize in {Cout, 1, 1}
++        } break;
++    case GGML_OP_POOL_2D:
++        {
++            const uint32_t N = dst->ne[3];
++            const uint32_t OC = dst->ne[2];
++            const uint32_t OH = dst->ne[1];
++            const uint32_t OW = dst->ne[0];
++            elements = { N * OC * OH * OW, 1, 1};
++        } break;
++    case GGML_OP_CONV_2D:
++    case GGML_OP_CONV_TRANSPOSE_2D:
++        if constexpr (std::is_same_v) {
++            const uint32_t NPQ = pc.N * pc.OH * pc.OW;
++            const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, pc.Cout, NPQ);
++            const uint32_t NPQ_blocks = CEIL_DIV(NPQ, vk_conv_block_sizes[shape].NPQ);
++
++            elements = { pc.Cout, NPQ_blocks, 1 };
++            if (elements[1] > 512) {
++                elements[2] = CEIL_DIV(elements[1], 512);
++                elements[1] = 512;
++            }
++        } else {
++            GGML_ABORT("invalid push constant type for CONV_2D");
++        }
++        break;
++    case GGML_OP_ADD:
++    case GGML_OP_SUB:
++    case GGML_OP_DIV:
++    case GGML_OP_MUL:
++    case GGML_OP_ADD1:
++    case GGML_OP_ARANGE:
++    case GGML_OP_FILL:
++    case GGML_OP_SCALE:
++    case GGML_OP_SQR:
++    case GGML_OP_SQRT:
++    case GGML_OP_SIN:
++    case GGML_OP_COS:
++    case GGML_OP_LOG:
++    case GGML_OP_TRI:
++    case GGML_OP_DIAG:
++    case GGML_OP_CLAMP:
++    case GGML_OP_PAD:
++    case GGML_OP_ROLL:
++    case GGML_OP_REPEAT:
++    case GGML_OP_REPEAT_BACK:
++    case GGML_OP_CPY:
++    case GGML_OP_CONCAT:
++    case GGML_OP_UPSCALE:
++    case GGML_OP_UNARY:
++    case GGML_OP_GLU:
++    case GGML_OP_CONV_2D_DW:
++        {
++            uint32_t ne = ggml_nelements(dst);
++            if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
++                // Convert from number of logical elements to 2- or 4-byte units.
++                ne /= ggml_blck_size(src0->type);
++                if ((ggml_type_size(src0->type) % 4) == 0) {
++                    ne *= ggml_type_size(src0->type) / 4;
++                } else {
++                    ne *= ggml_type_size(src0->type) / 2;
++                }
++            }
++            // copy_to_quant has block size of 32, and each thread does QUANT_K elements.
++            // Splitting into 512x512xZ wouldn't work well since each workgroup does 1024 elements.
++            // So divide by block size here before splitting into 512x512 groups.
++            if (op == GGML_OP_CPY && !ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
++                ne = CEIL_DIV(ne, ggml_blck_size(dst->type));
++            }
++            if (ne > 262144) {
++                elements = { 512, 512, CEIL_DIV(ne, 262144) };
++            } else if (ne > 512) {
++                elements = { 512, CEIL_DIV(ne, 512), 1 };
++            } else {
++                elements = { ne, 1, 1 };
++            }
++
++            if (pipeline == ctx->device->pipeline_cpy_transpose_32 ||
++                pipeline == ctx->device->pipeline_cpy_transpose_16) {
++                // 32x32 tiles
++                elements[0] = (uint32_t)CEIL_DIV(dst->ne[0], 32);
++                elements[1] = (uint32_t)CEIL_DIV(dst->ne[1], 32);
++                elements[2] = (uint32_t)(dst->ne[2]*dst->ne[3]);
++                elements[0] = std::min(elements[0], ctx->device->properties.limits.maxComputeWorkGroupCount[0]);
++                elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
++                elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
++            }
++        } break;
++    case GGML_OP_ADD_ID:
++        {
++            elements = { (uint32_t)ne01, (uint32_t)ne02, 1 };
++        } break;
++    case GGML_OP_SET_ROWS:
++        {
++            uint32_t ne = ggml_nelements(src0);
++            if (ggml_is_quantized(dst->type)) {
++                // quants run 32 threads each doing QUANT_K elements
++                ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type));
++            } else {
++                // scalar types do one element per thread, running 512 threads
++                ne = CEIL_DIV(ne, 512);
++            }
++            if (ne > 262144) {
++                elements = { 512, 512, CEIL_DIV(ne, 262144) };
++            } else if (ne > 512) {
++                elements = { 512, CEIL_DIV(ne, 512), 1 };
++            } else {
++                elements = { ne, 1, 1 };
++            }
++        }
++        break;
++    case GGML_OP_SSM_CONV:
++        {
++            const uint32_t nr  = src0->ne[1];
++            const uint32_t n_t = dst->ne[1];
++            const uint32_t n_s = dst->ne[2];
++            elements = { nr, n_t, n_s };
++        }
++        break;
++    default:
++        elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
++        break;
++    }
++
++    if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {
++        vk_subbuffer a_buf = src0_buf;
++        if (ctx->do_add_rms_partials) {
++            a_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_add_rms_partials, ctx->prealloc_size_add_rms_partials_offset);
++        }
++        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
++            { src0_buf, src1_buf, dst_buf, a_buf }, pc, elements);
++    } else if (op == GGML_OP_GLU) {
++        // Empty src1 is possible in glu, but the shader needs a buffer
++        vk_subbuffer subbuf1 = use_src1 ? src1_buf : src0_buf;
++        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc, elements);
++    } else if (op == GGML_OP_SOFT_MAX) {
++        // Empty src1 and src2 is possible in soft_max, but the shader needs a buffer
++        vk_subbuffer subbuf1 = use_src1 ? src1_buf : src0_buf;
++        vk_subbuffer subbuf2 = use_src2 ? src2_buf : src0_buf;
++        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, subbuf2, dst_buf }, pc, elements);
++    } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
++        // Empty src2 and src3 is possible in rope, but the shader needs a buffer
++        vk_subbuffer subbuf2 = use_src2 ? src2_buf : src0_buf;
++        vk_subbuffer subbuf3 = use_src3 ? src3_buf : src0_buf;
++        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, subbuf2, dst_buf, subbuf3 }, pc, elements);
++    } else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) {
++        if (ctx->device->shader_int64 && ctx->device->buffer_device_address) {
++            // buffer device address path doesn't use dst buffer
++            dst_buf.size = 1;
++        }
++        // im2col uses only src1 and dst buffers
++        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src1_buf, dst_buf }, pc, elements);
++    } else if (op == GGML_OP_COUNT_EQUAL) {
++        // count_equal assumes that destination buffer is initialized with zeroes
++        ggml_vk_buffer_memset_async(subctx, dst_buf.buffer, dst_buf.offset, 0, dst_buf.size);
++        ggml_vk_sync_buffers(ctx, subctx);
++        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, dst_buf }, pc, elements);
++    } else if (op == GGML_OP_OPT_STEP_SGD) {
++        // OPT_STEP_SGD works on src0, it does not need dst
++        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf }, pc, elements);
++    } else if (use_src3) {
++        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf, src3_buf, dst_buf }, pc, elements);
++    } else if (use_src2) {
++        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf, dst_buf }, pc, elements);
++    } else if (use_src1) {
++        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, dst_buf }, pc, elements);
++    } else {
++        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, dst_buf }, pc, elements);
++    }
++}
++
++static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
++    const uint32_t src0_type_size = ggml_type_size(src0->type);
++    const uint32_t src1_type_size = ggml_type_size(src1->type);
++    const uint32_t dst_type_size = ggml_type_size(dst->type);
++
++    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GET_ROWS, {
++        (uint32_t)ggml_nelements(src0),
++        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
++        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
++        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
++        0,
++        0.0f, 0.0f, 0,
++    });
++}
++
++static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
++    const uint32_t src0_type_size = ggml_type_size(src0->type);
++    const uint32_t src1_type_size = ggml_type_size(src1->type);
++    const uint32_t dst_type_size = ggml_type_size(dst->type);
++
++    int nb1 = dst->op_params[0] / src0_type_size; // 4 bytes of float32
++    int nb2 = dst->op_params[1] / src0_type_size; // 4 bytes of float32
++    int nb3 = dst->op_params[2] / src0_type_size; // 4 bytes of float32
++    int offset = dst->op_params[3] / src0_type_size; // offset in bytes
++
++    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, {
++        (uint32_t)ggml_nelements(src0),
++        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
++        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
++        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
++        0,
++        0.0f, 0.0f, offset,
++    });
++}
++
++static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
++    const ggml_tensor *first_node = cgraph->nodes[node_idx];
++    const ggml_tensor *dst = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
++
++    // Make a list of all the tensors used by the op.
++    // Last element of the list is the dest tensor.
++    const ggml_tensor *tensors[MAX_PARAMETER_COUNT];
++    uint32_t num_srcs = ctx->num_additional_fused_ops + 2;
++    uint32_t num_tensors = num_srcs + 1;
++    GGML_ASSERT(num_tensors + ctx->do_add_rms_partials <= MAX_PARAMETER_COUNT);
++
++    tensors[0] = first_node->src[0];
++    tensors[1] = first_node->src[1];
++    for (int32_t i = 0; i < ctx->num_additional_fused_ops; ++i) {
++        // check whether the previous result is src[0] or src[1]
++        if (cgraph->nodes[node_idx + i] == cgraph->nodes[node_idx + i + 1]->src[0]) {
++            tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[1];
++        } else {
++            tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[0];
++        }
++    }
++    tensors[num_srcs] = dst;
++
++    vk_op_multi_add_push_constants pc;
++    pc.ne20 = (uint32_t)dst->ne[0];
++    pc.ne21 = (uint32_t)dst->ne[1];
++    pc.ne22 = (uint32_t)dst->ne[2];
++    pc.ne23 = (uint32_t)dst->ne[3];
++
++    for (uint32_t i = 0; i < num_tensors; ++i) {
++        const ggml_tensor *t = tensors[i];
++        pc.nb[i][0] = (uint32_t)t->nb[0] / sizeof(float);
++        pc.nb[i][1] = (uint32_t)t->nb[1] / sizeof(float);
++        pc.nb[i][2] = (uint32_t)t->nb[2] / sizeof(float);
++        pc.nb[i][3] = (uint32_t)t->nb[3] / sizeof(float);
++    }
++    pc.rms_partials = ctx->do_add_rms_partials;
++
++    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, tensors[0], tensors[1], nullptr, dst, dst->op);
++
++    if (pipeline == nullptr) {
++        std::cerr << "ggml_vulkan: Error: Missing multi_add";
++        GGML_ABORT("fatal error");
++    }
++
++    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
++
++    ggml_backend_vk_buffer_context * buf_ctx[MAX_PARAMETER_COUNT];
++    vk_buffer buf[MAX_PARAMETER_COUNT];
++    size_t offset[MAX_PARAMETER_COUNT];
++    bool uma[MAX_PARAMETER_COUNT];
++
++    for (uint32_t i = 0; i < num_tensors; ++i) {
++        buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context;
++        buf[i] = nullptr;
++        offset[i] = 0;
++        uma[i] = false;
++
++        if (ctx->device->uma) {
++            ggml_vk_host_get(ctx->device, tensors[i]->data, buf[i], offset[i]);
++            uma[i] = buf[i] != nullptr;
++        }
++        if (!uma[i]) {
++            buf[i] = buf_ctx[i]->dev_buffer;
++            offset[i] = vk_tensor_offset(tensors[i]) + tensors[i]->view_offs;
++        }
++        GGML_ASSERT(buf[i] != nullptr);
++    }
++    // If any remaining descriptors are unused, just point them at src[0]
++    for (uint32_t i = num_tensors; i < MAX_PARAMETER_COUNT; ++i) {
++        buf[i] = buf[0];
++        offset[i] = 0;
++    }
++    if (ctx->do_add_rms_partials) {
++        buf[num_tensors] = ctx->prealloc_add_rms_partials;
++        offset[num_tensors] = ctx->prealloc_size_add_rms_partials_offset;
++    }
++
++    std::array elements;
++
++    uint32_t ne = ggml_nelements(dst);
++    if (ne > 262144) {
++        elements = { 512, 512, CEIL_DIV(ne, 262144) };
++    } else if (ne > 512) {
++        elements = { 512, CEIL_DIV(ne, 512), 1 };
++    } else {
++        elements = { ne, 1, 1 };
++    }
++
++    static_assert(MAX_PARAMETER_COUNT == 12);
++    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
++        {
++            ggml_vk_subbuffer(ctx, buf[0], offset[0]),
++            ggml_vk_subbuffer(ctx, buf[1], offset[1]),
++            ggml_vk_subbuffer(ctx, buf[2], offset[2]),
++            ggml_vk_subbuffer(ctx, buf[3], offset[3]),
++            ggml_vk_subbuffer(ctx, buf[4], offset[4]),
++            ggml_vk_subbuffer(ctx, buf[5], offset[5]),
++            ggml_vk_subbuffer(ctx, buf[6], offset[6]),
++            ggml_vk_subbuffer(ctx, buf[7], offset[7]),
++            ggml_vk_subbuffer(ctx, buf[8], offset[8]),
++            ggml_vk_subbuffer(ctx, buf[9], offset[9]),
++            ggml_vk_subbuffer(ctx, buf[10], offset[10]),
++            ggml_vk_subbuffer(ctx, buf[11], offset[11]),
++        }, pc, elements);
++}
++
++static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
++    const uint32_t src0_type_size = ggml_type_size(src0->type);
++    const uint32_t src1_type_size = ggml_type_size(src1->type);
++    const uint32_t dst_type_size = ggml_type_size(dst->type);
++
++    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ADD, {
++        (uint32_t)ggml_nelements(src0),
++        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
++        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
++        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
++        0,
++        0.0f, 0.0f, ctx->do_add_rms_partials,
++    });
++}
++
++static void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
++    const uint32_t src0_type_size = ggml_type_size(src0->type);
++    const uint32_t src1_type_size = ggml_type_size(src1->type);
++    const uint32_t dst_type_size = ggml_type_size(dst->type);
++
++    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SUB, {
++        (uint32_t)ggml_nelements(src0),
++        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
++        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
++        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
++        0,
++        0.0f, 0.0f, 0,
++    });
++}
++
++static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
++    const uint32_t src0_type_size = ggml_type_size(src0->type);
++    const uint32_t src1_type_size = ggml_type_size(src1->type);
++    const uint32_t dst_type_size = ggml_type_size(dst->type);
++
++    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_MUL, {
++        (uint32_t)ggml_nelements(src0),
++        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
++        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
++        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
++        0,
++        0.0f, 0.0f, 0,
++    });
++}
++
++static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
++    const uint32_t src0_type_size = ggml_type_size(src0->type);
++    const uint32_t src1_type_size = ggml_type_size(src1->type);
++    const uint32_t dst_type_size = ggml_type_size(dst->type);
++
++    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_DIV, {
++        (uint32_t)ggml_nelements(src0),
++        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
++        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
++        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
++        0,
++        0.0f, 0.0f, 0,
++    });
++}
++
++static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
++    const uint32_t src0_type_size = ggml_type_size(src0->type);
++    const uint32_t src1_type_size = ggml_type_size(src1->type);
++    const uint32_t src2_type_size = ggml_type_size(src2->type);
++
++    ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_ADD_ID, {
++        (uint32_t)dst->ne[0],
++        (uint32_t)dst->ne[1],
++        (uint32_t)src0->nb[1] / src0_type_size,
++        (uint32_t)src0->nb[2] / src0_type_size,
++        (uint32_t)src1->nb[1] / src1_type_size,
++        (uint32_t)src2->nb[1] / src2_type_size,
++    });
++}
++
++static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version) {
++    GGML_ASSERT(version == 6 || version == 7);
++    int num_srcs = version == 6 ? 6 : 7;
++
++    for (int i = 0; i < num_srcs; i++) {
++        GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
++    }
++
++    GGML_ASSERT(dst->buffer != nullptr);
++
++    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
++    GGML_ASSERT(pipeline != nullptr);
++
++    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
++
++    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
++    vk_subbuffer src_buf[7] = {};
++    for (int i = 0; i < num_srcs; i++) {
++        src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);
++    }
++
++    std::array elements = {
++        (uint32_t)(pc.B * pc.H),
++        1,
++        1
++    };
++
++    if (version == 6) {
++        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
++            {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
++            pc, elements);
++    } else if (version == 7) {
++        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
++            {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf},
++            pc, elements);
++    } else {
++        // shouldn't happen
++        GGML_ASSERT(false);
++    }
++}
++
++static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
++    const size_t seq_length = dst->src[0]->ne[2];
++    const size_t n_embed = dst->ne[0];
++    const size_t n_heads = dst->src[0]->ne[1];
++    const size_t n_seqs = dst->src[5]->ne[1];
++
++    ggml_vk_op_f32_wkv(
++        ctx, subctx, dst,
++        {
++            (uint32_t)n_seqs,
++            (uint32_t)seq_length,
++            (uint32_t)n_embed,
++            (uint32_t)n_heads,
++        },
++        6
++    );
++}
++
++static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
++    const size_t seq_length = dst->src[0]->ne[2];
++    const size_t n_embed = dst->ne[0];
++    const size_t n_heads = dst->src[0]->ne[1];
++    const size_t n_seqs = dst->src[6]->ne[1];
++
++    ggml_vk_op_f32_wkv(
++        ctx, subctx, dst,
++        {
++            (uint32_t)n_seqs,
++            (uint32_t)seq_length,
++            (uint32_t)n_embed,
++            (uint32_t)n_heads,
++        },
++        7
++    );
++}
++
++static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
++    const ggml_tensor * src0 = dst->src[0];
++    const ggml_tensor * src1 = dst->src[1];
++    const ggml_tensor * src2 = dst->src[2];
++    const ggml_tensor * src3 = dst->src[3];
++    const ggml_tensor * src4 = dst->src[4];
++    const ggml_tensor * src5 = dst->src[5];
++
++    GGML_ASSERT(dst->buffer != nullptr);
++
++    const uint32_t head_dim = src0->ne[1];
++    const uint32_t n_head = src1->ne[1];
++    const uint32_t n_group = src4->ne[1];
++    const uint32_t n_tok = src1->ne[2];
++    const uint32_t n_seq = src1->ne[3];
++
++    bool is_mamba2 = (src3->nb[1] == sizeof(float));
++    GGML_ASSERT(is_mamba2);
++
++    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, dst->op);
++    GGML_ASSERT(pipeline != nullptr);
++
++    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
++
++    const int64_t s_off = ggml_nelements(src1) * sizeof(float);
++
++    const vk_op_ssm_scan_push_constants pc = {
++        (uint32_t)src0->nb[2], (uint32_t)src0->nb[3],
++        (uint32_t)src1->nb[2], (uint32_t)src1->nb[3],
++        (uint32_t)src2->nb[1], (uint32_t)src2->nb[2],
++        (uint32_t)src3->nb[1],
++        (uint32_t)src4->nb[2], (uint32_t)src4->nb[3],
++        (uint32_t)src5->nb[2], (uint32_t)src5->nb[3],
++        (uint32_t)s_off,
++        n_head, head_dim, n_group, n_tok
++    };
++
++    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
++    vk_subbuffer src_buf[7] = {};
++    for (int i = 0; i < 7 && dst->src[i] != nullptr; i++) {
++        src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);
++    }
++
++    std::array elements;
++
++    const uint32_t d_state = src0->ne[0];
++    uint32_t num_subgroups = d_state / ctx->device->subgroup_size;
++    const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, num_subgroups);
++    const uint32_t num_workgroups_y = n_seq;
++    elements = { num_workgroups_x, num_workgroups_y, 1 };
++
++    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
++        {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf},
++        pc, elements);
++}
++
++static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
++    const ggml_tensor * src0 = dst->src[0];
++    const ggml_tensor * src1 = dst->src[1];
++
++    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SSM_CONV, {
++        (uint32_t)src0->nb[1], (uint32_t)src0->nb[2],
++        (uint32_t)src1->nb[1],
++        (uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2],
++        (uint32_t)src1->ne[0],
++        (uint32_t)src0->ne[0],
++        (uint32_t)src0->ne[1],
++        (uint32_t)dst->ne[1],
++        (uint32_t)dst->ne[2],
++    });
++}
++
++static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc) {
++    const ggml_tensor * x = dst->src[0];
++    const ggml_tensor * g = dst->src[1];
++    const ggml_tensor * gm = dst->src[2];
++    const ggml_tensor * gv = dst->src[3];
++    const ggml_tensor * p = dst->src[4];
++
++    GGML_ASSERT(x->type == GGML_TYPE_F32);
++    GGML_ASSERT(g->type == GGML_TYPE_F32);
++    GGML_ASSERT(gm->type == GGML_TYPE_F32);
++    GGML_ASSERT(gv->type == GGML_TYPE_F32);
++    GGML_ASSERT(p->type == GGML_TYPE_F32);
++    GGML_ASSERT(dst->buffer != nullptr);
++    GGML_ASSERT(ggml_is_contiguous(x));
++    GGML_ASSERT(ggml_is_contiguous(g));
++    GGML_ASSERT(ggml_is_contiguous(gm));
++    GGML_ASSERT(ggml_is_contiguous(gv));
++    GGML_ASSERT(ggml_is_contiguous(p));
++    GGML_ASSERT(ggml_are_same_shape(x, g));
++    GGML_ASSERT(ggml_are_same_shape(x, gm));
++    GGML_ASSERT(ggml_are_same_shape(x, gv));
++    GGML_ASSERT(ggml_nelements(p) == 7);
++
++    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, g, gm, gv, dst, GGML_OP_OPT_STEP_ADAMW);
++    GGML_ASSERT(pipeline != nullptr);
++
++    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
++
++    vk_subbuffer x_buf = ggml_vk_tensor_subbuffer(ctx, x);
++    vk_subbuffer g_buf = ggml_vk_tensor_subbuffer(ctx, g);
++    vk_subbuffer gm_buf = ggml_vk_tensor_subbuffer(ctx, gm);
++    vk_subbuffer gv_buf = ggml_vk_tensor_subbuffer(ctx, gv);
++    vk_subbuffer p_buf = ggml_vk_tensor_subbuffer(ctx, p);
++
++    std::array elements = { (uint32_t)ggml_nelements(x), 1, 1 };
++
++    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
++        {x_buf, g_buf, gm_buf, gv_buf, p_buf},
++        pc, elements);
++}
++
++static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
++    const size_t n = ggml_nelements(dst->src[0]);
++
++    ggml_vk_op_f32_opt_step_adamw(
++        ctx, subctx, dst,
++        { (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f }
++    );
++}
++
++static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
++    const size_t n = ggml_nelements(dst->src[0]);
++
++    ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f });
++}
++
++static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
++    int * op_params = (int *)dst->op_params;
++
++    const uint32_t src0_type_size = ggml_type_size(src0->type);
++    const uint32_t src1_type_size = ggml_type_size(src1->type);
++    const uint32_t dst_type_size = ggml_type_size(dst->type);
++
++    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONCAT, {
++        (uint32_t)ggml_nelements(dst),
++        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
++        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
++        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
++        0,
++        0.0f, 0.0f, op_params[0],
++    });
++}
++
++static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    const uint32_t src0_type_size = ggml_type_size(src0->type);
++    const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0);
++
++    GGML_TENSOR_UNARY_OP_LOCALS
++
++    float sf0 = (float)ne0 / ne00;
++    float sf1 = (float)ne1 / ne01;
++    float sf2 = (float)ne2 / ne02;
++    float sf3 = (float)ne3 / ne03;
++    float pixel_offset = 0.5f;
++
++    if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) {
++        sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
++        sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
++        pixel_offset = 0.0f;
++    }
++
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
++        (uint32_t)ggml_nelements(dst), 0, 0,
++        (uint32_t)ne00, (uint32_t)ne01,
++        (uint32_t)nb00 / src0_type_size, (uint32_t)nb01 / src0_type_size, (uint32_t)nb02 / src0_type_size, (uint32_t)nb03 / src0_type_size,
++        (uint32_t)ne0, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
++        sf0, sf1, sf2, sf3, pixel_offset
++    });
++}
++
++static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
++    p.param1 = ggml_get_op_params_f32(dst, 0);
++    p.param2 = ggml_get_op_params_f32(dst, 1);
++
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p));
++}
++
++static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst));
++}
++
++static void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst));
++}
++
++static void ggml_vk_add1(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
++    const uint32_t src0_type_size = ggml_type_size(src0->type);
++    const uint32_t src1_type_size = ggml_type_size(src1->type);
++    const uint32_t dst_type_size = ggml_type_size(dst->type);
++
++    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ADD1, {
++        (uint32_t)ggml_nelements(src0),
++        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
++        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
++        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
++        0,
++        0.0f, 0.0f, 0,
++    });
++}
++
++static void ggml_vk_arange(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
++    VK_LOG_DEBUG("ggml_vk_arange(dst=" << dst << ", ne=" << ggml_nelements(dst) << ")");
++
++    vk_op_push_constants pc = {
++        (uint32_t)ggml_nelements(dst),
++        1,
++        ggml_get_op_params_f32(dst, 0),
++        ggml_get_op_params_f32(dst, 2),
++        0.0f, 0.0f,
++    };
++
++    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_ARANGE);
++    GGML_ASSERT(pipeline != nullptr);
++
++    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
++    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false);
++
++    std::array elements = { (uint32_t)ggml_nelements(dst), 1, 1 };
++
++    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { dst_buf }, pc, elements);
++}
++
++static void ggml_vk_fill(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
++    VK_LOG_DEBUG("ggml_vk_fill(dst=" << dst << ", ne=" << ggml_nelements(dst) << ")");
++
++    vk_op_push_constants pc = {
++        (uint32_t)ggml_nelements(dst),
++        1,
++        ggml_get_op_params_f32(dst, 0),
++        0.0f,
++        0.0f, 0.0f,
++    };
++
++    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_FILL);
++    GGML_ASSERT(pipeline != nullptr);
++
++    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
++    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false);
++
++    std::array elements = { (uint32_t)ggml_nelements(dst), 1, 1 };
++
++    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { dst_buf }, pc, elements);
++}
++
++static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst));
++}
++
++static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst));
++}
++
++static void ggml_vk_log(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LOG, vk_op_unary_push_constants_init(src0, dst));
++}
++
++static void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
++    p.param1 = ggml_get_op_params_f32(dst, 0);
++
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TRI, std::move(p));
++}
++
++static void ggml_vk_diag(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
++
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_DIAG, std::move(p));
++}
++
++static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
++    p.param1 = ggml_get_op_params_f32(dst, 0);
++    p.param2 = ggml_get_op_params_f32(dst, 1);
++
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p));
++}
++
++static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    vk_op_pad_push_constants p = vk_op_pad_push_constants_init(src0, dst);
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p));
++}
++
++static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    const int32_t s0 = ggml_get_op_params_i32(dst, 0);
++    const int32_t s1 = ggml_get_op_params_i32(dst, 1);
++    const int32_t s2 = ggml_get_op_params_i32(dst, 2);
++    const int32_t s3 = ggml_get_op_params_i32(dst, 3);
++    const uint32_t s01_packed = ((s0 + 0x8000) << 16) | (s1 + 0x8000);
++    const uint32_t s23_packed = ((s2 + 0x8000) << 16) | (s3 + 0x8000);
++
++    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
++    memcpy(&p.param1, &s01_packed, sizeof(float));
++    memcpy(&p.param2, &s23_packed, sizeof(float));
++
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p));
++}
++
++static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p));
++}
++
++static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p));
++}
++
++static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    uint32_t ne = (uint32_t)ggml_nelements(src0);
++    if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
++        // Convert from number of logical elements to 2- or 4-byte units.
++        ne /= ggml_blck_size(src0->type);
++        if ((ggml_type_size(src0->type) % 4) == 0) {
++            ne *= ggml_type_size(src0->type) / 4;
++        } else {
++            ne *= ggml_type_size(src0->type) / 2;
++        }
++    }
++
++    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne);
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p));
++}
++
++static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
++    const uint32_t src0_type_size = ggml_type_size(src0->type);
++    const uint32_t src1_type_size = ggml_type_size(src1->type);
++    const uint32_t dst_type_size = ggml_type_size(dst->type);
++
++    // Skip empty skip_rows operations. For most ops the empty check at the start
++    // of ggml_vk_build_graph is sufficient, but set_rows can have a nonempty dst
++    // with empty srcs.
++    if (ggml_is_empty(src0) || ggml_is_empty(src1)) {
++        return;
++    }
++
++    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SET_ROWS, {
++        (uint32_t)ggml_nelements(src0),
++        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
++        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
++        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
++        0,
++        0.0f, 0.0f, 0,
++    });
++}
++
++static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
++    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
++}
++
++static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    float * op_params = (float *)dst->op_params;
++
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
++}
++
++static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    const int * int_op_params = (const int *)dst->op_params;
++    const float * float_op_params = (const float *)dst->op_params;
++
++    const uint32_t num_groups = int_op_params[0];
++    const float eps = float_op_params[1];
++    const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
++
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f, 0.0f, 0.0f });
++}
++
++static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
++    const uint32_t ne = (uint32_t)node->ne[0];
++    const uint32_t denom = ctx->device->pipeline_add_rms[0][0][0]->wg_denoms[0];
++    const uint32_t num_partials = CEIL_DIV(ne, denom);
++    return num_partials;
++}
++
++static uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
++    const uint32_t num_partials = ggml_vk_rms_num_partials(ctx, node);
++    const uint32_t num_bytes = ROUNDUP_POW2(num_partials * sizeof(uint32_t), ctx->device->partials_binding_alignment);
++    return num_bytes;
++}
++
++static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *dst, const ggml_tensor *src0, const bool has_ff, bool backprop, const uint32_t set_rows_stride) {
++    const int n_dims        = ((const int32_t *) dst->op_params)[1];
++    const int mode          = ((const int32_t *) dst->op_params)[2];
++    // const int n_ctx         = ((const int32_t *) dst->op_params)[3];
++    const int n_ctx_orig    = ((const int32_t *) dst->op_params)[4];
++    const float freq_base   = ((const float *)   dst->op_params)[5];
++    const float freq_scale  = ((const float *)   dst->op_params)[6];
++    const float ext_factor  = ((const float *)   dst->op_params)[7];
++    const float attn_factor = ((const float *)   dst->op_params)[8];
++    const float beta_fast   = ((const float *)   dst->op_params)[9];
++    const float beta_slow   = ((const float *)   dst->op_params)[10];
++    int sections[4] {};
++    if (mode & GGML_ROPE_TYPE_MROPE) {
++        memcpy(sections, (const int32_t *) dst->op_params + 11, sizeof(int)*4);
++    }
++
++    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
++
++    float corr_dims[2];
++    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
++
++    const float theta_scale = powf(freq_base, -2.0f/n_dims);
++
++    uint32_t nb01 = src0->nb[1] / ggml_type_size(src0->type);
++    uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type);
++    uint32_t nb03 = src0->nb[3] / ggml_type_size(src0->type);
++
++    uint32_t nb11 = dst->nb[1] / ggml_type_size(dst->type);
++    uint32_t nb12 = dst->nb[2] / ggml_type_size(dst->type);
++    uint32_t nb13 = dst->nb[3] / ggml_type_size(dst->type);
++
++    vk_op_rope_push_constants rope {
++        (uint32_t)mode, (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale,
++        freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, has_ff,
++        { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
++
++        (uint32_t)src0->ne[0],
++        (uint32_t)src0->ne[1],
++        (uint32_t)src0->ne[2],
++        nb01, nb02, nb03,
++        nb11, nb12, nb13,
++    };
++
++    return rope;
++}
++
++static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx, float * op_params) {
++    ggml_tensor * dst;
++    const ggml_tensor * src0;
++    const ggml_tensor * src1;
++
++    if (ctx->num_additional_fused_ops > 0) {
++        // fused rms_norm + mul
++        ggml_tensor *mul = cgraph->nodes[node_idx + 1];
++        ggml_tensor *other_src = mul->src[0] == cgraph->nodes[node_idx + 0] ? mul->src[1] : mul->src[0];
++        dst = mul;
++        src0 = cgraph->nodes[node_idx]->src[0];
++        src1 = other_src;
++    } else {
++        dst = cgraph->nodes[node_idx];
++        src0 = src1 = dst->src[0];
++    }
++
++    const uint32_t src0_type_size = ggml_type_size(src0->type);
++    const uint32_t src1_type_size = ggml_type_size(src1->type);
++    const uint32_t dst_type_size = ggml_type_size(dst->type);
++
++    uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0;
++
++    vk_op_binary_push_constants bin {
++        (uint32_t)ggml_nelements(src0),
++        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
++        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
++        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
++        0,
++        op_params[0], 0.0f, (int32_t)param3,
++    };
++
++    // more than one fused op means rms_norm+mul+rope
++    if (ctx->num_additional_fused_ops > 1) {
++        static constexpr uint32_t max_tensors = 7;
++        const ggml_tensor *tensors[max_tensors] {};
++
++        ggml_tensor *rms = cgraph->nodes[node_idx + 0];
++        ggml_tensor *mul = cgraph->nodes[node_idx + 1];
++        ggml_tensor *rope = cgraph->nodes[node_idx + 2];
++
++        ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0];
++
++        bool do_set_rows = ctx->num_additional_fused_ops == 4;
++
++        tensors[0] = rms->src[0];
++        tensors[1] = other_src;
++        tensors[2] = mul;
++        tensors[3] = rope->src[1]; // pos
++        tensors[4] = rope->src[2]; // ff
++        tensors[5] = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; // dst
++        tensors[6] = do_set_rows ? tensors[5]->src[1] : nullptr;
++        const uint32_t set_rows_stride = do_set_rows ? tensors[5]->nb[1] / ggml_type_size(tensors[5]->type) : 0;
++
++        vk_op_rms_norm_mul_rope_push_constants pc;
++        pc.bin = bin;
++        pc.rope = ggml_vk_make_rope_constants(rope, rope->src[0], tensors[4] != nullptr, false, set_rows_stride);
++
++        vk_pipeline pipeline = tensors[5]->type == GGML_TYPE_F16 ? ctx->device->pipeline_rms_norm_mul_rope_f32_f16 : ctx->device->pipeline_rms_norm_mul_rope_f32_f32;
++
++        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
++
++        ggml_backend_vk_buffer_context * buf_ctx[max_tensors];
++        vk_buffer buf[max_tensors];
++        size_t offset[max_tensors];
++        bool uma[max_tensors];
++
++        for (uint32_t i = 0; i < max_tensors; ++i) {
++            if (!tensors[i]) {
++                // If any remaining descriptors are unused, just point them at src[0]
++                buf[i] = buf[0];
++                offset[i] = 0;
++                continue;
++            }
++            buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context;
++            buf[i] = nullptr;
++            offset[i] = 0;
++            uma[i] = false;
++
++            if (ctx->device->uma) {
++                ggml_vk_host_get(ctx->device, tensors[i]->data, buf[i], offset[i]);
++                uma[i] = buf[i] != nullptr;
++            }
++            if (!uma[i]) {
++                buf[i] = buf_ctx[i]->dev_buffer;
++                offset[i] = vk_tensor_offset(tensors[i]) + tensors[i]->view_offs;
++            }
++            GGML_ASSERT(buf[i] != nullptr);
++        }
++
++        std::array elements;
++        elements = { (uint32_t)rms->src[0]->ne[1], (uint32_t)rms->src[0]->ne[2], (uint32_t)rms->src[0]->ne[3] };
++
++        static_assert(max_tensors == 7);
++        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
++            {
++                ggml_vk_subbuffer(ctx, buf[0], offset[0]),
++                ggml_vk_subbuffer(ctx, buf[1], offset[1]),
++                ggml_vk_subbuffer(ctx, buf[2], offset[2]),
++                ggml_vk_subbuffer(ctx, buf[3], offset[3]),
++                ggml_vk_subbuffer(ctx, buf[4], offset[4]),
++                ggml_vk_subbuffer(ctx, buf[5], offset[5]),
++                ggml_vk_subbuffer(ctx, buf[6], offset[6]),
++            }, pc, elements);
++    } else {
++        ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM, std::move(bin));
++    }
++
++    if (ctx->do_add_rms_partials_offset_calculation) {
++        ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size(ctx, src0);
++        ctx->do_add_rms_partials = false;
++        ctx->do_add_rms_partials_offset_calculation = false;
++    }
++}
++
++static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
++    float * op_params = (float *)dst->op_params;
++    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
++}
++
++static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    const float * op_params = (const float *)dst->op_params;
++    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
++    p.param1 = op_params[0];
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, std::move(p));
++}
++
++static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
++}
++
++static void ggml_vk_xielu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    float * op_params = (float *)dst->op_params;
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY,
++        {
++            (uint32_t)ggml_nelements(src0), 0,
++            op_params[1], op_params[2], op_params[3], op_params[4]
++        }
++    );
++}
++
++static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
++    const float * op_params_f = (const float *)dst->op_params;
++
++    const bool swapped = (bool)dst->op_params[1];
++    const bool split = src1 != nullptr;
++    const float alpha = op_params_f[2];
++    const float limit = op_params_f[3];
++
++    GGML_ASSERT(ggml_is_contiguous(src0));
++
++    if (!split) {
++        GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
++    } else {
++        GGML_ASSERT(src0->ne[0] == src1->ne[0]);
++        GGML_ASSERT(src0->ne[0] == dst->ne[0]);
++        GGML_ASSERT(src0->type == src1->type);
++    }
++
++    const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
++
++    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GLU,
++        {
++            (uint32_t)ggml_nelements(dst),
++            (uint32_t)src0->ne[0],
++            (uint32_t)dst->ne[0],
++            mode,
++            alpha,
++            limit
++        });
++}
++
++static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    int32_t * op_params = (int32_t *)dst->op_params;
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] });
++}
++
++static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
++    float * op_params = (float *)dst->op_params;
++
++    float scale = op_params[0];
++    float max_bias = op_params[1];
++
++    const uint32_t ncols =   (uint32_t)src0->ne[0];
++    const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
++    const uint32_t nrows_y = (uint32_t)src0->ne[1];
++
++    const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u;
++    const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u;
++    const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u;
++    const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u;
++    const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u;
++
++    const uint32_t n_head_kv   = src0->ne[2];
++    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
++
++    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
++    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
++
++    vk_op_soft_max_push_constants pc {
++        ncols,
++        src1 != nullptr ? nrows_y : (uint32_t)0,
++        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
++        ne12, ne13,
++        nb11, nb12, nb13,
++        scale, max_bias,
++        m0, m1,
++        n_head_log2,
++        nrows_x,
++        src2 != nullptr
++    };
++
++    if (ncols <= 16384) {
++        ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, std::move(pc));
++    } else {
++
++        vk_subbuffer buf_a = ggml_vk_tensor_subbuffer(ctx, src0);
++        vk_subbuffer buf_b = src1 ? ggml_vk_tensor_subbuffer(ctx, src1) : buf_a;
++        vk_subbuffer buf_c = src2 ? ggml_vk_tensor_subbuffer(ctx, src2) : buf_a;
++        vk_subbuffer buf_d = ggml_vk_tensor_subbuffer(ctx, dst);
++
++        uint32_t elems_per_wg = 128 * 4;
++        uint32_t num_wgs = CEIL_DIV(ncols, elems_per_wg);
++        size_t tmp_size = num_wgs * nrows_x * sizeof(float);
++
++        if (ctx->prealloc_size_x < tmp_size) {
++            ctx->prealloc_size_x = tmp_size;
++            ggml_vk_preallocate_buffers(ctx, subctx);
++        }
++        if (ctx->prealloc_size_y < tmp_size) {
++            ctx->prealloc_size_y = tmp_size;
++            ggml_vk_preallocate_buffers(ctx, subctx);
++        }
++        if (ctx->prealloc_x_need_sync || ctx->prealloc_y_need_sync) {
++            ggml_vk_sync_buffers(ctx, subctx);
++        }
++
++        vk_subbuffer buf_x = { ctx->prealloc_x, 0, tmp_size };
++        vk_subbuffer buf_y = { ctx->prealloc_y, 0, tmp_size };
++
++        std::array elements = { num_wgs, nrows_x, 1 };
++
++        vk_pipeline pipeline1 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large1_f32_f16 : ctx->device->pipeline_soft_max_large1_f32;
++        vk_pipeline pipeline2 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large2_f32_f16 : ctx->device->pipeline_soft_max_large2_f32;
++        vk_pipeline pipeline3 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large3_f32_f16 : ctx->device->pipeline_soft_max_large3_f32;
++
++        ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1);
++        ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1);
++        ggml_pipeline_request_descriptor_sets(ctx, pipeline3, 1);
++
++        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
++        ggml_vk_sync_buffers(ctx, subctx);
++        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
++        ggml_vk_sync_buffers(ctx, subctx);
++        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline3, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
++
++        ctx->prealloc_x_need_sync = true;
++        ctx->prealloc_y_need_sync = true;
++    }
++}
++
++static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
++    float * op_params = (float *)dst->op_params;
++    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1], 0.0f, 0.0f });
++}
++
++static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
++    topk_moe_mode mode = ctx->fused_topk_moe_mode;
++    ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
++    ggml_tensor * bias = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 2]->src[1] : logits;
++    ggml_tensor * weights = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
++    ggml_tensor * ids = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 4] :
++                        (mode == TOPK_MOE_LATE_SOFTMAX) ?      cgraph->nodes[node_idx + 1] :
++                                                               cgraph->nodes[node_idx + 3];
++
++    GGML_ASSERT(logits->type == GGML_TYPE_F32);
++    GGML_ASSERT(bias->type == GGML_TYPE_F32);
++    GGML_ASSERT(weights->type == GGML_TYPE_F32);
++    GGML_ASSERT(ids->type == GGML_TYPE_I32);
++
++    const int n_experts = logits->ne[0];
++    const int n_rows    = logits->ne[1];
++    const int n_expert_used = weights->ne[1];
++
++    GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
++
++    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, cgraph->nodes[node_idx], GGML_OP_SOFT_MAX);
++
++    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
++
++    vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer(ctx, logits);
++    vk_subbuffer bias_buf = ggml_vk_tensor_subbuffer(ctx, bias);
++    vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer(ctx, weights);
++    vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer(ctx, ids);
++
++    vk_op_topk_moe_push_constants pc {};
++    pc.n_rows = n_rows;
++    pc.n_experts_push = n_experts;
++    pc.n_expert_used = n_expert_used;
++    pc.clamp_min = -std::numeric_limits::infinity();
++    pc.clamp_max = std::numeric_limits::infinity();
++    if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
++        ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
++        GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
++        pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
++        pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
++    }
++    if (mode == TOPK_MOE_SIGMOID_NORM_BIAS) {
++        ggml_tensor * clamp = cgraph->nodes[node_idx + 8];
++        GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
++        pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
++        pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
++    }
++
++#define GATING_FUNC_SOFTMAX 0
++#define GATING_FUNC_SIGMOID 1
++#define GATING_FUNC_SOFTMAX_WEIGHT 2
++
++    pc.gating_func = mode == TOPK_MOE_SIGMOID_NORM_BIAS ? GATING_FUNC_SIGMOID :
++                     mode == TOPK_MOE_LATE_SOFTMAX ?      GATING_FUNC_SOFTMAX_WEIGHT :
++                                                          GATING_FUNC_SOFTMAX;
++    pc.has_bias = mode == TOPK_MOE_SIGMOID_NORM_BIAS;
++    pc.with_norm = mode == TOPK_MOE_EARLY_SOFTMAX_NORM || mode == TOPK_MOE_SIGMOID_NORM_BIAS;
++    if (ctx->fused_topk_moe_scale) {
++        GGML_ASSERT(weights->op == GGML_OP_SCALE);
++        pc.output_scale = ggml_get_op_params_f32(weights, 0);
++        pc.output_bias = ggml_get_op_params_f32(weights, 1);
++    } else {
++        pc.output_scale = 1.0f;
++        pc.output_bias = 0.0f;
++    }
++
++    GGML_ASSERT(n_expert_used <= n_experts);
++
++    const uint32_t rows_per_block = 4;
++    std::array elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 };
++
++    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, bias_buf, weights_buf, ids_buf}, pc, elements);
++}
++
++static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop) {
++    ggml_tensor * dst = cgraph->nodes[node_idx];
++    const ggml_tensor * src0 = dst->src[0];
++    const ggml_tensor * src1 = dst->src[1];
++    const ggml_tensor * src2 = dst->src[2];
++    const ggml_tensor * src3 = nullptr;
++    const int n_dims        = ((int32_t *) dst->op_params)[1];
++    const int mode          = ((int32_t *) dst->op_params)[2];
++    // const int n_ctx         = ((int32_t *) dst->op_params)[3];
++    const int n_ctx_orig    = ((int32_t *) dst->op_params)[4];
++    const float freq_base   = ((float *)   dst->op_params)[5];
++    const float beta_fast   = ((float *)   dst->op_params)[9];
++    const float beta_slow   = ((float *)   dst->op_params)[10];
++    int sections[4] {};
++    if (mode & GGML_ROPE_TYPE_MROPE) {
++        memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
++    }
++
++    float corr_dims[2];
++    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
++
++    uint32_t set_rows_stride = 0;
++    // Fused rope + view + set_rows passes the set_rows destination stride in set_rows_stride
++    // and overrides the dst and sets src3=row_indices
++    if (ctx->num_additional_fused_ops > 0) {
++        set_rows_stride = cgraph->nodes[node_idx + 2]->nb[1] / ggml_type_size(cgraph->nodes[node_idx + 2]->type);
++        src3 = cgraph->nodes[node_idx + 2]->src[1];
++        dst = cgraph->nodes[node_idx + 2];
++    }
++
++    ggml_vk_op_f32(ctx, subctx, src0, src1, src2, src3, dst, GGML_OP_ROPE,
++        ggml_vk_make_rope_constants(cgraph->nodes[node_idx], src0, src2 != nullptr, backprop, set_rows_stride));
++}
++
++static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    const uint32_t * op_params = (const uint32_t *)dst->op_params;
++
++    uint32_t ncols = src0->ne[0];
++    uint32_t nrows = ggml_nrows(src0);
++
++    uint32_t ncols_pad_log2 = (uint32_t)ceilf(log2f(float(ncols)));
++    uint32_t ncolsp2 = 1 << ncols_pad_log2;
++
++    vk_op_argsort_push_constants pc { ncols, ncolsp2, ncols_pad_log2, nrows, op_params[0], 0, 0, 0, 0, };
++
++    // Pick the largest workgroup size <= ncolsp2
++    uint32_t pipeline_idx = std::min(ncols_pad_log2, num_argsort_pipelines - 1);
++
++    // Use the "small" argsort shader if the whole sort can be done by a single workgroup.
++    bool use_small = ncols_pad_log2 <= ctx->device->max_workgroup_size_log2 &&
++                     ctx->device->pipeline_argsort_f32[pipeline_idx] != nullptr;
++
++    vk_pipeline pipeline = use_small ? ctx->device->pipeline_argsort_f32[pipeline_idx]
++                                     : ctx->device->pipeline_argsort_large_f32[pipeline_idx];
++
++    vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0);
++    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
++    vk_subbuffer subbuf1 = dst_buf;
++
++    // Reserve space for ivec2 per element, with rows padded to a power of two
++    if (!use_small) {
++        const size_t x_sz = size_t{ncolsp2} * nrows * 2 * sizeof(int);
++
++        if (ctx->prealloc_size_x < x_sz) {
++            ctx->prealloc_size_x = x_sz;
++            ggml_vk_preallocate_buffers(ctx, subctx);
++        }
++        if (ctx->prealloc_x_need_sync) {
++            ggml_vk_sync_buffers(ctx, subctx);
++        }
++        subbuf1 = { ctx->prealloc_x, 0, ctx->prealloc_x->size };
++    }
++
++    std::array elements;
++
++    elements[0] = ncolsp2;
++    elements[1] = std::min((uint32_t)ggml_nrows(src0), ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
++    elements[2] = 1;
++
++    // First dispatch initializes tmp_idx and does the first N passes where
++    // there is only communication between threads in the same workgroup.
++    {
++        vk_op_argsort_push_constants pc2 = pc;
++        pc2.outer_start = 0;
++        pc2.outer_end = std::min(ncols_pad_log2, ctx->device->max_workgroup_size_log2);
++        pc2.inner_start = 0;
++        pc2.inner_end = 100;
++        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
++        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);
++    }
++    if (!use_small) {
++        ggml_vk_sync_buffers(ctx, subctx);
++        // Loop over outer/inner passes, synchronizing between each pass.
++        for (uint32_t outer = ctx->device->max_workgroup_size_log2; outer < ncols_pad_log2; ++outer) {
++            for (uint32_t inner = 0; inner < outer + 1; ++inner) {
++                vk_op_argsort_push_constants pc2 = pc;
++                pc2.outer_start = outer;
++                pc2.outer_end = outer + 1;
++                pc2.inner_start = inner;
++                pc2.inner_end = inner + 1;
++                // When the inner idx is large enough, there's only communication
++                // within a workgroup. So the remaining inner iterations can all
++                // run in the same dispatch.
++                if (outer - inner < pipeline_idx) {
++                    pc2.inner_end = 100;
++                    inner = outer;
++                    pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx];
++                } else {
++                    // Smaller workgroup empirically seems to perform better
++                    pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx - 2];
++                }
++                ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
++                ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);
++                ggml_vk_sync_buffers(ctx, subctx);
++            }
++        }
++        ctx->prealloc_x_need_sync = true;
++    }
++}
++
++static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    uint32_t ncols = src0->ne[0];
++    uint32_t nrows = ggml_nrows(src0);
++    uint32_t k = dst->ne[0];
++
++    vk_op_topk_push_constants pc { ncols, ncols, ncols, k, nrows, 0, 0 };
++
++    if (ctx->prealloc_x_need_sync) {
++        ggml_vk_sync_buffers(ctx, subctx);
++    }
++
++    std::array elements;
++    elements[1] = std::min(nrows, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
++    elements[2] = 1;
++
++    uint32_t num_elements = ncols;
++
++    // Each iteration reduces a workgroup's worth of elements down to the K
++    // largest elements. Repeat until we have the top K elements.
++    // Need to do at least one iteration to write out the results.
++    bool done_one_iter = false;
++    uint32_t dbl_buf_index = 0;
++    size_t dbl_buf_size;
++    while (num_elements > k || !done_one_iter) {
++
++        // Prefer going as small as num_topk_pipelines - 3 for perf reasons.
++        // But if K is larger, then we need a larger workgroup
++        uint32_t max_pipeline = num_topk_pipelines - 1;
++        uint32_t preferred_pipeline = std::max(num_topk_pipelines - 3, (uint32_t)log2f(float(k)) + 2);
++        max_pipeline = std::min(preferred_pipeline, max_pipeline);
++        uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1;
++        // require full subgroup
++        min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2);
++
++        uint32_t pipeline_idx = (uint32_t)ceilf(log2f(float(num_elements)));
++        pipeline_idx = std::min(pipeline_idx, max_pipeline);
++        pipeline_idx = std::max(pipeline_idx, min_pipeline);
++
++        if (num_elements > (1u << pipeline_idx)) {
++            // If we could finish on this loop iteration (i.e. a single workgroup)
++            // then do so. It's better than the overhead of another pass.
++            for (uint32_t i = pipeline_idx; i < num_topk_pipelines; ++i) {
++                if (num_elements <= (1u << i)) {
++                    pipeline_idx = i;
++                    break;
++                }
++            }
++        }
++
++        vk_pipeline pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];
++        // If the device doesn't support a pipeline this large, use smaller
++        while (!pipeline) {
++            pipeline_idx--;
++            GGML_ASSERT(pipeline_idx >= min_pipeline);
++            pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];
++        }
++
++        vk_op_topk_push_constants pc2 = pc;
++        pc2.ncols_input = num_elements;
++
++        // Number of elements remaining after this pass
++        uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]);
++
++        pc2.ncols_output = num_dst_elements;
++
++        if (!done_one_iter) {
++            // Reserve space for ivec2 per element, double buffered
++            // K per workgroup per row
++            dbl_buf_size = num_dst_elements * nrows * 2 * sizeof(int);
++            dbl_buf_size = ROUNDUP_POW2(dbl_buf_size, ctx->device->properties.limits.minStorageBufferOffsetAlignment);
++            const size_t x_sz = dbl_buf_size * 2;
++
++            if (ctx->prealloc_size_x < x_sz) {
++                ctx->prealloc_size_x = x_sz;
++                ggml_vk_preallocate_buffers(ctx, subctx);
++            }
++        }
++
++        vk_subbuffer src_buf;
++        vk_subbuffer dst_buf;
++
++        if (num_elements == ncols) {
++            pc2.first_pass = 1;
++            src_buf = ggml_vk_tensor_subbuffer(ctx, src0);
++        } else {
++            src_buf = { ctx->prealloc_x, dbl_buf_index * dbl_buf_size, dbl_buf_size };
++        }
++        if (num_dst_elements == k) {
++            pc2.last_pass = 1;
++            dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
++        } else {
++            dst_buf = { ctx->prealloc_x, (dbl_buf_index ^ 1) * dbl_buf_size, dbl_buf_size };
++        }
++
++        elements[0] = num_elements;
++
++        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
++        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc2, elements);
++        num_elements = num_dst_elements;
++        dbl_buf_index ^= 1;
++        if (num_elements > k) {
++            ggml_vk_sync_buffers(ctx, subctx);
++        }
++        done_one_iter = true;
++    }
++    ctx->prealloc_x_need_sync = true;
++}
++
++static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM, p);
++}
++
++static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, p);
++}
++
++static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
++    p.weight = 1.0f / (float)src0->ne[0];
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_MEAN, p);
++}
++
++static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    vk_op_sum_rows_push_constants pc = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
++    // Use the single pass shader when the rows are small or there are enough rows to fill the GPU.
++    // For fewer, larger rows, use the multipass shader to spread each row across SMs.
++    if (dst->ne[0] <= 4096 || ggml_nrows(dst) >= ctx->device->shader_core_count) {
++        ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, pc);
++        return;
++    }
++
++    // First pass computes partial sums within a block, and stores the last partial
++    // to the temp buffer. Second pass sums the block partials from the temp buffer
++    // and adds that to the result of the first pass.
++    vk_pipeline pipeline1 = ctx->device->pipeline_cumsum_multipass1_f32;
++    vk_pipeline pipeline2 = ctx->device->pipeline_cumsum_multipass2_f32;
++    GGML_ASSERT(pipeline1 != nullptr && pipeline2 != nullptr);
++
++    ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1);
++    ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1);
++
++    std::array elements;
++
++    elements[0] = dst->ne[0];
++    elements[1] = (uint32_t)ggml_nrows(dst);
++    elements[2] = 1;
++
++    size_t temp_size = sizeof(float) * elements[0] * ggml_nrows(dst);
++
++    if (ctx->prealloc_size_split_k < temp_size) {
++        ctx->prealloc_size_split_k = temp_size;
++        ggml_vk_preallocate_buffers(ctx, subctx);
++    }
++
++    vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src0);
++    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
++    vk_subbuffer temp_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
++
++    if (ctx->prealloc_split_k_need_sync) {
++        ggml_vk_sync_buffers(ctx, subctx);
++    }
++
++    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, {src_buf, dst_buf, temp_buf}, pc, elements);
++    ggml_vk_sync_buffers(ctx, subctx);
++    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, {src_buf, dst_buf, temp_buf}, pc, elements);
++
++    ctx->prealloc_split_k_need_sync = true;
++}
++
++static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f, 0.0f, 0.0f });
++}
++
++static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
++    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
++}
++
++static void ggml_vk_solve_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
++    const uint32_t src0_type_size = ggml_type_size(src0->type);
++    const uint32_t src1_type_size = ggml_type_size(src1->type);
++    const uint32_t dst_type_size = ggml_type_size(dst->type);
++
++    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOLVE_TRI, {
++        (uint32_t)ggml_nelements(src0),
++        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
++        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
++        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
++        0,
++        0.0f, 0.0f, 0,
++    });
++}
++
++static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
++    const int32_t s0 = dst->op_params[0];
++    const int32_t s1 = dst->op_params[1];
++    const int32_t p0 = dst->op_params[2];
++    const int32_t p1 = dst->op_params[3];
++    const int32_t d0 = dst->op_params[4];
++    const int32_t d1 = dst->op_params[5];
++
++    const bool is_2D = dst->op_params[6] == 1;
++
++    const uint32_t IC = src1->ne[is_2D ? 2 : 1];
++    const uint32_t IH = is_2D ? src1->ne[1] : 1;
++    const uint32_t IW =         src1->ne[0];
++
++    const uint32_t KH = is_2D ? src0->ne[1] : 1;
++    const uint32_t KW =         src0->ne[0];
++
++    const uint32_t OH = is_2D ? dst->ne[2] : 1;
++    const uint32_t OW =         dst->ne[1];
++
++    const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
++    const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
++
++    const uint32_t pelements = OW * KW * KH;
++    const uint32_t batch = src1->ne[is_2D ? 3 : 2];
++
++    const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
++    const vk_buffer d_buf = d_buf_ctx->dev_buffer;
++
++    const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
++
++    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_IM2COL, {
++        dst_addr,
++        batch_offset, offset_delta,
++        IC, IW, IH, OW, OH, KW, KH,
++        pelements,
++        IC * KH * KW,
++        s0, s1, p0, p1, d0, d1, batch * IC
++    });
++}
++
++static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
++    GGML_TENSOR_BINARY_OP_LOCALS
++
++    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
++    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
++    const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
++    const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
++    const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
++    const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
++    const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
++    const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
++    const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
++    const int32_t IC = ((const int32_t *)(dst->op_params))[9];
++
++    const int64_t N  = ne13 / IC;
++    const int64_t ID = ne12;
++    const int64_t IH = ne11;
++    const int64_t IW = ne10;
++
++    const int64_t KD = ne02;
++    const int64_t KH = ne01;
++    const int64_t KW = ne00;
++
++    const int64_t OD = ne3 / N;
++    const int64_t OH = ne2;
++    const int64_t OW = ne1;
++
++    const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
++    const vk_buffer d_buf = d_buf_ctx->dev_buffer;
++
++    const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
++
++    vk_op_im2col_3d_push_constants pc {};
++
++    pc.dst_addr = dst_addr;
++    pc.nb10 = nb10 / ggml_type_size(src1->type);
++    pc.nb11 = nb11 / ggml_type_size(src1->type);
++    pc.nb12 = nb12 / ggml_type_size(src1->type);
++    pc.nb13 = nb13 / ggml_type_size(src1->type);
++    pc.s0 = s0;
++    pc.s1 = s1;
++    pc.s2 = s2;
++    pc.p0 = p0;
++    pc.p1 = p1;
++    pc.p2 = p2;
++    pc.d0 = d0;
++    pc.d1 = d1;
++    pc.d2 = d2;
++    pc.IW = IW;
++    pc.IH = IH;
++    pc.ID = ID;
++    pc.IC = IC;
++    pc.KW = KW;
++    pc.OH = OH;
++    pc.KD_KH_KW = KD*KH*KW;
++    pc.KH_KW = KH*KW;
++    pc.IC_KD_KH_KW = IC*KD*KH*KW;
++    pc.N_OD_OH = N*OD*OH;
++    pc.OD_OH = OD*OH;
++    pc.OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW;
++    pc.OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
++    pc.OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
++
++    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_IM2COL_3D, std::move(pc));
++}
++
++static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    const uint32_t dim = dst->op_params[0];
++    const uint32_t max_period = dst->op_params[1];
++    const uint32_t nb1 = dst->nb[1] / ggml_type_size(dst->type);
++
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, {
++        nb1, dim, max_period,
++    });
++}
++
++static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
++    // src0: (K, Cout, Cin, 1) -- kernel
++    // src1: (L, Cin, 1, 1) -- input
++    // dst: (*, Cout, 1, 1)
++
++    GGML_ASSERT(src0->type == GGML_TYPE_F32);
++    GGML_ASSERT(src1->type == GGML_TYPE_F32);
++    GGML_ASSERT( dst->type == GGML_TYPE_F32);
++
++    GGML_TENSOR_BINARY_OP_LOCALS
++
++    GGML_ASSERT(nb00 == sizeof(float));
++    GGML_ASSERT(nb10 == sizeof(float));
++
++    const int32_t s0 = dst->op_params[0];
++
++    vk_op_conv_transpose_1d_push_constants p{};
++    p.Cout = static_cast(ne01);
++    p.Cin = static_cast(ne02);
++    p.K = static_cast(ne00);
++    p.L = static_cast(ne10);
++    p.KL = static_cast(ne0);
++    p.nb01 = static_cast(nb01 / nb00);
++    p.nb02 = static_cast(nb02 / nb00);
++    p.nb11 = static_cast(nb11 / nb10);
++    p.nb1 = static_cast(nb1 / nb0);
++    p.s0 = static_cast(s0);
++
++    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p));
++}
++
++static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    uint32_t op = static_cast(dst->op_params[0]);
++    const int32_t k1 = dst->op_params[1];
++    const int32_t k0 = dst->op_params[2];
++    const int32_t s1 = dst->op_params[3];
++    const int32_t s0 = dst->op_params[4];
++    const int32_t p1 = dst->op_params[5];
++    const int32_t p0 = dst->op_params[6];
++
++    const uint32_t IH = src0->ne[1];
++    const uint32_t IW = src0->ne[0];
++
++    const uint32_t N = dst->ne[3];
++
++    const uint32_t OC = dst->ne[2];
++    const uint32_t OH = dst->ne[1];
++    const uint32_t OW = dst->ne[0];
++
++    const uint32_t parallel_elements = N * OC * OH * OW;
++
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_POOL_2D, {
++        IW, IH, OW, OH, OC,
++        parallel_elements,
++        op,
++        k0, k1, s0, s1, p0, p1,
++    });
++}
++
++static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
++                            const ggml_tensor * src1, ggml_tensor * dst) {
++    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
++    GGML_ASSERT(src1->type == GGML_TYPE_F32);
++    GGML_ASSERT(dst->type == GGML_TYPE_F32);
++
++    GGML_TENSOR_BINARY_OP_LOCALS
++    GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
++    GGML_ASSERT(nb10 == sizeof(float));
++    GGML_ASSERT(nb0 == sizeof(float));
++
++    bool transpose = dst->op == GGML_OP_CONV_TRANSPOSE_2D;
++
++    vk_op_conv2d_push_constants p{};
++    p.Cout = static_cast(!transpose ? ne03 : ne02);
++    p.Cin  = static_cast(!transpose ? ne02 : ne03);
++    p.N    = static_cast(ne13);
++    GGML_ASSERT(p.Cout == ne2);
++    GGML_ASSERT(p.Cin == ne12);
++
++    p.W  = static_cast(ne10);
++    p.H  = static_cast(ne11);
++    p.OW = static_cast(ne0);
++    p.OH = static_cast(ne1);
++
++    p.nb01 = static_cast(nb01 / nb00);
++    p.nb02 = static_cast(nb02 / nb00);
++    p.nb03 = static_cast(nb03 / nb00);
++
++    p.nb11 = static_cast(nb11 / nb10);
++    p.nb12 = static_cast(nb12 / nb10);
++    p.nb13 = static_cast(nb13 / nb10);
++
++    p.nb1 = static_cast(nb1 / nb0);
++    p.nb2 = static_cast(nb2 / nb0);
++    p.nb3 = static_cast(nb3 / nb0);
++
++    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, std::move(p));
++}
++
++static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
++    vk_op_conv2d_dw_push_constants p{};
++    p.ne = ggml_nelements(dst);
++    p.channels = dst->ne[2];
++    p.batches = dst->ne[3];
++    p.dst_w = dst->ne[0];
++    p.dst_h = dst->ne[1];
++    p.src_w = src1->ne[0];
++    p.src_h = src1->ne[1];
++    p.knl_w = src0->ne[0];
++    p.knl_h = src0->ne[1];
++    p.stride_x = dst->op_params[0];
++    p.stride_y = dst->op_params[1];
++    p.pad_x = dst->op_params[2];
++    p.pad_y = dst->op_params[3];
++    p.dilation_x = dst->op_params[4];
++    p.dilation_y = dst->op_params[5];
++
++    GGML_ASSERT(src0->ne[3] == p.channels);
++    GGML_ASSERT(src1->ne[3] == p.batches);
++
++    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p));
++}
++
++static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
++    const float * op_params = (const float *)dst->op_params;
++    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f, 0.0f, 0.0f });
++}
++
++#ifdef GGML_VULKAN_RUN_TESTS
++static void ggml_vk_print_matrix_area(const void * data, ggml_type type, int ne0, int ne1, int i0, int i1, int i2) {
++    if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16) {
++        return;
++    }
++    i0 = std::max(i0, 5);
++    i1 = std::max(i1, 5);
++    i2 = std::max(i2, 0);
++    fprintf(stderr, "         ");
++    for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
++        fprintf(stderr, "%7d ", idx1);
++    }
++    fprintf(stderr, "\n");
++    for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) {
++        fprintf(stderr, "%7d: ", idx0);
++        for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
++            if (idx0 >= 0 && idx0 < ne0 && idx1 >= 0 && idx1 < ne1) {
++                float val;
++                if (type == GGML_TYPE_F32) {
++                    val = *((const float *) data + i2*ne1*ne0 + idx1*ne0 + idx0);
++                } else if (type == GGML_TYPE_F16) {
++                    val = ggml_fp16_to_fp32(*((const ggml_fp16_t *) data + i2*ne1*ne0 + idx1*ne0 + idx0));
++                } else {
++                    GGML_ABORT("fatal error");
++                }
++                fprintf(stderr, "% 7.2f ", val);
++            } else {
++                fprintf(stderr, "        ");
++            }
++        }
++        fprintf(stderr, "\n");
++    }
++}
++
++template 
++static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, int split_k, int shader_size) {
++    VK_LOG_DEBUG("ggml_vk_test_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << shader_size << ")");
++    const size_t x_ne = m * k * batch;
++    const size_t y_ne = k * n * batch;
++    const size_t d_ne = m * n * batch;
++
++    vk_pipeline p;
++    std::string shname;
++    if (shader_size == 0) {
++        if (std::is_same() && std::is_same()) {
++            p = ctx->device->pipeline_matmul_f32->a_s;
++            shname = "F32_ALIGNED_S";
++        } else if (std::is_same() && std::is_same()) {
++            p = ctx->device->pipeline_matmul_f32_f16->a_s;
++            shname = "F32_F16_ALIGNED_S";
++        } else if (std::is_same() && std::is_same()) {
++            p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_s;
++            shname = "F16_F32_ALIGNED_S";
++        } else if (std::is_same() && std::is_same()) {
++            p = ctx->device->pipeline_matmul_f16.f32acc->a_s;
++            shname = "F16_ALIGNED_S";
++        } else {
++            GGML_ABORT("fatal error");
++        }
++    } else if (shader_size == 1) {
++        if (std::is_same() && std::is_same()) {
++            p = ctx->device->pipeline_matmul_f32->a_m;
++            shname = "F32_ALIGNED_M";
++        } else if (std::is_same() && std::is_same()) {
++            p = ctx->device->pipeline_matmul_f32_f16->a_m;
++            shname = "F32_F16_ALIGNED_M";
++        } else if (std::is_same() && std::is_same()) {
++            p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_m;
++            shname = "F16_F32_ALIGNED_M";
++        } else if (std::is_same() && std::is_same()) {
++            p = ctx->device->pipeline_matmul_f16.f32acc->a_m;
++            shname = "F16_ALIGNED_M";
++        } else {
++            GGML_ABORT("fatal error");
++        }
++    } else if (shader_size == 2) {
++        if (std::is_same() && std::is_same()) {
++            p = ctx->device->pipeline_matmul_f32->a_l;
++            shname = "F32_ALIGNED_L";
++        } else if (std::is_same() && std::is_same()) {
++            p = ctx->device->pipeline_matmul_f32_f16->a_l;
++            shname = "F32_F16_ALIGNED_L";
++        } else if (std::is_same() && std::is_same()) {
++            p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_l;
++            shname = "F16_F32_ALIGNED_L";
++        } else if (std::is_same() && std::is_same()) {
++            p = ctx->device->pipeline_matmul_f16.f32acc->a_l;
++            shname = "F16_ALIGNED_L";
++        } else {
++            GGML_ABORT("fatal error");
++        }
++    } else {
++        GGML_ASSERT(0);
++    }
++
++    const size_t kpad = ggml_vk_align_size(k, p->align);
++
++    if (k != kpad) {
++        if (shader_size == 0) {
++            if (std::is_same() && std::is_same()) {
++                p = ctx->device->pipeline_matmul_f32->s;
++                shname = "F32_S";
++            } else if (std::is_same() && std::is_same()) {
++                p = ctx->device->pipeline_matmul_f32_f16->s;
++                shname = "F32_F16_S";
++            } else if (std::is_same() && std::is_same()) {
++                p = ctx->device->pipeline_matmul_f16_f32.f32acc->s;
++                shname = "F16_F32_S";
++            } else if (std::is_same() && std::is_same()) {
++                p = ctx->device->pipeline_matmul_f16.f32acc->s;
++                shname = "F16_S";
++            }
++        } else if (shader_size == 1) {
++            if (std::is_same() && std::is_same()) {
++                p = ctx->device->pipeline_matmul_f32->m;
++                shname = "F32_M";
++            } else if (std::is_same() && std::is_same()) {
++                p = ctx->device->pipeline_matmul_f32_f16->m;
++                shname = "F32_F16_M";
++            } else if (std::is_same() && std::is_same()) {
++                p = ctx->device->pipeline_matmul_f16_f32.f32acc->m;
++                shname = "F16_F32_M";
++            } else if (std::is_same() && std::is_same()) {
++                p = ctx->device->pipeline_matmul_f16.f32acc->m;
++                shname = "F16_M";
++            }
++        } else if (shader_size == 2) {
++            if (std::is_same() && std::is_same()) {
++                p = ctx->device->pipeline_matmul_f32->l;
++                shname = "F32_L";
++            } else if (std::is_same() && std::is_same()) {
++                p = ctx->device->pipeline_matmul_f32_f16->l;
++                shname = "F32_F16_L";
++            } else if (std::is_same() && std::is_same()) {
++                p = ctx->device->pipeline_matmul_f16_f32.f32acc->l;
++                shname = "F16_F32_L";
++            } else if (std::is_same() && std::is_same()) {
++                p = ctx->device->pipeline_matmul_f16.f32acc->l;
++                shname = "F16_L";
++            }
++        }
++    }
++
++    if (split_k > 1) {
++        ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
++
++        if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
++            // Resize buffer
++            if (ctx->prealloc_split_k != nullptr) {
++                ggml_vk_destroy_buffer(ctx->prealloc_split_k);
++            }
++            ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, {vk::MemoryPropertyFlagBits::eDeviceLocal});
++        }
++    }
++
++    ggml_pipeline_allocate_descriptor_sets(ctx);
++
++    vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal});
++    vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal});
++    vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal});
++
++    X_TYPE* x = (X_TYPE *) malloc(sizeof(X_TYPE) * x_ne);
++    Y_TYPE* y = (Y_TYPE *) malloc(sizeof(Y_TYPE) * y_ne);
++    float* d = (float *) malloc(sizeof(float) * d_ne);
++
++    for (size_t i = 0; i < x_ne; i++) {
++        if (std::is_same()) {
++            x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
++            // x[i] = 1.0f;
++            // x[i] = i + 1;
++            // x[i] = (i % k == i / k) ? 1.0f : 0.0f;
++        } else if (std::is_same()) {
++            x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
++            // x[i] = ggml_fp32_to_fp16(1.0f);
++            // x[i] = ggml_fp32_to_fp16(i + 1);
++            // x[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
++        } else {
++            GGML_ABORT("fatal error");
++        }
++    }
++    for (size_t i = 0; i < y_ne; i++) {
++        if (std::is_same()) {
++            y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
++            // y[i] = (i % k == i / k) ? 1.0f : 0.0f;
++            // y[i] = i + 1;
++        } else if (std::is_same()) {
++            y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
++            // y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
++            // y[i] = ggml_fp32_to_fp16(i + 1);
++        } else {
++            GGML_ABORT("fatal error");
++        }
++    }
++
++    ggml_vk_buffer_write(d_X, 0, x, sizeof(X_TYPE) * k * m * batch);
++    ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch);
++
++    vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
++    ggml_vk_ctx_begin(ctx->device, subctx);
++    for (size_t i = 0; i < num_it; i++) {
++        ggml_vk_matmul(
++            ctx, subctx, p, ggml_vk_subbuffer(ctx, d_X), ggml_vk_subbuffer(ctx, d_Y), ggml_vk_subbuffer(ctx, d_D), ggml_vk_subbuffer(ctx, ctx->prealloc_split_k),
++            m, n, k,
++            k, k, m, k*m, k*n, m*n,
++            split_k, batch, batch, batch, 1, 1, n
++        );
++    }
++    ggml_vk_ctx_end(subctx);
++
++    auto begin = std::chrono::high_resolution_clock::now();
++    ggml_vk_submit(subctx, ctx->fence);
++    VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences");
++    ctx->device->device.resetFences({ ctx->fence });
++    ggml_vk_queue_command_pools_cleanup(ctx->device);
++
++    auto end = std::chrono::high_resolution_clock::now();
++    double time = std::chrono::duration_cast(end-begin).count() / 1000.0;
++
++    // copy dst to host
++    ggml_vk_buffer_read(d_D, 0, d, sizeof(float) * d_ne);
++
++    float * d_chk = (float *) malloc(sizeof(float) * d_ne);
++
++    ggml_init_params iparams = {
++        /*.mem_size   =*/ 1024*1024*1024,
++        /*.mem_buffer =*/ NULL,
++        /*.no_alloc   =*/ true,
++    };
++
++    ggml_context * ggml_ctx = ggml_init(iparams);
++
++    ggml_type src0_type;
++    ggml_type src1_type;
++
++    if (std::is_same()) {
++        src0_type = GGML_TYPE_F32;
++    } else if (std::is_same()) {
++        src0_type = GGML_TYPE_F16;
++    } else {
++        GGML_ABORT("fatal error");
++    }
++    if (std::is_same()) {
++        src1_type = GGML_TYPE_F32;
++    } else if (std::is_same()) {
++        src1_type = GGML_TYPE_F16;
++    } else {
++        GGML_ABORT("fatal error");
++    }
++
++    ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, src0_type, k, m, batch);
++    ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, src1_type, k, n, batch);
++    ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml);
++
++    src0_ggml->data = x;
++    src1_ggml->data = y;
++    tensor_ggml->data = d_chk;
++
++    ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
++    ggml_build_forward_expand(cgraph, tensor_ggml);
++
++    ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1);
++
++    ggml_free(ggml_ctx);
++
++    double avg_err = 0.0;
++    int first_err_n = -1;
++    int first_err_m = -1;
++    int first_err_b = -1;
++
++    for (size_t i = 0; i < m*n*batch; i++) {
++        double err = std::fabs(d[i] - d_chk[i]);
++        avg_err += err;
++
++        if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) {
++            first_err_b = i / (m * n);
++            first_err_n = (i % (m * n)) / m;
++            first_err_m = (i % (m * n)) % m;
++        }
++    }
++
++    avg_err /= m * n;
++
++    double tflops = 2.0*m*n*k*batch*num_it / (time / 1000.0) / (1000.0*1000.0*1000.0*1000.0);
++
++    std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
++
++    if (avg_err > 0.1 || std::isnan(avg_err)) {
++        std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
++        std::cerr << "Actual result: " << std::endl << std::endl;
++        ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
++        std::cerr << "Expected result: " << std::endl << std::endl;
++        ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
++
++        if (split_k > 1) {
++            float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
++            ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
++
++            std::cerr << "d_buf0: " << std::endl << std::endl;
++            ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
++
++            std::cerr << "d_buf1: " << std::endl << std::endl;
++            ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
++
++            std::cerr << "d_buf2: " << std::endl << std::endl;
++            ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
++
++            std::cerr << "d_buf3: " << std::endl << std::endl;
++            ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
++
++            free(split_k_buf);
++        }
++    }
++
++    free(d_chk);
++
++    ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
++
++    ggml_vk_destroy_buffer(d_X);
++    ggml_vk_destroy_buffer(d_Y);
++    ggml_vk_destroy_buffer(d_D);
++
++    free(x);
++    free(y);
++    free(d);
++}
++
++static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
++    if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16) {
++        return;
++    }
++    i0 = std::max(i0, 5);
++    i1 = std::max(i1, 5);
++    i2 = std::max(i2, 0);
++    i3 = std::max(i3, 0);
++    fprintf(stderr, "         ");
++    for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
++        fprintf(stderr, "%7d ", idx1);
++    }
++    fprintf(stderr, "\n");
++    for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) {
++        fprintf(stderr, "%7d: ", idx0);
++        for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
++            if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) {
++                float val;
++                if (tensor->type == GGML_TYPE_F32) {
++                    val = *(float *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]);
++                } else if (tensor->type == GGML_TYPE_F16) {
++                    val = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]));
++                } else {
++                    GGML_ABORT("fatal error");
++                }
++                fprintf(stderr, "% 7.2f ", val);
++            } else {
++                fprintf(stderr, "        ");
++            }
++        }
++        fprintf(stderr, "\n");
++    }
++}
++
++static void ggml_vk_quantize_data(const float * from, void * to, size_t ne, ggml_type quant) {
++    ggml_quantize_chunk(quant, from, to, 0, 1, ne, nullptr);
++}
++
++static void ggml_vk_dequantize_data(const void * from, float * to, size_t ne, ggml_type quant) {
++    if (quant == GGML_TYPE_F32) {
++        memcpy(to, from, sizeof(float) * ne);
++        return;
++    }
++
++    const auto * tt = ggml_get_type_traits(quant);
++
++    ggml_to_float_t dequant_fn = tt->to_float;
++
++    dequant_fn(from, to, ne);
++}
++
++static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
++    VK_LOG_DEBUG("ggml_vk_test_dequant(" << ne << ")");
++    const size_t x_sz = sizeof(float) * ne;
++    const size_t x_sz_f16 = sizeof(ggml_fp16_t) * ne;
++    const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
++    float * x = (float *) malloc(x_sz);
++    void * qx = malloc(qx_sz);
++    vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
++    vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz_f16, {vk::MemoryPropertyFlagBits::eDeviceLocal});
++    float * x_ref = (float *) malloc(x_sz);
++    ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16);
++
++    for (size_t i = 0; i < ne; i++) {
++        x[i] = rand() / (float)RAND_MAX;
++    }
++
++    vk_pipeline p = ggml_vk_get_to_fp16(ctx, quant);
++
++    ggml_vk_quantize_data(x, qx, ne, quant);
++    ggml_vk_dequantize_data(qx, x_ref, ne, quant);
++
++    ggml_pipeline_request_descriptor_sets(ctx, p, 1);
++
++    ggml_pipeline_allocate_descriptor_sets(ctx);
++
++    ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
++
++    vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
++    ggml_vk_ctx_begin(ctx->device, subctx);
++    const std::vector pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne };
++    ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc, { (uint32_t)ne, 1, 1});
++    ggml_vk_ctx_end(subctx);
++
++    auto begin = std::chrono::high_resolution_clock::now();
++
++    ggml_vk_submit(subctx, ctx->fence);
++    VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
++    ctx->device->device.resetFences({ ctx->fence });
++    ggml_vk_queue_command_pools_cleanup(ctx->device);
++
++    auto end = std::chrono::high_resolution_clock::now();
++
++    double ms_dequant = std::chrono::duration_cast(end-begin).count() / 1000.0;
++    ggml_vk_buffer_read(x_buf, 0, x_chk, x_sz_f16);
++
++    int first_err = -1;
++
++    double avg_err = 0.0;
++    for (size_t i = 0; i < ne; i++) {
++        double error = std::fabs(x_ref[i] - ggml_fp16_to_fp32(x_chk[i]));
++        avg_err += error;
++
++        if (first_err < 0 && error > 0.05) {
++            first_err = i;
++        }
++    }
++
++    avg_err /= ne;
++
++    std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err << std::endl;
++
++    if (avg_err > 0.1) {
++        std::cerr << "first_error = " << first_err << std::endl;
++        std::cerr << "Actual result: " << std::endl << std::endl;
++        for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {
++            std::cerr << ggml_fp16_to_fp32(x_chk[i]) << ", ";
++        }
++        std::cerr << std::endl << "Expected result: " << std::endl << std::endl;
++        for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {
++            std::cerr << x_ref[i] << ", ";
++        }
++        std::cerr << std::endl;
++    }
++
++    ggml_vk_destroy_buffer(x_buf);
++    ggml_vk_destroy_buffer(qx_buf);
++
++    free(x);
++    free(qx);
++    free(x_ref);
++    free(x_chk);
++}
++
++// This does not work without ggml q8_1 quantization support
++//
++// typedef uint16_t ggml_half;
++// typedef uint32_t ggml_half2;
++//
++// #define QK8_1 32
++// typedef struct {
++//     union {
++//         struct {
++//             ggml_half d; // delta
++//             ggml_half s; // d * sum(qs[i])
++//         } GGML_COMMON_AGGR_S;
++//         ggml_half2 ds;
++//     } GGML_COMMON_AGGR_U;
++//     int8_t qs[QK8_1]; // quants
++// } block_q8_1;
++//
++// static void ggml_vk_test_quantize(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
++//     VK_LOG_DEBUG("ggml_vk_test_quantize(" << ne << ")");
++//     GGML_ASSERT(quant == GGML_TYPE_Q8_1);
++//
++//     const size_t x_sz = sizeof(float) * ne;
++//     const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
++//     float * x = (float *) malloc(x_sz);
++//     block_q8_1 * qx     = (block_q8_1 *)malloc(qx_sz);
++//     block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz);
++//     vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
++//     vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
++//
++//     for (size_t i = 0; i < ne; i++) {
++//         x[i] = rand() / (float)RAND_MAX;
++//     }
++//
++//     vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant);
++//
++//     ggml_pipeline_request_descriptor_sets(ctx, p, 1);
++//
++//     ggml_pipeline_allocate_descriptor_sets(ctx);
++//
++//     ggml_vk_buffer_write(x_buf, 0, x, x_sz);
++//
++//     vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
++//     ggml_vk_ctx_begin(ctx->device, subctx);
++//     ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, x_buf), ggml_vk_subbuffer(ctx, qx_buf), ne);
++//     ggml_vk_ctx_end(subctx);
++//
++//     auto begin = std::chrono::high_resolution_clock::now();
++//
++//     ggml_vk_submit(subctx, ctx->fence);
++//     VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_quantize waitForFences");
++//     ctx->device->device.resetFences({ ctx->fence });
++//     ggml_vk_queue_command_pools_cleanup(ctx->device);
++//
++//     auto end = std::chrono::high_resolution_clock::now();
++//
++//     double ms_quant = std::chrono::duration_cast(end-begin).count() / 1000.0;
++//     ggml_vk_buffer_read(qx_buf, 0, qx, qx_sz);
++//
++//     ggml_vk_quantize_data(x, qx_res, ne, quant);
++//
++//     int first_err = -1;
++//
++//     for (size_t i = 0; i < ne / 32; i++) {
++//         double error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d));
++//
++//         if (first_err < 0 && error > 0.1) {
++//             first_err = i;
++//         }
++//
++//         error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s));
++//
++//         if (first_err < 0 && error > 0.1) {
++//             first_err = i;
++//         }
++//
++//         for (size_t j = 0; j < 32; j++) {
++//             uint64_t error = std::abs(qx_res[i].qs[j] - qx[i].qs[j]);
++//
++//             if (first_err < 0 && error > 1) {
++//                 first_err = i;
++//             }
++//         }
++//     }
++//
++//     std::cerr << "TEST QUANTIZE " << ggml_type_name(quant) << " time=" << ms_quant << "ms " << (first_err == -1 ? "CORRECT" : "INCORRECT") << std::endl;
++//
++//     if (first_err != -1) {
++//         std::cerr << "first_error = " << first_err << std::endl;
++//         std::cerr << "Actual result: " << std::endl << std::endl;
++//         std::cout << "d=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
++//         for (size_t j = 0; j < 32; j++) {
++//             std::cout << " qs" << j << "=" << (uint32_t)qx[first_err].qs[j] << " ";
++//         }
++//         std::cerr << std::endl << std::endl << "Expected result: " << std::endl << std::endl;
++//         std::cout << "d=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
++//         for (size_t j = 0; j < 32; j++) {
++//             std::cout << " qs" << j << "=" << (uint32_t)qx_res[first_err].qs[j] << " ";
++//         }
++//         std::cerr << std::endl;
++//     }
++//
++//     ggml_vk_destroy_buffer(x_buf);
++//     ggml_vk_destroy_buffer(qx_buf);
++//
++//     free(x);
++//     free(qx);
++//     free(qx_res);
++// }
++
++static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant, bool mmq = false) {
++    VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")");
++    const size_t x_ne = m * k * batch;
++    const size_t y_ne = k * n * batch;
++    const size_t d_ne = m * n * batch;
++
++    vk_matmul_pipeline2 * pipelines;
++
++    if (mmq) {
++        pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1;
++    } else {
++        pipelines = ctx->device->pipeline_dequant_mul_mat_mat;
++    }
++
++    const bool fp16acc = ctx->device->fp16;
++
++    vk_pipeline p;
++    std::string shname;
++    if (shader_size == 0) {
++        p = fp16acc ? pipelines[quant].f16acc->a_s : pipelines[quant].f32acc->a_s;
++        shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
++    } else if (shader_size == 1) {
++        p = fp16acc ? pipelines[quant].f16acc->a_m : pipelines[quant].f32acc->a_m;
++        shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
++    } else if (shader_size == 2) {
++        p = fp16acc ? pipelines[quant].f16acc->a_l : pipelines[quant].f32acc->a_l;
++        shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
++    } else {
++        GGML_ASSERT(0);
++    }
++
++    const size_t kpad = mmq ? 0 : ggml_vk_align_size(k, p->align);
++
++    if (mmq || k != kpad) {
++        if (shader_size == 0) {
++            p = fp16acc ? pipelines[quant].f16acc->s : pipelines[quant].f32acc->s;
++            shname = std::string(ggml_type_name(quant)) + "_S";
++        } else if (shader_size == 1) {
++            p = fp16acc ? pipelines[quant].f16acc->m : pipelines[quant].f32acc->m;
++            shname = std::string(ggml_type_name(quant)) + "_M";
++        } else if (shader_size == 2) {
++            p = fp16acc ? pipelines[quant].f16acc->l : pipelines[quant].f32acc->l;
++            shname = std::string(ggml_type_name(quant)) + "_L";
++        } else {
++            GGML_ASSERT(0);
++        }
++    }
++
++    if (p == nullptr) {
++        std::cerr << "error: no pipeline for ggml_vk_test_dequant_matmul " << ggml_type_name(quant) << std::endl;
++        return;
++    }
++
++    const size_t x_sz = sizeof(float) * x_ne;
++    const size_t y_sz = sizeof(float) * y_ne;
++    const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);
++    const size_t qy_sz = mmq ? y_ne * ggml_type_size(GGML_TYPE_Q8_1)/ggml_blck_size(GGML_TYPE_Q8_1) : y_sz;
++    const size_t d_sz = sizeof(float) * d_ne;
++    float * x = (float *) malloc(x_sz);
++    float * y = (float *) malloc(y_sz);
++    void * qx = malloc(qx_sz);
++    vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
++    vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
++    vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
++    vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
++    float * d = (float *) malloc(d_sz);
++    float * d_chk = (float *) malloc(d_sz);
++
++    for (size_t i = 0; i < x_ne; i++) {
++        x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
++        // x[i] = (i % k == i / k) ? 1.0f : 0.0f;
++        // x[i] = i % k;
++    }
++
++    ggml_vk_quantize_data(x, qx, x_ne, quant);
++
++    for (size_t i = 0; i < y_ne; i++) {
++        y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
++        // y[i] = (i % k == i / k) ? 1.0f : 0.0f;
++        // y[i] = i % k;
++    }
++
++    if (split_k > 1) {
++        ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
++
++        if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
++            // Resize buffer
++            if (ctx->prealloc_split_k != nullptr) {
++                ggml_vk_destroy_buffer(ctx->prealloc_split_k);
++            }
++            ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, {vk::MemoryPropertyFlagBits::eDeviceLocal});
++        }
++    }
++    if (mmq) {
++        vk_pipeline pipeline_quantize_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
++        ggml_pipeline_request_descriptor_sets(ctx, pipeline_quantize_q8_1, num_it);
++    }
++
++    ggml_pipeline_allocate_descriptor_sets(ctx);
++
++    ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
++    ggml_vk_buffer_write(y_buf, 0, y, y_sz);
++
++    vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
++    ggml_vk_ctx_begin(ctx->device, subctx);
++    if (mmq) {
++        for (size_t i = 0; i < num_it; i++) {
++            ggml_vk_quantize_q8_1(ctx, subctx, { y_buf, 0, y_sz }, { qy_buf, 0, qy_sz }, y_ne);
++            ggml_vk_matmul(
++                ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
++                m, n, k,
++                k, k, m, k*m, k*n, m*n,
++                split_k, batch, batch, batch, 1, 1, n
++            );
++        }
++    } else {
++        for (size_t i = 0; i < num_it; i++) {
++            ggml_vk_matmul(
++                ctx, subctx, p, { qx_buf, 0, qx_sz }, { y_buf, 0, y_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
++                m, n, k,
++                k, k, m, k*m, k*n, m*n,
++                split_k, batch, batch, batch, 1, 1, n
++            );
++        }
++    }
++    ggml_vk_ctx_end(subctx);
++
++    auto begin = std::chrono::high_resolution_clock::now();
++
++    ggml_vk_submit(subctx, ctx->fence);
++    VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
++    ctx->device->device.resetFences({ ctx->fence });
++    ggml_vk_queue_command_pools_cleanup(ctx->device);
++
++    auto end = std::chrono::high_resolution_clock::now();
++
++    double time_ms = std::chrono::duration_cast(end-begin).count() / 1000.0;
++    ggml_vk_buffer_read(d_buf, 0, d, d_sz);
++
++    ggml_init_params iparams = {
++        /*.mem_size   =*/ 1024*1024*1024,
++        /*.mem_buffer =*/ NULL,
++        /*.no_alloc   =*/ true,
++    };
++
++    ggml_context * ggml_ctx = ggml_init(iparams);
++
++    ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, quant, k, m, batch);
++    ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, GGML_TYPE_F32, k, n, batch);
++    ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml);
++
++    src0_ggml->data = qx;
++    src1_ggml->data = y;
++    tensor_ggml->data = d_chk;
++
++    ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
++    ggml_build_forward_expand(cgraph, tensor_ggml);
++
++    ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1);
++
++    ggml_free(ggml_ctx);
++
++    double avg_err = 0.0;
++    int first_err_n = -1;
++    int first_err_m = -1;
++    int first_err_b = -1;
++
++    for (size_t i = 0; i < m*n*batch; i++) {
++        double err = std::fabs(d[i] - d_chk[i]);
++        avg_err += err;
++
++        if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) {
++            first_err_b = i / (m * n);
++            first_err_n = (i % (m * n)) / m;
++            first_err_m = (i % (m * n)) % m;
++        }
++    }
++
++    avg_err /= m * n;
++
++    double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0);
++
++    std::cerr << "TEST dequant matmul " << shname;
++    if (mmq) {
++        std::cerr << " mmq";
++    }
++    std::cerr << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
++
++    if (avg_err > 0.01 || std::isnan(avg_err)) {
++        std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
++        std::cerr << "Actual result: " << std::endl << std::endl;
++        ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
++        std::cerr << std::endl;
++        std::cerr << "Expected result: " << std::endl << std::endl;
++        ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
++
++        std::cerr << "src0: " << std::endl << std::endl;
++        ggml_vk_print_matrix_area(x, GGML_TYPE_F32, k, m, first_err_m, first_err_n, first_err_b);
++        std::cerr << std::endl;
++        std::cerr << "src1: " << std::endl << std::endl;
++        ggml_vk_print_matrix_area(y, GGML_TYPE_F32, k, n, first_err_m, first_err_n, first_err_b);
++
++        if (split_k > 1) {
++            float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
++            ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
++
++            std::cerr << "d_buf0: " << std::endl << std::endl;
++            ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
++
++            std::cerr << "d_buf1: " << std::endl << std::endl;
++            ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
++
++            std::cerr << "d_buf2: " << std::endl << std::endl;
++            ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
++
++            std::cerr << "d_buf3: " << std::endl << std::endl;
++            ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
++
++            free(split_k_buf);
++        }
++    }
++
++    ggml_vk_destroy_buffer(qx_buf);
++    ggml_vk_destroy_buffer(y_buf);
++    ggml_vk_destroy_buffer(qy_buf);
++    ggml_vk_destroy_buffer(d_buf);
++
++    free(x);
++    free(qx);
++    free(y);
++    free(d);
++    free(d_chk);
++}
++#endif
++
++static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_context subctx) {
++#if defined(GGML_VULKAN_RUN_TESTS)
++    const std::vector vals {
++        512, 512, 128,
++        128, 512, 512,
++        4096, 512, 4096,
++        11008, 512, 4096,
++        4096, 512, 11008,
++        32000, 512, 4096,
++        8, 8, 8,
++        100, 46, 576,
++        623, 111, 128,
++        100, 46, 558,
++        512, 1, 256,
++        128, 110, 622,
++        511, 511, 127,
++        511, 511, 7,
++        511, 511, 17,
++        49, 49, 128,
++        128, 49, 49,
++        4096, 49, 4096,
++    };
++    const size_t num_it = 100;
++
++    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0);
++    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0);
++    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0);
++
++    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0, true);
++    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0, true);
++    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0, true);
++
++    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0);
++    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0);
++    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0);
++
++    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0, true);
++    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0, true);
++    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0, true);
++
++    abort();
++
++    for (size_t i = 0; i < vals.size(); i += 3) {
++        ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
++        ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1);
++        ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2);
++        std::cerr << '\n';
++        ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0);
++        ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1);
++        ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2);
++        std::cerr << '\n';
++        ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0);
++        ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1);
++        ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2);
++        std::cerr << '\n' << std::endl;
++
++        if (vals[i + 2] % 32 == 0) {
++            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_0);
++            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_0);
++            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_0);
++            std::cerr << '\n';
++            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_0);
++            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_0);
++            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_0);
++            std::cerr << '\n';
++            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_0);
++            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_0);
++            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_0);
++            std::cerr << '\n' << std::endl;
++        }
++
++        if (vals[i + 2] % 256 == 0) {
++            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_K);
++            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_K);
++            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_K);
++            std::cerr << '\n';
++            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_K);
++            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_K);
++            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_K);
++            std::cerr << '\n';
++            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_K);
++            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_K);
++            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_K);
++            std::cerr << '\n' << std::endl;
++        }
++    }
++
++    GGML_ABORT("fatal error");
++#endif
++
++    if (subctx) {
++        // Submit and wait for any pending work before reallocating the buffers
++        ggml_vk_ctx_end(subctx);
++        ggml_vk_submit(subctx, {});
++        ctx->submit_pending = true;
++        ggml_vk_synchronize(ctx);
++        GGML_ASSERT(ctx->compute_ctx.expired());
++        ggml_vk_ctx_begin(ctx->device, subctx);
++        ctx->compute_ctx = subctx;
++    }
++
++    if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) {
++        VK_LOG_MEMORY("ggml_vk_preallocate_buffers(x_size: " << ctx->prealloc_size_x << ")");
++        // Resize buffer
++        if (ctx->prealloc_x != nullptr) {
++            ggml_vk_destroy_buffer(ctx->prealloc_x);
++        }
++        ctx->prealloc_x = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_x);
++    }
++    if (ctx->prealloc_y == nullptr || (ctx->prealloc_size_y > 0 && ctx->prealloc_y->size < ctx->prealloc_size_y)) {
++        VK_LOG_MEMORY("ggml_vk_preallocate_buffers(y_size: " << ctx->prealloc_size_y << ")");
++        // Resize buffer
++        if (ctx->prealloc_y != nullptr) {
++            ggml_vk_destroy_buffer(ctx->prealloc_y);
++        }
++        ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y);
++        ctx->prealloc_y_last_tensor_used = nullptr;
++    }
++    if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) {
++        VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")");
++        // Resize buffer
++        if (ctx->prealloc_split_k != nullptr) {
++            ggml_vk_destroy_buffer(ctx->prealloc_split_k);
++        }
++        ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k);
++    }
++    if (ctx->prealloc_add_rms_partials == nullptr || (ctx->prealloc_size_add_rms_partials > 0 && ctx->prealloc_add_rms_partials->size < ctx->prealloc_size_add_rms_partials)) {
++        VK_LOG_MEMORY("ggml_vk_preallocate_buffers(add_partials_size: " << ctx->prealloc_add_rms_partials << ")");
++        // Resize buffer
++        if (ctx->prealloc_add_rms_partials != nullptr) {
++            ggml_vk_destroy_buffer(ctx->prealloc_add_rms_partials);
++        }
++        ctx->prealloc_add_rms_partials = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_add_rms_partials);
++    }
++}
++
++static void ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool almost_ready);
++
++// Returns true if node has enqueued work into the queue, false otherwise
++// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
++static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool last_node, bool almost_ready, bool submit){
++    ggml_tensor * node = cgraph->nodes[node_idx];
++    if (ggml_is_empty(node) || ggml_op_is_empty(node->op) || !node->buffer) {
++        return false;
++    }
++    if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
++        return false;
++    }
++
++    VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")");
++    ctx->semaphore_idx = 0;
++
++    ggml_tensor * src0 = node->src[0];
++    ggml_tensor * src1 = node->src[1];
++    ggml_tensor * src2 = node->src[2];
++    ggml_tensor * src3 = node->src[3];
++
++    if (node->op == GGML_OP_ADD) {
++        int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops;
++        if (next_node_idx < cgraph->n_nodes &&
++            cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM &&
++            cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] &&
++            ggml_nrows(cgraph->nodes[next_node_idx]) == 1 &&
++            ctx->device->add_rms_fusion) {
++            uint32_t size = ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);
++            ctx->do_add_rms_partials_offset_calculation = true;
++            if (ctx->prealloc_size_add_rms_partials_offset + size <= ctx->prealloc_size_add_rms_partials) {
++                ctx->do_add_rms_partials = true;
++            }
++        }
++    }
++
++    vk_context compute_ctx;
++
++    if (ctx->compute_ctx.expired()) {
++        compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
++        ctx->compute_ctx = compute_ctx;
++        ggml_vk_ctx_begin(ctx->device, compute_ctx);
++    } else {
++        compute_ctx = ctx->compute_ctx.lock();
++    }
++
++    {
++        // This logic detects dependencies between modes in the graph and calls ggml_vk_sync_buffers
++        // to synchronize them. This handles most "normal" synchronization when computing the graph, and when
++        // there is no auxiliary memory use, it shouldn't be necessary to call ggml_vk_sync_buffers
++        // outside of this logic. When a node uses one of the prealloc buffers for something like
++        // dequantization or split_k, additional synchronization is needed between those passes.
++        bool need_sync = false;
++
++        // Check whether "node" requires synchronization. The node requires synchronization if it
++        // overlaps in memory with another unsynchronized node and at least one of them is a write.
++        // Destination nodes are checked against both the written/read lists. Source nodes are only
++        // checked against the written list. Two nodes overlap in memory if they come from the same
++        // buffer and the tensor or view ranges overlap.
++        auto const &overlaps_unsynced = [&](const ggml_tensor *node, const std::vector &unsynced_nodes) -> bool {
++            if (unsynced_nodes.size() == 0) {
++                return false;
++            }
++            auto n_base = vk_tensor_offset(node) + node->view_offs;
++            auto n_size = ggml_nbytes(node);
++            ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)node->buffer->context;
++            vk_buffer a_buf = a_buf_ctx->dev_buffer;
++            for (auto &other : unsynced_nodes) {
++                ggml_backend_vk_buffer_context * o_buf_ctx = (ggml_backend_vk_buffer_context *)other->buffer->context;
++                vk_buffer o_buf = o_buf_ctx->dev_buffer;
++                if (a_buf == o_buf) {
++                    auto o_base = vk_tensor_offset(other) + other->view_offs;
++                    auto o_size = ggml_nbytes(other);
++
++                    if ((o_base <= n_base && n_base < o_base + o_size) ||
++                        (n_base <= o_base && o_base < n_base + n_size)) {
++                        return true;
++                    }
++                }
++            }
++            return false;
++        };
++
++        // For all fused ops, check if the destination node or any of the source
++        // nodes require synchronization.
++        for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1 && !need_sync; ++i) {
++            const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
++            // If the node actually writes to memory, then check if it needs to sync
++            if (ctx->fused_ops_write_mask & (1 << i)) {
++                if (overlaps_unsynced(cur_node, ctx->unsynced_nodes_read) || overlaps_unsynced(cur_node, ctx->unsynced_nodes_written)) {
++                    need_sync = true;
++                    break;
++                }
++            }
++            for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
++                if (!cur_node->src[j]) {
++                    continue;
++                }
++                if (overlaps_unsynced(cur_node->src[j], ctx->unsynced_nodes_written)) {
++                    need_sync = true;
++                    break;
++                }
++            }
++        }
++
++        if (need_sync) {
++            if (vk_enable_sync_logger) {
++                std::cerr <<  "sync" << std::endl;
++            }
++            ctx->unsynced_nodes_written.clear();
++            ctx->unsynced_nodes_read.clear();
++            ggml_vk_sync_buffers(ctx, compute_ctx);
++
++            if (vk_perf_logger_enabled && vk_perf_logger_concurrent) {
++                ctx->query_node_idx[ctx->query_idx] = node_idx;
++                compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
++            }
++        }
++        // Add all fused nodes to the unsynchronized lists.
++        for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
++            const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
++            // Multiple outputs could be written, e.g. in topk_moe. Add them all to the list.
++            if (ctx->fused_ops_write_mask & (1 << i)) {
++                ctx->unsynced_nodes_written.push_back(cur_node);
++            }
++            for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
++                if (!cur_node->src[j]) {
++                    continue;
++                }
++                ctx->unsynced_nodes_read.push_back(cur_node->src[j]);
++            }
++        }
++    }
++    if (vk_enable_sync_logger) {
++        for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
++            auto *n = cgraph->nodes[node_idx + i];
++            std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " <<  n->name;
++            if (n->op == GGML_OP_GLU) {
++                std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " ";
++            }
++            if (n->op == GGML_OP_ROPE) {
++                const int mode = ((const int32_t *) n->op_params)[2];
++                std::cerr << " rope mode: " << mode;
++            }
++            std::cerr << std::endl;
++        }
++    }
++
++    switch (node->op) {
++    case GGML_OP_REPEAT:
++        ggml_vk_repeat(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_REPEAT_BACK:
++        ggml_vk_repeat_back(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_ACC:
++    case GGML_OP_SET:
++        ggml_vk_acc(ctx, compute_ctx, src0, src1, node);
++
++        break;
++    case GGML_OP_GET_ROWS:
++        ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node);
++
++        break;
++    case GGML_OP_ADD:
++        if (ctx->num_additional_fused_ops) {
++            ggml_vk_multi_add(ctx, compute_ctx, cgraph, node_idx);
++        } else {
++            ggml_vk_add(ctx, compute_ctx, src0, src1, node);
++        }
++        break;
++    case GGML_OP_SUB:
++        ggml_vk_sub(ctx, compute_ctx, src0, src1, node);
++
++        break;
++    case GGML_OP_MUL:
++        ggml_vk_mul(ctx, compute_ctx, src0, src1, node);
++
++        break;
++    case GGML_OP_DIV:
++        ggml_vk_div(ctx, compute_ctx, src0, src1, node);
++
++        break;
++    case GGML_OP_ADD_ID:
++        ggml_vk_add_id(ctx, compute_ctx, src0, src1, src2, node);
++
++        break;
++    case GGML_OP_CONCAT:
++        ggml_vk_concat(ctx, compute_ctx, src0, src1, node);
++
++        break;
++    case GGML_OP_UPSCALE:
++        ggml_vk_upscale(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_ADD1:
++        ggml_vk_add1(ctx, compute_ctx, src0, src1, node);
++
++        break;
++    case GGML_OP_ARANGE:
++        ggml_vk_arange(ctx, compute_ctx, node);
++
++        break;
++    case GGML_OP_FILL:
++        ggml_vk_fill(ctx, compute_ctx, node);
++
++        break;
++    case GGML_OP_SCALE:
++        ggml_vk_scale(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_SQR:
++        ggml_vk_sqr(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_SQRT:
++        ggml_vk_sqrt(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_SIN:
++        ggml_vk_sin(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_COS:
++        ggml_vk_cos(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_LOG:
++        ggml_vk_log(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_TRI:
++        ggml_vk_tri(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_DIAG:
++        ggml_vk_diag(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_CLAMP:
++        ggml_vk_clamp(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_PAD:
++        ggml_vk_pad(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_ROLL:
++        ggml_vk_roll(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_CPY:
++    case GGML_OP_CONT:
++    case GGML_OP_DUP:
++        ggml_vk_cpy(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_SET_ROWS:
++        ggml_vk_set_rows(ctx, compute_ctx, src0, src1, node);
++
++        break;
++    case GGML_OP_SILU_BACK:
++        ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node);
++
++        break;
++    case GGML_OP_NORM:
++        ggml_vk_norm(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_GROUP_NORM:
++        ggml_vk_group_norm(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_RMS_NORM:
++        ggml_vk_rms_norm(ctx, compute_ctx, cgraph, node_idx, (float *)node->op_params);
++        break;
++    case GGML_OP_RMS_NORM_BACK:
++        ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node);
++
++        break;
++    case GGML_OP_L2_NORM:
++        ggml_vk_l2_norm(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_UNARY:
++        if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
++            ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
++            break;
++        }
++
++        switch (ggml_get_unary_op(node)) {
++        case GGML_UNARY_OP_EXP:
++        case GGML_UNARY_OP_SILU:
++        case GGML_UNARY_OP_GELU:
++        case GGML_UNARY_OP_GELU_ERF:
++        case GGML_UNARY_OP_GELU_QUICK:
++        case GGML_UNARY_OP_RELU:
++        case GGML_UNARY_OP_NEG:
++        case GGML_UNARY_OP_TANH:
++        case GGML_UNARY_OP_SIGMOID:
++        case GGML_UNARY_OP_HARDSIGMOID:
++        case GGML_UNARY_OP_HARDSWISH:
++        case GGML_UNARY_OP_ABS:
++        case GGML_UNARY_OP_SOFTPLUS:
++        case GGML_UNARY_OP_STEP:
++        case GGML_UNARY_OP_ROUND:
++        case GGML_UNARY_OP_CEIL:
++        case GGML_UNARY_OP_FLOOR:
++        case GGML_UNARY_OP_TRUNC:
++            ggml_vk_unary(ctx, compute_ctx, src0, node);
++            break;
++        case GGML_UNARY_OP_XIELU:
++            ggml_vk_xielu(ctx, compute_ctx, src0, node);
++            break;
++        default:
++            return false;
++        }
++        break;
++    case GGML_OP_GLU:
++        switch (ggml_get_glu_op(node)) {
++        case GGML_GLU_OP_GEGLU:
++        case GGML_GLU_OP_REGLU:
++        case GGML_GLU_OP_SWIGLU:
++        case GGML_GLU_OP_SWIGLU_OAI:
++        case GGML_GLU_OP_GEGLU_ERF:
++        case GGML_GLU_OP_GEGLU_QUICK:
++            ggml_vk_glu(ctx, compute_ctx, src0, src1, node);
++            break;
++        default:
++            return false;
++        }
++        break;
++    case GGML_OP_DIAG_MASK_INF:
++        ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_SOFT_MAX:
++        if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
++            ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
++        } else {
++            ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node);
++        }
++
++        break;
++    case GGML_OP_SOFT_MAX_BACK:
++        ggml_vk_soft_max_back(ctx, compute_ctx, src0, src1, node);
++
++        break;
++    case GGML_OP_ROPE:
++        ggml_vk_rope(ctx, compute_ctx, cgraph, node_idx, false);
++
++        break;
++    case GGML_OP_ROPE_BACK:
++        ggml_vk_rope(ctx, compute_ctx, cgraph, node_idx, true);
++
++        break;
++    case GGML_OP_ARGSORT:
++        if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
++            ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
++        } else {
++            ggml_vk_argsort(ctx, compute_ctx, src0, node);
++        }
++
++        break;
++    case GGML_OP_TOP_K:
++        ggml_vk_topk(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_SUM:
++        ggml_vk_sum(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_SUM_ROWS:
++        ggml_vk_sum_rows(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_CUMSUM:
++        ggml_vk_cumsum(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_MEAN:
++        ggml_vk_mean(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_ARGMAX:
++        ggml_vk_argmax(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_COUNT_EQUAL:
++        ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node);
++
++        break;
++    case GGML_OP_SOLVE_TRI:
++        ggml_vk_solve_tri(ctx, compute_ctx, src0, src1, node);
++
++        break;
++    case GGML_OP_IM2COL:
++        ggml_vk_im2col(ctx, compute_ctx, src0, src1, node);
++
++        break;
++    case GGML_OP_IM2COL_3D:
++        ggml_vk_im2col_3d(ctx, compute_ctx, src0, src1, node);
++
++        break;
++    case GGML_OP_TIMESTEP_EMBEDDING:
++        ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_CONV_TRANSPOSE_1D:
++        ggml_vk_conv_transpose_1d(ctx, compute_ctx, src0, src1, node);
++
++        break;
++    case GGML_OP_POOL_2D:
++        ggml_vk_pool_2d(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_CONV_2D:
++    case GGML_OP_CONV_TRANSPOSE_2D:
++        ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node);
++
++        break;
++    case GGML_OP_CONV_2D_DW:
++        ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node);
++
++        break;
++    case GGML_OP_LEAKY_RELU:
++        ggml_vk_leaky_relu(ctx, compute_ctx, src0, node);
++
++        break;
++    case GGML_OP_MUL_MAT:
++        ggml_vk_mul_mat(ctx, compute_ctx, cgraph, node_idx);
++
++        break;
++    case GGML_OP_MUL_MAT_ID:
++        ggml_vk_mul_mat_id(ctx, compute_ctx, cgraph, node_idx);
++
++        break;
++
++    case GGML_OP_FLASH_ATTN_EXT:
++        ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node->src[4], node);
++
++        break;
++
++    case GGML_OP_RWKV_WKV6:
++        ggml_vk_rwkv_wkv6(ctx, compute_ctx, node);
++
++        break;
++
++    case GGML_OP_RWKV_WKV7:
++        ggml_vk_rwkv_wkv7(ctx, compute_ctx, node);
++
++        break;
++
++    case GGML_OP_SSM_SCAN:
++        ggml_vk_ssm_scan(ctx, compute_ctx, node);
++
++        break;
++
++    case GGML_OP_SSM_CONV:
++        ggml_vk_ssm_conv(ctx, compute_ctx, node);
++
++        break;
++
++    case GGML_OP_OPT_STEP_ADAMW:
++        ggml_vk_opt_step_adamw(ctx, compute_ctx, node);
++
++        break;
++
++    case GGML_OP_OPT_STEP_SGD:
++        ggml_vk_opt_step_sgd(ctx, compute_ctx, src0, src1, src2, node);
++
++        break;
++    default:
++        return false;
++    }
++
++    ctx->tensor_ctxs[node_idx] = compute_ctx;
++
++#if defined(GGML_VULKAN_CHECK_RESULTS)
++    // Force context reset on each node so that each tensor ends up in its own context
++    // and can be run and compared to its CPU equivalent separately
++    last_node = true;
++#endif
++
++    if (submit || last_node) {
++        ggml_vk_ctx_end(compute_ctx);
++
++        // TODO probably it'd be better to pass a exit_node flag to ggml_vk_compute_forward
++        if (last_node) {
++            compute_ctx->exit_tensor_idx = node_idx_begin;
++        }
++        else {
++            compute_ctx->exit_tensor_idx = -1;
++        }
++
++        ctx->compute_ctx.reset();
++
++        ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, almost_ready);
++    }
++    return true;
++}
++
++static void ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool almost_ready = false) {
++    GGML_UNUSED(cgraph);
++    GGML_UNUSED(tensor);
++
++    VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")");
++
++    vk_context subctx = ctx->tensor_ctxs[tensor_idx].lock();
++
++    // Only run if ctx hasn't been submitted yet
++    if (!subctx->seqs.empty()) {
++#ifdef GGML_VULKAN_CHECK_RESULTS
++        ggml_vk_check_results_0(ctx, cgraph, tensor_idx);
++#endif
++
++        // Do staging buffer copies
++        for (auto& cpy : subctx->in_memcpys) {
++            memcpy(cpy.dst, cpy.src, cpy.n);
++        }
++
++        for (auto& mset : subctx->memsets) {
++            memset(mset.dst, mset.val, mset.n);
++        }
++
++        if (almost_ready && !ctx->almost_ready_fence_pending) {
++            ggml_vk_submit(subctx, ctx->almost_ready_fence);
++            ctx->almost_ready_fence_pending = true;
++        } else {
++            ggml_vk_submit(subctx, {});
++        }
++        ctx->submit_pending = true;
++
++#ifdef GGML_VULKAN_CHECK_RESULTS
++        ggml_vk_synchronize(ctx);
++        ggml_vk_check_results_1(ctx, cgraph, tensor_idx);
++#endif
++    }
++
++    if (tensor_idx == subctx->exit_tensor_idx) {
++        // Do staging buffer copies
++        for (auto& cpy : subctx->out_memcpys) {
++            memcpy(cpy.dst, cpy.src, cpy.n);
++        }
++        subctx->in_memcpys.clear();
++        subctx->out_memcpys.clear();
++        subctx->memsets.clear();
++    }
++}
++
++// Clean up after graph processing is done
++static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
++    VK_LOG_DEBUG("ggml_vk_graph_cleanup()");
++    ctx->prealloc_y_last_pipeline_used = {};
++
++    ctx->unsynced_nodes_written.clear();
++    ctx->unsynced_nodes_read.clear();
++    ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
++
++    ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
++
++    for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) {
++        ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s });
++    }
++    ctx->gc.semaphores.clear();
++
++    for (size_t i = 0; i < ctx->gc.tl_semaphores.size(); i++) {
++        ctx->device->device.destroySemaphore({ ctx->gc.tl_semaphores[i].s });
++    }
++    ctx->gc.tl_semaphores.clear();
++    ctx->semaphore_idx = 0;
++
++    ctx->event_idx = 0;
++
++    for (auto& event : ctx->gc.events) {
++        ctx->device->device.resetEvent(event);
++    }
++
++    ctx->tensor_ctxs.clear();
++    ctx->gc.contexts.clear();
++    ctx->pipeline_descriptor_set_requirements = 0;
++    ctx->descriptor_set_idx = 0;
++}
++
++// Clean up on backend free
++static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
++    VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")");
++    // discard any unsubmitted command buffers
++    ctx->compute_ctx.reset();
++    // wait for any pending command buffers to finish
++    ggml_vk_synchronize(ctx);
++
++    ggml_vk_graph_cleanup(ctx);
++
++    ggml_vk_destroy_buffer(ctx->prealloc_x);
++    ggml_vk_destroy_buffer(ctx->prealloc_y);
++    ggml_vk_destroy_buffer(ctx->prealloc_split_k);
++    ggml_vk_destroy_buffer(ctx->prealloc_add_rms_partials);
++    ggml_vk_destroy_buffer(ctx->sync_staging);
++
++    ctx->prealloc_y_last_pipeline_used = nullptr;
++
++    ctx->prealloc_size_x = 0;
++    ctx->prealloc_size_y = 0;
++    ctx->prealloc_size_split_k = 0;
++
++    for (auto& event : ctx->gc.events) {
++        ctx->device->device.destroyEvent(event);
++    }
++    ctx->gc.events.clear();
++
++    ctx->device->device.destroyFence(ctx->fence);
++    ctx->device->device.destroyFence(ctx->almost_ready_fence);
++
++    for (auto& pool : ctx->descriptor_pools) {
++        ctx->device->device.destroyDescriptorPool(pool);
++    }
++    ctx->descriptor_pools.clear();
++    ctx->descriptor_sets.clear();
++
++    ctx->compute_cmd_pool.destroy(ctx->device->device);
++    if (vk_perf_logger_enabled) {
++        ctx->perf_logger->print_timings(true);
++    }
++}
++
++static int ggml_vk_get_device_count() {
++    ggml_vk_instance_init();
++
++    return vk_instance.device_indices.size();
++}
++
++static void ggml_vk_get_device_description(int device, char * description, size_t description_size) {
++    ggml_vk_instance_init();
++
++    std::vector devices = vk_instance.instance.enumeratePhysicalDevices();
++
++    vk::PhysicalDeviceProperties props;
++    devices[device].getProperties(&props);
++
++    snprintf(description, description_size, "%s", props.deviceName.data());
++}
++
++// backend interface
++
++#define UNUSED GGML_UNUSED
++
++// device backend
++
++static bool ggml_backend_buffer_is_vk(ggml_backend_buffer_t buffer) {
++    return buffer->buft->iface.get_name == ggml_backend_vk_buffer_type_name;
++}
++
++static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) {
++    VK_LOG_MEMORY("ggml_backend_vk_buffer_free_buffer()");
++    ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
++    ggml_vk_destroy_buffer(ctx->dev_buffer);
++    delete ctx;
++    delete buffer;
++}
++
++static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) {
++    return vk_ptr_base;
++
++    UNUSED(buffer);
++}
++
++static enum ggml_status ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
++    VK_LOG_DEBUG("ggml_backend_vk_buffer_init_tensor(" << buffer << " (" << buffer->context << "), " << tensor << ")");
++    if (tensor->view_src != nullptr) {
++        GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
++    }
++    return GGML_STATUS_SUCCESS;
++}
++
++static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
++    VK_LOG_DEBUG("ggml_backend_vk_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")");
++    ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
++    vk_buffer buf = buf_ctx->dev_buffer;
++
++    uint32_t val32 = (uint32_t)value * 0x01010101;
++    ggml_vk_buffer_memset(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, val32, size);
++}
++
++static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
++    VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
++    ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
++    vk_buffer buf = buf_ctx->dev_buffer;
++
++    ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
++}
++
++static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
++    VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
++    ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
++
++    vk_buffer buf = buf_ctx->dev_buffer;
++
++    ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
++}
++
++static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
++    if (ggml_backend_buffer_is_vk(src->buffer)) {
++        ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
++        ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
++
++        vk_buffer src_buf = src_buf_ctx->dev_buffer;
++        vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
++
++        ggml_vk_buffer_copy(dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src));
++
++        return true;
++    }
++    return false;
++
++    UNUSED(buffer);
++}
++
++static void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
++    ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
++
++    ggml_vk_buffer_memset(ctx->dev_buffer, 0, value, buffer->size);
++}
++
++static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
++    /* .free_buffer     = */ ggml_backend_vk_buffer_free_buffer,
++    /* .get_base        = */ ggml_backend_vk_buffer_get_base,
++    /* .init_tensor     = */ ggml_backend_vk_buffer_init_tensor,
++    /* .memset_tensor   = */ ggml_backend_vk_buffer_memset_tensor,
++    /* .set_tensor      = */ ggml_backend_vk_buffer_set_tensor,
++    /* .get_tensor      = */ ggml_backend_vk_buffer_get_tensor,
++    /* .cpy_tensor      = */ ggml_backend_vk_buffer_cpy_tensor,
++    /* .clear           = */ ggml_backend_vk_buffer_clear,
++    /* .reset           = */ NULL,
++};
++
++// vk buffer type
++static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft) {
++    ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *)buft->context;
++
++    return ctx->name.c_str();
++}
++
++static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
++    VK_LOG_MEMORY("ggml_backend_vk_buffer_type_alloc_buffer(" << size << ")");
++    ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
++
++    vk_buffer dev_buffer = nullptr;
++    try {
++        dev_buffer = ggml_vk_create_buffer_device(ctx->device, size);
++    } catch (const vk::SystemError& e) {
++        return nullptr;
++    }
++
++    ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(ctx->device, std::move(dev_buffer), ctx->name);
++
++    return ggml_backend_buffer_init(buft, ggml_backend_vk_buffer_interface, bufctx, size);
++}
++
++static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
++    ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
++    return ctx->device->properties.limits.minStorageBufferOffsetAlignment;
++}
++
++static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
++    ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
++    return ctx->device->suballocation_block_size;
++}
++
++static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
++    return ggml_nbytes(tensor);
++
++    UNUSED(buft);
++}
++
++ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num) {
++    ggml_vk_instance_init();
++
++    VK_LOG_DEBUG("ggml_backend_vk_buffer_type(" << dev_num << ")");
++
++    vk_device dev = ggml_vk_get_device(dev_num);
++
++    return &dev->buffer_type;
++}
++
++// host buffer type
++
++static const char * ggml_backend_vk_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
++    return GGML_VK_NAME "_Host";
++
++    UNUSED(buft);
++}
++
++static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffer) {
++    return GGML_VK_NAME "_Host";
++
++    UNUSED(buffer);
++}
++
++static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
++    VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()");
++    ggml_vk_host_free(vk_instance.devices[0], buffer->context);
++    delete buffer;
++}
++
++static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
++    VK_LOG_MEMORY("ggml_backend_vk_host_buffer_type_alloc_buffer(" << size << ")");
++
++    size += 32;  // Behave like the CPU buffer type
++    void * ptr = nullptr;
++    try {
++        ptr = ggml_vk_host_malloc(vk_instance.devices[0], size);
++    } catch (vk::SystemError& e) {
++        GGML_LOG_WARN("ggml_vulkan: Failed to allocate pinned memory (%s)\n", e.what());
++        // fallback to cpu buffer
++        return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
++    }
++
++    ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
++    buffer->buft = buft;
++    buffer->iface.free_buffer = ggml_backend_vk_host_buffer_free_buffer;
++
++    return buffer;
++
++    UNUSED(buft);
++}
++
++static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
++    return vk_instance.devices[0]->properties.limits.minMemoryMapAlignment;
++
++    UNUSED(buft);
++}
++
++static size_t ggml_backend_vk_host_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
++    return vk_instance.devices[0]->suballocation_block_size;
++
++    UNUSED(buft);
++}
++
++// Should be changed to return device-specific host buffer type
++// but that probably requires changes in llama.cpp
++ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
++    static struct ggml_backend_buffer_type ggml_backend_vk_buffer_type_host = {
++        /* .iface    = */ {
++            /* .get_name         = */ ggml_backend_vk_host_buffer_type_name,
++            /* .alloc_buffer     = */ ggml_backend_vk_host_buffer_type_alloc_buffer,
++            /* .get_alignment    = */ ggml_backend_vk_host_buffer_type_get_alignment,
++            /* .get_max_size     = */ ggml_backend_vk_host_buffer_type_get_max_size,
++            /* .get_alloc_size   = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
++            /* .is_host          = */ ggml_backend_cpu_buffer_type()->iface.is_host,
++        },
++        /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), 0),
++        /* .context  = */ nullptr,
++    };
++
++    // Make sure device 0 is initialized
++    ggml_vk_instance_init();
++    ggml_vk_get_device(0);
++
++    return &ggml_backend_vk_buffer_type_host;
++}
++
++
++// backend
++
++static const char * ggml_backend_vk_name(ggml_backend_t backend) {
++    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
++
++    return ctx->name.c_str();
++}
++
++static void ggml_backend_vk_free(ggml_backend_t backend) {
++    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
++    VK_LOG_DEBUG("ggml_backend_vk_free(" << ctx->name << ")");
++
++    ggml_vk_cleanup(ctx);
++
++    delete ctx;
++    delete backend;
++}
++
++static ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_backend_t backend) {
++    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
++
++    return &ctx->device->buffer_type;
++}
++
++static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
++    VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")");
++    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
++    GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
++
++    ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
++
++    vk_context compute_ctx;
++
++    if (ctx->compute_ctx.expired()) {
++        // Initialize new transfer context
++        compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
++        ctx->compute_ctx = compute_ctx;
++        ggml_vk_ctx_begin(ctx->device, compute_ctx);
++    } else {
++        compute_ctx = ctx->compute_ctx.lock();
++    }
++
++    vk_buffer buf = buf_ctx->dev_buffer;
++
++    auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
++
++    bool ret = ggml_vk_buffer_write_async(compute_ctx, buf, dst_offset, data, size);
++
++    if (!ret) {
++        ggml_vk_ensure_sync_staging_buffer(ctx, size);
++        ggml_vk_sync_buffers(nullptr, compute_ctx);
++
++        vk::BufferCopy buffer_cpy;
++        buffer_cpy.srcOffset = 0;
++        buffer_cpy.dstOffset = dst_offset;
++        buffer_cpy.size = size;
++
++        compute_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy });
++        deferred_memcpy(ctx->sync_staging->ptr, data, size, &compute_ctx->in_memcpys);
++        ggml_vk_synchronize(ctx);
++    }
++}
++
++static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
++    VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")");
++    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
++    GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
++
++    ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
++
++    vk_context compute_ctx;
++
++    if (ctx->compute_ctx.expired()) {
++        // Initialize new transfer context
++        compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
++        ctx->compute_ctx = compute_ctx;
++        ggml_vk_ctx_begin(ctx->device, compute_ctx);
++    } else {
++        compute_ctx = ctx->compute_ctx.lock();
++    }
++
++    vk_buffer buf = buf_ctx->dev_buffer;
++
++    auto src_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
++    bool ret = ggml_vk_buffer_read_async(compute_ctx, buf, src_offset, data, size);
++
++    // If that failed, copy synchronously through a staging buffer
++    if (!ret) {
++        ggml_vk_ensure_sync_staging_buffer(ctx, size);
++        ggml_vk_sync_buffers(nullptr, compute_ctx);
++
++        vk::BufferCopy buffer_cpy;
++        buffer_cpy.srcOffset = src_offset;
++        buffer_cpy.dstOffset = 0;
++        buffer_cpy.size = size;
++
++        compute_ctx->s->buffer.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy });
++        deferred_memcpy(data, ctx->sync_staging->ptr, size, &compute_ctx->out_memcpys);
++        ggml_vk_synchronize(ctx);
++    }
++}
++
++static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) {
++    VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()");
++    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
++    if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(src->buffer)) {
++        ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
++        ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
++
++        vk_context compute_ctx;
++
++        if (ctx->compute_ctx.expired()) {
++            // Initialize new transfer context
++            compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
++            ctx->compute_ctx = compute_ctx;
++            ggml_vk_ctx_begin(ctx->device, compute_ctx);
++        } else {
++            compute_ctx = ctx->compute_ctx.lock();
++        }
++
++        vk_buffer src_buf = src_buf_ctx->dev_buffer;
++        vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
++
++        ggml_vk_buffer_copy_async(compute_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src));
++        return true;
++    }
++
++    return false;
++}
++
++static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) {
++    VK_LOG_DEBUG("ggml_vk_synchronize()");
++
++    bool do_transfer = !ctx->compute_ctx.expired();
++
++    vk_context compute_ctx;
++    if (do_transfer) {
++        compute_ctx = ctx->compute_ctx.lock();
++
++        ggml_vk_ctx_end(compute_ctx);
++
++        for (auto& cpy : compute_ctx->in_memcpys) {
++            memcpy(cpy.dst, cpy.src, cpy.n);
++        }
++
++        ggml_vk_submit(compute_ctx, {});
++        ctx->submit_pending = true;
++    }
++
++    if (ctx->submit_pending) {
++        {
++            std::lock_guard guard(queue_mutex);
++            ctx->device->compute_queue.queue.submit({}, ctx->fence);
++        }
++        ggml_vk_wait_for_fence(ctx);
++        ctx->submit_pending = false;
++    }
++
++    if (do_transfer) {
++        for (auto& cpy : compute_ctx->out_memcpys) {
++            memcpy(cpy.dst, cpy.src, cpy.n);
++        }
++        ctx->compute_ctx.reset();
++    }
++}
++
++static void ggml_backend_vk_synchronize(ggml_backend_t backend) {
++    VK_LOG_DEBUG("ggml_backend_vk_synchronize()");
++    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
++
++    ggml_vk_synchronize(ctx);
++
++    ggml_vk_graph_cleanup(ctx);
++}
++
++static bool ggml_vk_is_empty(ggml_tensor * node) {
++    return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
++}
++
++static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops) {
++    if (!ggml_can_fuse(cgraph, node_idx, ops)) {
++        return false;
++    }
++
++    if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
++        // additional constraints specific to this fusion
++        const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
++        const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
++
++        GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
++        GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
++        // rms_norm only supports f32
++        if (mul->src[0]->type != GGML_TYPE_F32 ||
++            mul->src[1]->type != GGML_TYPE_F32 ||
++            mul->type != GGML_TYPE_F32) {
++            return false;
++        }
++        // if rms_norm is the B operand, then we don't handle broadcast
++        if (rms_norm == mul->src[1] &&
++            !ggml_are_same_shape(mul->src[0], rms_norm)) {
++            return false;
++        }
++        // rms_norm shader assumes contiguous rows
++        if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
++            return false;
++        }
++    }
++    auto const &mm_add_ok = [&](const ggml_tensor *mul, const ggml_tensor *add) {
++        const ggml_tensor *bias = add->src[0] == mul ? add->src[1] : add->src[0];
++
++        // mat-vec only
++        if (ggml_nrows(mul) != 1) {
++            return false;
++        }
++        // shaders assume the types match
++        if (mul->type != bias->type) {
++            return false;
++        }
++        // shaders reuse the D shape for bias
++        if (!ggml_are_same_shape(mul, bias) ||
++            !ggml_are_same_stride(mul, bias)) {
++            return false;
++        }
++        // unaligned bias isn't handled
++        if (get_misalign_bytes(ctx, bias) != 0) {
++            return false;
++        }
++        return true;
++    };
++
++    if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_MUL_MAT && ops.begin()[1] == GGML_OP_ADD) {
++        // additional constraints specific to this fusion
++        const ggml_tensor *mul = cgraph->nodes[node_idx];
++        const ggml_tensor *add = cgraph->nodes[node_idx + 1];
++
++        if (!mm_add_ok(mul, add)) {
++            return false;
++        }
++        if (ops.size() == 3) {
++            if (ops.begin()[2] != GGML_OP_ADD) {
++                return false;
++            }
++            if (!mm_add_ok(add, cgraph->nodes[node_idx + 2])) {
++                return false;
++            }
++        }
++    }
++
++    auto const &mmid_mul_ok = [&](const ggml_tensor *mmid, const ggml_tensor *mul) {
++        const ggml_tensor *scale = mul->src[1];
++
++        if (mmid != mul->src[0]) {
++            return false;
++        }
++        // mat-vec only
++        if (!ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) {
++            return false;
++        }
++        // shaders assume the types match
++        if (mmid->type != scale->type) {
++            return false;
++        }
++        // shaders assume the bias is contiguous
++        if (!ggml_is_contiguous(scale)) {
++            return false;
++        }
++        // unaligned bias isn't handled
++        if (get_misalign_bytes(ctx, scale) != 0) {
++            return false;
++        }
++        // shader only indexes by expert index
++        if (scale->ne[0] != 1 ||
++            scale->ne[1] != mul->ne[1] ||
++            scale->ne[2] != 1 ||
++            scale->ne[3] != 1) {
++            return false;
++        }
++        return true;
++    };
++
++    if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_ADD_ID) {
++        // additional constraints specific to this fusion
++        const ggml_tensor *mul = cgraph->nodes[node_idx];
++        const ggml_tensor *add = cgraph->nodes[node_idx + 1];
++        const ggml_tensor *bias = add->src[1];
++
++        if (mul != add->src[0]) {
++            return false;
++        }
++        // mat-vec only
++        if (!ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) {
++            return false;
++        }
++        // shaders assume the types match
++        if (mul->type != bias->type) {
++            return false;
++        }
++        // shaders assume the bias is contiguous
++        if (!ggml_is_contiguous(bias)) {
++            return false;
++        }
++        // the ID tensor must be the same for mul_mat_id and add_id
++        if (mul->src[2] != add->src[2]) {
++            return false;
++        }
++        // unaligned bias isn't handled
++        if (get_misalign_bytes(ctx, bias) != 0) {
++            return false;
++        }
++
++        if (ops.size() == 3) {
++            if (ops.begin()[2] != GGML_OP_MUL) {
++                return false;
++            }
++            const ggml_tensor *mul = cgraph->nodes[node_idx + 2];
++            return mmid_mul_ok(add, mul);
++        }
++    }
++
++    if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_MUL) {
++        // additional constraints specific to this fusion
++        const ggml_tensor *mmid = cgraph->nodes[node_idx];
++        const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
++
++        if (!mmid_mul_ok(mmid, mul)) {
++            return false;
++        }
++    }
++
++    return true;
++}
++
++static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
++                                      int node_idx, topk_moe_mode mode) {
++
++    const ggml_tensor * softmax;
++    const ggml_tensor * weights;
++    const ggml_tensor * get_rows;
++    const ggml_tensor * argsort;
++
++    switch (mode) {
++    case TOPK_MOE_EARLY_SOFTMAX_NORM:
++        softmax = cgraph->nodes[node_idx + 0];
++        weights = cgraph->nodes[node_idx + 9];
++        get_rows = cgraph->nodes[node_idx + 4];
++        argsort = cgraph->nodes[node_idx + 2];
++        break;
++    case TOPK_MOE_SIGMOID_NORM_BIAS:
++        softmax = cgraph->nodes[node_idx + 0]; // really sigmoid
++        weights = cgraph->nodes[node_idx + 10];
++        get_rows = cgraph->nodes[node_idx + 5];
++        argsort = cgraph->nodes[node_idx + 3];
++        if (ggml_get_unary_op(softmax) != GGML_UNARY_OP_SIGMOID) {
++            return false;
++        }
++        // bias is expected to be 1D
++        if (ggml_nrows(cgraph->nodes[node_idx + 2]->src[1]) != 1 ||
++            !ggml_is_contiguous(cgraph->nodes[node_idx + 2]->src[1])) {
++            return false;
++        }
++        // sigmoid fusion seems to generate infinities on moltenvk
++        if (ctx->device->driver_id == vk::DriverId::eMoltenvk) {
++            return false;
++        }
++        break;
++    case TOPK_MOE_EARLY_SOFTMAX:
++        softmax = cgraph->nodes[node_idx + 0];
++        weights = cgraph->nodes[node_idx + 4];
++        get_rows = cgraph->nodes[node_idx + 4];
++        argsort = cgraph->nodes[node_idx + 2];
++        break;
++    case TOPK_MOE_LATE_SOFTMAX:
++        softmax = cgraph->nodes[node_idx + 4];
++        weights = cgraph->nodes[node_idx + 5];
++        get_rows = cgraph->nodes[node_idx + 2];
++        argsort = cgraph->nodes[node_idx + 0];
++        break;
++    default:
++        return false;
++    }
++
++    ggml_tensor * probs = get_rows->src[0];
++    if (probs->op != GGML_OP_RESHAPE) {
++        return false;
++    }
++    probs = probs->src[0];
++    ggml_tensor * selection_probs = argsort->src[0];
++
++    if (probs != selection_probs && mode != TOPK_MOE_SIGMOID_NORM_BIAS) {
++        return false;
++    }
++
++    if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
++        return false;
++    }
++
++    if (softmax->op == GGML_OP_SOFT_MAX) {
++        const float * op_params = (const float *)softmax->op_params;
++
++        float scale = op_params[0];
++        float max_bias = op_params[1];
++
++        if (scale != 1.0f || max_bias != 0.0f) {
++            return false;
++        }
++
++        // don't fuse when masks or sinks are present
++        if (softmax->src[1] || softmax->src[2]) {
++            return false;
++        }
++    }
++
++    const int n_expert = softmax->ne[0];
++    if (n_expert > (1 << (num_topk_moe_pipelines-1))) {
++        return false;
++    }
++
++    if (!ctx->device->subgroup_arithmetic ||
++        !ctx->device->subgroup_shuffle ||
++        !ctx->device->subgroup_require_full_support ||
++        ctx->device->disable_fusion) {
++        return false;
++    }
++
++    return true;
++}
++
++static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
++                                           int node_idx) {
++    GGML_UNUSED(ctx);
++    const ggml_tensor *rope = cgraph->nodes[node_idx + 0];
++    const ggml_tensor *view = cgraph->nodes[node_idx + 1];
++    const ggml_tensor *set_rows = cgraph->nodes[node_idx + 2];
++
++    // ne3 not tested
++    if (rope->src[0]->ne[3] != 1) {
++        return false;
++    }
++
++    if (set_rows->type != GGML_TYPE_F32 && set_rows->type != GGML_TYPE_F16) {
++        return false;
++    }
++
++    if (set_rows->src[1]->type != GGML_TYPE_I64) {
++        return false;
++    }
++
++    // The view should flatten two dims of rope into one dim
++    if (!ggml_is_contiguous(view) ||
++        view->ne[0] != rope->ne[0] * rope->ne[1]) {
++        return false;
++    }
++
++    // Only norm/neox/mrope shaders have the fusion code
++    const int mode = ((const int32_t *) rope->op_params)[2];
++    if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_MROPE) {
++        return false;
++    }
++
++    return true;
++}
++
++// Check whether the tensors overlap in memory.
++// Fusions can potentially overwrite src tensors in ways that are not prevented
++// by ggml-alloc. If the fusion src is being applied in a way that's elementwise
++// with the destination, then it's OK for them to overlap if they are exactly equal.
++static bool ggml_vk_tensors_overlap(const ggml_tensor * a, const ggml_tensor * b, bool elementwise) {
++    ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)a->buffer->context;
++    vk_buffer a_buf = a_buf_ctx->dev_buffer;
++    ggml_backend_vk_buffer_context * b_buf_ctx = (ggml_backend_vk_buffer_context *)b->buffer->context;
++    vk_buffer b_buf = b_buf_ctx->dev_buffer;
++    if (a_buf == b_buf) {
++        auto a_base = vk_tensor_offset(a) + a->view_offs;
++        auto a_size = ggml_nbytes(a);
++        auto b_base = vk_tensor_offset(b) + b->view_offs;
++        auto b_size = ggml_nbytes(b);
++
++        if (elementwise && a_base == b_base && a_size == b_size) {
++            return false;
++        }
++
++        if ((b_base <= a_base && a_base < b_base + b_size) ||
++            (a_base <= b_base && b_base < a_base + a_size)) {
++            return true;
++        }
++    }
++    return false;
++}
++
++static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
++                                               int node_idx) {
++    GGML_UNUSED(ctx);
++    const ggml_tensor *rms = cgraph->nodes[node_idx + 0];
++    const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
++    const ggml_tensor *rope = cgraph->nodes[node_idx + 2];
++
++    const int mode = ((const int32_t *) rope->op_params)[2];
++
++    // noncontig tensors aren't tested, and don't seem common in practice
++    if (!ggml_is_contiguous(rms) ||
++        !ggml_is_contiguous(mul) ||
++        !ggml_is_contiguous(rope)) {
++        return false;
++    }
++
++    // only norm/neox are handled in the shader
++    if (mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_NORMAL) {
++        return false;
++    }
++
++    // shared memory size for passing data from mul->rope
++    if (mul->ne[0] > 1024) {
++        return false;
++    }
++
++    // conditions for pipeline creation
++    if (!(ctx->device->float_controls_rte_fp16 &&
++        sizeof(vk_op_rms_norm_mul_rope_push_constants) <= ctx->device->properties.limits.maxPushConstantsSize)) {
++        return false;
++    }
++
++    return true;
++}
++
++static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
++
++    const ggml_tensor *first_node = cgraph->nodes[node_idx];
++    if (first_node->op != GGML_OP_ADD) {
++        return 0;
++    }
++
++    if (!ctx->device->multi_add) {
++        return 0;
++    }
++
++    int32_t num_adds = 1;
++    while (node_idx + num_adds < cgraph->n_nodes &&
++           cgraph->nodes[node_idx + num_adds]->op == GGML_OP_ADD &&
++           num_adds < MAX_FUSED_ADDS) {
++        num_adds++;
++    }
++
++    // The shader currently requires same shapes (but different strides are allowed),
++    // everything f32, and no misalignment
++    for (int32_t i = 0; i < num_adds; ++i) {
++        const ggml_tensor *next_node = cgraph->nodes[node_idx + i];
++        if (!ggml_are_same_shape(first_node, next_node->src[0]) ||
++            !ggml_are_same_shape(first_node, next_node->src[1]) ||
++            next_node->type != GGML_TYPE_F32 ||
++            next_node->src[0]->type != GGML_TYPE_F32 ||
++            next_node->src[1]->type != GGML_TYPE_F32 ||
++            get_misalign_bytes(ctx, next_node) ||
++            get_misalign_bytes(ctx, next_node->src[0]) ||
++            get_misalign_bytes(ctx, next_node->src[1])) {
++            num_adds = i;
++        }
++    }
++
++    // Verify we can fuse these
++    ggml_op adds[MAX_FUSED_ADDS];
++    for (int32_t i = 0; i < num_adds; ++i) {
++        adds[i] = GGML_OP_ADD;
++    }
++
++    // decrease num_adds if they can't all be fused
++    while (num_adds > 1 && !ggml_can_fuse(cgraph, node_idx, adds, num_adds)) {
++        num_adds--;
++    }
++
++    // a single add is not "fused", so just return zero
++    if (num_adds == 1) {
++        return 0;
++    }
++    return num_adds;
++}
++
++static int32_t find_first_set(uint32_t x) {
++    int32_t ret = 0;
++    if (!x) {
++        return -1;
++    }
++    while (!(x & 1)) {
++        x >>= 1;
++        ret++;
++    }
++    return ret;
++}
++
++static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
++    VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
++    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
++
++    if (vk_instance.debug_utils_support) {
++        vk::DebugUtilsLabelEXT dul = {};
++        dul.pLabelName = "ggml_backend_vk_graph_compute";
++        dul.color = std::array{1.0f, 1.0f, 1.0f, 1.0f};
++        vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast(&dul));
++    }
++
++    ctx->prealloc_size_add_rms_partials_offset = 0;
++    ctx->do_add_rms_partials = false;
++    ctx->do_add_rms_partials_offset_calculation = false;
++
++    int last_node = cgraph->n_nodes - 1;
++
++    // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly
++    while (last_node > 0 && (ggml_vk_is_empty(cgraph->nodes[last_node]) || ((cgraph->nodes[last_node]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0))) {
++        last_node -= 1;
++    }
++
++    // Reserve tensor context space for all nodes
++    ctx->tensor_ctxs.resize(cgraph->n_nodes);
++
++    bool first_node_in_batch = true; // true if next node will be first node in a batch
++    int submit_node_idx = 0; // index to first node in a batch
++
++    vk_context compute_ctx;
++    if (vk_perf_logger_enabled) {
++        // allocate/resize the query pool
++        if (ctx->num_queries < cgraph->n_nodes + 1) {
++            if (ctx->query_pool) {
++                ctx->device->device.destroyQueryPool(ctx->query_pool);
++            }
++            vk::QueryPoolCreateInfo query_create_info;
++            query_create_info.queryType = vk::QueryType::eTimestamp;
++            query_create_info.queryCount = cgraph->n_nodes + 100;
++            ctx->query_pool = ctx->device->device.createQueryPool(query_create_info);
++            ctx->num_queries = query_create_info.queryCount;
++            ctx->query_fusion_names.resize(ctx->num_queries);
++            ctx->query_fusion_node_count.resize(ctx->num_queries);
++            ctx->query_nodes.resize(ctx->num_queries);
++            ctx->query_node_idx.resize(ctx->num_queries);
++        }
++
++        ctx->device->device.resetQueryPool(ctx->query_pool, 0, cgraph->n_nodes+1);
++        std::fill(ctx->query_fusion_names.begin(), ctx->query_fusion_names.end(), nullptr);
++        std::fill(ctx->query_fusion_node_count.begin(), ctx->query_fusion_node_count.end(), 0);
++        std::fill(ctx->query_nodes.begin(), ctx->query_nodes.end(), nullptr);
++        std::fill(ctx->query_node_idx.begin(), ctx->query_node_idx.end(), 0);
++
++        GGML_ASSERT(ctx->compute_ctx.expired());
++        compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
++        ctx->compute_ctx = compute_ctx;
++        ggml_vk_ctx_begin(ctx->device, compute_ctx);
++        ctx->query_idx = 0;
++        compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
++    }
++
++    ctx->prealloc_y_last_pipeline_used = nullptr;
++    ctx->prealloc_y_last_tensor_used = nullptr;
++
++    if (ctx->prealloc_size_add_rms_partials) {
++        ggml_vk_preallocate_buffers(ctx, nullptr);
++        if (ctx->compute_ctx.expired()) {
++            compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
++            ctx->compute_ctx = compute_ctx;
++            ggml_vk_ctx_begin(ctx->device, compute_ctx);
++        } else {
++            compute_ctx = ctx->compute_ctx.lock();
++        }
++        // initialize partial sums to zero.
++        ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials, 0, 0, ctx->prealloc_size_add_rms_partials);
++        ggml_vk_sync_buffers(ctx, compute_ctx);
++    }
++
++    // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
++    // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
++    // (and scaled down based on model size, so smaller models submit earlier).
++    // Also submit at least every 100 nodes, in case there are workloads without as much matmul.
++    int nodes_per_submit = 100;
++    int submitted_nodes = 0;
++    int submit_count = 0;
++    uint64_t mul_mat_bytes = 0;
++    uint64_t total_mul_mat_bytes = 0;
++    uint64_t mul_mat_bytes_per_submit = std::min(uint64_t(100*1000*1000), ctx->last_total_mul_mat_bytes / 40u);
++    for (int i = 0; i < cgraph->n_nodes; i++) {
++        if (first_node_in_batch) {
++            submit_node_idx = i;
++        }
++
++        if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
++            auto bytes = ggml_nbytes(cgraph->nodes[i]->src[0]);
++            mul_mat_bytes += bytes;
++            total_mul_mat_bytes += bytes;
++        }
++
++        // op_srcs_fused_elementwise indicates whether an op's srcs all contribute to
++        // the fused result in an elementwise-way. This affects whether the memory for
++        // the src is allowed to overlap the memory for the destination.
++        // The array is sized to handle the largest fusion (asserted later).
++        bool op_srcs_fused_elementwise[12];
++
++        ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;
++        ctx->fused_topk_moe_scale = false;
++        const char *fusion_string {};
++        if (!ctx->device->disable_fusion) {
++            uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
++            if (num_adds) {
++                ctx->num_additional_fused_ops = num_adds - 1;
++                fusion_string = "MULTI_ADD";
++                std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, true);
++            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_ADD })) {
++                ctx->num_additional_fused_ops = 2;
++                fusion_string = "MUL_MAT_ADD_ADD";
++                op_srcs_fused_elementwise[0] = false;
++                op_srcs_fused_elementwise[1] = true;
++                op_srcs_fused_elementwise[2] = true;
++            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) {
++                ctx->num_additional_fused_ops = 1;
++                fusion_string = "MUL_MAT_ADD";
++                op_srcs_fused_elementwise[0] = false;
++                op_srcs_fused_elementwise[1] = true;
++            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL })) {
++                ctx->num_additional_fused_ops = 2;
++                fusion_string = "MUL_MAT_ID_ADD_ID_MUL";
++                op_srcs_fused_elementwise[0] = false;
++                op_srcs_fused_elementwise[1] = true;
++                op_srcs_fused_elementwise[2] = true;
++            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {
++                ctx->num_additional_fused_ops = 1;
++                fusion_string = "MUL_MAT_ID_ADD_ID";
++                op_srcs_fused_elementwise[0] = false;
++                op_srcs_fused_elementwise[1] = true;
++            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) {
++                ctx->num_additional_fused_ops = 1;
++                fusion_string = "MUL_MAT_ID_MUL";
++                op_srcs_fused_elementwise[0] = false;
++                op_srcs_fused_elementwise[1] = true;
++            } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) &&
++                       ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) &&
++                       ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) &&
++                       ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i + 2)) {
++                ctx->num_additional_fused_ops = 4;
++                fusion_string = "RMS_NORM_MUL_ROPE_VIEW_SET_ROWS";
++                op_srcs_fused_elementwise[0] = false;
++                op_srcs_fused_elementwise[1] = false;
++                op_srcs_fused_elementwise[2] = false;
++                op_srcs_fused_elementwise[3] = false;
++                op_srcs_fused_elementwise[4] = false;
++            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE })&&
++                       ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i)) {
++                ctx->num_additional_fused_ops = 2;
++                fusion_string = "RMS_NORM_MUL_ROPE";
++                // rope is approximately elementwise - whole rows are done by a single workgroup and it's row-wise
++                op_srcs_fused_elementwise[0] = false;
++                op_srcs_fused_elementwise[1] = true;
++                op_srcs_fused_elementwise[2] = true;
++            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
++                ctx->num_additional_fused_ops = 1;
++                fusion_string = "RMS_NORM_MUL";
++                // rms_norm is not elementwise, but whole rows must be consumed and the scale factor computed before
++                // they are overwritten, and one workgroup per row. So close enough.
++                op_srcs_fused_elementwise[0] = true;
++                op_srcs_fused_elementwise[1] = true;
++            } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
++                       ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
++                       ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
++                ctx->num_additional_fused_ops = 2;
++                fusion_string = "ROPE_VIEW_SET_ROWS";
++                op_srcs_fused_elementwise[0] = false;
++                op_srcs_fused_elementwise[1] = false;
++                op_srcs_fused_elementwise[2] = false;
++            } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
++                       ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
++                       ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
++                ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
++                // view of argsort writes to memory
++                ctx->fused_ops_write_mask |= 1 << 3;
++                ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM;
++                fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM";
++                std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
++            } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) &&
++                       ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) &&
++                       ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) {
++                ctx->num_additional_fused_ops = topk_moe_sigmoid_norm_bias.size() - 1;
++                // view of argsort writes to memory
++                ctx->fused_ops_write_mask |= 1 << 4;
++                ctx->fused_topk_moe_mode = TOPK_MOE_SIGMOID_NORM_BIAS;
++                fusion_string = "TOPK_MOE_SIGMOID_NORM_BIAS";
++                std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
++            } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
++                       ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
++                       ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
++                ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
++                // view of argsort writes to memory
++                ctx->fused_ops_write_mask |= 1 << 3;
++                ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX;
++                fusion_string = "TOPK_MOE_EARLY_SOFTMAX";
++                std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
++            } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
++                       ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
++                       ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
++                ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
++                // view of argsort writes to memory
++                ctx->fused_ops_write_mask |= 1 << 1;
++                ctx->fused_topk_moe_mode = TOPK_MOE_LATE_SOFTMAX;
++                fusion_string = "TOPK_MOE_LATE_SOFTMAX";
++                std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
++            }
++            if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
++                // Look for an additional scale op to fuse - occurs in deepseek2 and nemotron3 nano.
++                if (ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops - 1, { GGML_OP_DIV, GGML_OP_RESHAPE, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 }) ||
++                    ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops, { GGML_OP_GET_ROWS, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 })) {
++                    ctx->fused_topk_moe_scale = true;
++                    ctx->num_additional_fused_ops++;
++                    op_srcs_fused_elementwise[ctx->num_additional_fused_ops] = false;
++                }
++            }
++        }
++        GGML_ASSERT(ctx->num_additional_fused_ops < (int)(sizeof(op_srcs_fused_elementwise) / sizeof(op_srcs_fused_elementwise[0])));
++        ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops;
++
++        // Check whether fusion would overwrite src operands while they're still in use.
++        // If so, disable fusion.
++        if (ctx->num_additional_fused_ops) {
++            // There are up to two output nodes - topk_moe has two.
++            uint32_t bits = ctx->fused_ops_write_mask & ~(1 << ctx->num_additional_fused_ops);
++            ggml_tensor *output_nodes[2] {};
++            output_nodes[0] = cgraph->nodes[i + ctx->num_additional_fused_ops];
++            if (bits) {
++                int output_idx = find_first_set(bits);
++                GGML_ASSERT(bits == (1u << output_idx));
++                output_nodes[1] = cgraph->nodes[i + output_idx];
++            }
++
++            bool need_disable = false;
++
++            // topk_moe often overwrites the source, but for a given row all the src values are
++            // loaded before anything is stored. If there's only one row, this is safe, so treat
++            // this as a special case.
++            bool is_topk_moe_single_row = ctx->fused_topk_moe_mode != TOPK_MOE_COUNT &&
++                                          ggml_nrows(cgraph->nodes[i]->src[0]) == 1;
++
++            if (!is_topk_moe_single_row) {
++                for (int j = 0; j < 2; ++j) {
++                    ggml_tensor *dst = output_nodes[j];
++                    if (!dst) {
++                        continue;
++                    }
++                    // Loop over all srcs of all nodes in the fusion. If the src overlaps
++                    // the destination and the src is not an intermediate node that's being
++                    // elided, then disable fusion.
++                    for (int k = 0; k <= ctx->num_additional_fused_ops; ++k) {
++                        for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
++                            ggml_tensor *src = cgraph->nodes[i + k]->src[s];
++                            if (!src || src->op == GGML_OP_NONE) {
++                                continue;
++                            }
++                            if (ggml_vk_tensors_overlap(src, dst, op_srcs_fused_elementwise[k])) {
++                                bool found = false;
++                                for (int n = 0; n < k; ++n) {
++                                    if (cgraph->nodes[i + n] == src) {
++                                        found = true;
++                                        break;
++                                    }
++                                }
++                                if (!found) {
++                                    need_disable = true;
++                                }
++                            }
++                        }
++                    }
++                }
++            }
++            if (need_disable) {
++                ctx->num_additional_fused_ops = 0;
++                ctx->fused_ops_write_mask = 1;
++                ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;
++                ctx->fused_topk_moe_scale = false;
++            }
++        }
++
++        // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
++        bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
++        bool submit = (submitted_nodes >= nodes_per_submit) ||
++                      (mul_mat_bytes_per_submit != 0 && mul_mat_bytes >= mul_mat_bytes_per_submit) ||
++                      (i + ctx->num_additional_fused_ops >= last_node) ||
++                      (almost_ready && !ctx->almost_ready_fence_pending);
++
++        bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit);
++
++        if (vk_perf_logger_enabled && enqueued) {
++            if (ctx->compute_ctx.expired()) {
++                compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
++                ctx->compute_ctx = compute_ctx;
++                ggml_vk_ctx_begin(ctx->device, compute_ctx);
++            } else {
++                compute_ctx = ctx->compute_ctx.lock();
++            }
++            if (!vk_perf_logger_concurrent) {
++                // track a single node/fusion for the current query
++                ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i];
++                ctx->query_fusion_names[ctx->query_idx] = fusion_string;
++                compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
++            } else {
++                // track a fusion string and number of fused ops for the current node_idx
++                ctx->query_fusion_names[i] = fusion_string;
++                ctx->query_fusion_node_count[i] = ctx->num_additional_fused_ops;
++            }
++        }
++
++        if (enqueued) {
++            ++submitted_nodes;
++
++#ifndef GGML_VULKAN_CHECK_RESULTS
++            if (first_node_in_batch) {
++                first_node_in_batch = false;
++            }
++#endif
++        }
++
++        if (submit && enqueued) {
++            first_node_in_batch = true;
++            submitted_nodes = 0;
++            mul_mat_bytes = 0;
++            if (submit_count < 3) {
++                mul_mat_bytes_per_submit *= 2;
++            }
++            submit_count++;
++        }
++        i += ctx->num_additional_fused_ops;
++        ctx->num_additional_fused_ops = 0;
++        ctx->fused_ops_write_mask = 0;
++    }
++
++    ctx->last_total_mul_mat_bytes = total_mul_mat_bytes;
++
++    if (vk_perf_logger_enabled) {
++        // End the command buffer and submit/wait
++        GGML_ASSERT(!ctx->compute_ctx.expired());
++        compute_ctx = ctx->compute_ctx.lock();
++        ggml_vk_ctx_end(compute_ctx);
++
++        ggml_vk_submit(compute_ctx, ctx->device->fence);
++        VK_CHECK(ctx->device->device.waitForFences({ ctx->device->fence }, true, UINT64_MAX), "GGML_VULKAN_PERF waitForFences");
++        ctx->device->device.resetFences({ ctx->device->fence });
++        ctx->compute_ctx.reset();
++
++        // Get the results and pass them to the logger
++        std::vector timestamps(cgraph->n_nodes + 1);
++        VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->query_pool, 0, ctx->query_idx, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), "get timestamp results");
++        if (!vk_perf_logger_concurrent) {
++            // Log each op separately
++            for (int i = 1; i < ctx->query_idx; i++) {
++                auto node = ctx->query_nodes[i];
++                auto name = ctx->query_fusion_names[i];
++                ctx->perf_logger->log_timing(node, name, uint64_t((timestamps[i] - timestamps[i-1]) * ctx->device->properties.limits.timestampPeriod));
++            }
++        } else {
++            // Log each group of nodes
++            int prev_node_idx = 0;
++            for (int i = 1; i < ctx->query_idx; i++) {
++                auto cur_node_idx = ctx->query_node_idx[i];
++                std::vector nodes;
++                std::vector names;
++                for (int node_idx = prev_node_idx; node_idx < cur_node_idx; ++node_idx) {
++                    if (ggml_op_is_empty(cgraph->nodes[node_idx]->op)) {
++                        continue;
++                    }
++                    nodes.push_back(cgraph->nodes[node_idx]);
++                    names.push_back(ctx->query_fusion_names[node_idx]);
++                    node_idx += ctx->query_fusion_node_count[node_idx];
++                }
++                prev_node_idx = cur_node_idx;
++                ctx->perf_logger->log_timing(nodes, names, uint64_t((timestamps[i] - timestamps[i-1]) * ctx->device->properties.limits.timestampPeriod));
++            }
++        }
++        ctx->perf_logger->print_timings();
++    }
++
++    if (!ctx->device->support_async) {
++        ggml_vk_synchronize(ctx);
++    }
++
++    return GGML_STATUS_SUCCESS;
++
++    UNUSED(backend);
++}
++
++// Sort the graph for improved parallelism.
++static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * graph)
++{
++    VK_LOG_DEBUG("ggml_vk_graph_optimize(" << graph->n_nodes << " nodes)");
++    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
++
++    if (ctx->device->disable_graph_optimize) {
++        return;
++    }
++
++    auto const &is_empty = [](ggml_tensor * node) -> bool {
++        return node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
++    };
++
++    auto const &is_src_of = [](const ggml_tensor *dst, const ggml_tensor *src) -> bool {
++        for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
++            if (dst->src[s] == src) {
++                return true;
++            }
++        }
++        // implicit dependency if they view the same tensor
++        const ggml_tensor *dst2 = dst->view_src ? dst->view_src : dst;
++        const ggml_tensor *src2 = src->view_src ? src->view_src : src;
++        if (dst2 == src2) {
++            return true;
++        }
++        return false;
++    };
++
++    std::vector new_order;
++    std::vector used(graph->n_nodes, false);
++    std::set used_node_set;
++
++    int first_unused = 0;
++    while (first_unused < graph->n_nodes) {
++        std::vector current_set;
++
++        // Check for fusion patterns and avoid reordering them
++        auto const &match_pattern = [&](const std::initializer_list &pattern, int start) -> bool {
++            if (start + (int)pattern.size() <= graph->n_nodes) {
++                bool is_pattern = true;
++                for (size_t j = 0; j < pattern.size(); ++j) {
++                    if (graph->nodes[start + j]->op != pattern.begin()[j] || used[start + j]) {
++                        is_pattern = false;
++                    }
++                }
++                return is_pattern;
++            }
++            return false;
++        };
++
++        auto const &keep_pattern = [&](const std::initializer_list &pattern) -> bool {
++            if (match_pattern(pattern, first_unused)) {
++                for (size_t j = 0; j < pattern.size(); ++j) {
++                    new_order.push_back(graph->nodes[first_unused + j]);
++                    used_node_set.insert(graph->nodes[first_unused + j]);
++                    used[first_unused + j] = true;
++                }
++                while (first_unused < graph->n_nodes && used[first_unused]) {
++                    first_unused++;
++                }
++                return true;
++            }
++            return false;
++        };
++
++        if (keep_pattern(topk_moe_early_softmax_norm)) {
++            continue;
++        }
++        if (keep_pattern(topk_moe_sigmoid_norm_bias)) {
++            continue;
++        }
++        if (keep_pattern(topk_moe_early_softmax)) {
++            continue;
++        }
++        if (keep_pattern(topk_moe_late_softmax)) {
++            continue;
++        }
++
++        // First, grab the next unused node.
++        current_set.push_back(first_unused);
++
++        // Loop through the next N nodes. Grab any that don't depend on other nodes that
++        // haven't already been run. Nodes that have already been run have used[i] set
++        // to true. Allow nodes that depend on the previous node if it's a fusion pattern
++        // that we support (e.g. RMS_NORM + MUL).
++        // This first pass only grabs "real" (non-view nodes). Second pass grabs view nodes.
++        // The goal is to not interleave real and view nodes in a way that breaks fusion.
++        const int NUM_TO_CHECK = 20;
++        for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) {
++            if (used[j]) {
++                continue;
++            }
++            if (is_empty(graph->nodes[j])) {
++                continue;
++            }
++            // Don't pull forward nodes from fusion patterns
++            if (match_pattern(topk_moe_early_softmax_norm, j) ||
++                match_pattern(topk_moe_sigmoid_norm_bias, j) ||
++                match_pattern(topk_moe_early_softmax, j) ||
++                match_pattern(topk_moe_late_softmax, j)) {
++                continue;
++            }
++            bool ok = true;
++            for (int c = first_unused; c < j; ++c) {
++                if (!used[c] &&
++                    is_src_of(graph->nodes[j], graph->nodes[c]) &&
++                    !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL) &&
++                    !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) &&
++                    !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) &&
++                    !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL) &&
++                    !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD)) {
++                    ok = false;
++                    break;
++                }
++            }
++            if (ok) {
++                current_set.push_back(j);
++
++                int rope_idx = j;
++
++                // When we've found RMS_NORM + MUL, try to find a ROPE that uses it
++                if (j > 0 &&
++                    graph->nodes[j]->op == GGML_OP_MUL &&
++                    graph->nodes[j-1]->op == GGML_OP_RMS_NORM) {
++                    for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {
++                        if (graph->nodes[k]->op == GGML_OP_ROPE &&
++                            graph->nodes[k]->src[0] == graph->nodes[j] &&
++                            // Check that other srcs are already valid
++                            graph->nodes[k]->src[1]->op == GGML_OP_NONE &&
++                            (graph->nodes[k]->src[2] == nullptr || graph->nodes[k]->src[2]->op == GGML_OP_NONE)) {
++                            rope_idx = k;
++                            current_set.push_back(rope_idx);
++                            used[rope_idx] = true;
++                            break;
++                        }
++                    }
++                }
++                // Look for ROPE + VIEW + SET_ROWS and make them consecutive
++                if (graph->nodes[rope_idx]->op == GGML_OP_ROPE) {
++                    int view_idx = -1;
++                    int set_rows_idx = -1;
++                    for (int k = rope_idx+1; k < std::min(rope_idx + 10, graph->n_nodes); ++k) {
++                        if (view_idx == -1 &&
++                            graph->nodes[k]->op == GGML_OP_VIEW &&
++                            graph->nodes[k]->src[0] == graph->nodes[rope_idx]) {
++                            view_idx = k;
++                            continue;
++                        }
++                        if (view_idx != -1 &&
++                            set_rows_idx == -1 &&
++                            graph->nodes[k]->op == GGML_OP_SET_ROWS &&
++                            graph->nodes[k]->src[0] == graph->nodes[view_idx]) {
++                            set_rows_idx = k;
++                            break;
++                        }
++                    }
++                    if (set_rows_idx != -1) {
++                        current_set.push_back(view_idx);
++                        current_set.push_back(set_rows_idx);
++                        used[view_idx] = true;
++                        used[set_rows_idx] = true;
++                    }
++                }
++                // Look for MUL_MAT_ID + ADD_ID + MUL
++                if (j > 0 &&
++                    graph->nodes[j]->op == GGML_OP_ADD_ID &&
++                    graph->nodes[j-1]->op == GGML_OP_MUL_MAT_ID) {
++                    for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {
++                        if (graph->nodes[k]->op == GGML_OP_MUL &&
++                            graph->nodes[k]->src[0] == graph->nodes[j] &&
++                            // src1 must either be weights or already processed
++                            (graph->nodes[k]->src[1]->op == GGML_OP_NONE || used_node_set.find(graph->nodes[k]->src[1]) != used_node_set.end())) {
++                            current_set.push_back(k);
++                            used[k] = true;
++                            break;
++                        }
++                    }
++                }
++                // Look for MUL_MAT + ADD + ADD
++                if (j > 0 &&
++                    graph->nodes[j]->op == GGML_OP_ADD &&
++                    graph->nodes[j-1]->op == GGML_OP_MUL_MAT) {
++                    for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {
++                        if (graph->nodes[k]->op == GGML_OP_ADD &&
++                            graph->nodes[k]->src[0] == graph->nodes[j] &&
++                            // src1 must either be weights or already processed
++                            (graph->nodes[k]->src[1]->op == GGML_OP_NONE || used_node_set.find(graph->nodes[k]->src[1]) != used_node_set.end())) {
++                            current_set.push_back(k);
++                            used[k] = true;
++                            break;
++                        }
++                    }
++                }
++            }
++        }
++        // Second pass grabs view nodes.
++        // Skip this if it would break a fusion optimization (don't split up add->rms_norm or add->add).
++        if (graph->nodes[current_set.back()]->op != GGML_OP_ADD) {
++            for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) {
++                if (used[j]) {
++                    continue;
++                }
++                if (!is_empty(graph->nodes[j])) {
++                    continue;
++                }
++                bool ok = true;
++                for (int c = first_unused; c < j; ++c) {
++                    bool c_in_current_set = std::find(current_set.begin(), current_set.end(), c) != current_set.end();
++                    // skip views whose srcs haven't been processed.
++                    if (!used[c] &&
++                        is_src_of(graph->nodes[j], graph->nodes[c]) &&
++                        !c_in_current_set) {
++                        ok = false;
++                        break;
++                    }
++                }
++                if (ok) {
++                    current_set.push_back(j);
++                }
++            }
++        }
++
++        // Push the current set into new_order
++        for (auto c : current_set) {
++            new_order.push_back(graph->nodes[c]);
++            used_node_set.insert(graph->nodes[c]);
++            used[c] = true;
++        }
++        while (first_unused < graph->n_nodes && used[first_unused]) {
++            first_unused++;
++        }
++    }
++    // Replace the graph with the new order.
++    for (int i = 0; i < graph->n_nodes; ++i) {
++        graph->nodes[i] = new_order[i];
++    }
++}
++
++static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
++    VK_LOG_DEBUG("ggml_backend_vk_event_record(backend=" << backend << ", event=" << event << ")");
++    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
++    vk_event *vkev = (vk_event *)event->context;
++
++    vk_context compute_ctx;
++
++    if (ctx->compute_ctx.expired()) {
++        // Initialize new transfer context
++        compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
++        ctx->compute_ctx = compute_ctx;
++        ggml_vk_ctx_begin(ctx->device, compute_ctx);
++    } else {
++        compute_ctx = ctx->compute_ctx.lock();
++    }
++
++    // the backend interface doesn't have an explicit reset, so reset it here
++    // before we record the command to set it
++    ctx->device->device.resetEvent(vkev->event);
++    ctx->device->device.resetFences({ vkev->fence });
++
++    ggml_vk_set_event(compute_ctx, vkev->event);
++
++    ggml_vk_ctx_end(compute_ctx);
++
++    ggml_vk_submit(compute_ctx, {vkev->fence});
++    ctx->submit_pending = true;
++    ctx->compute_ctx.reset();
++}
++
++static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
++    VK_LOG_DEBUG("ggml_backend_vk_event_wait(backend=" << backend << ", event=" << event << ")");
++    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
++    vk_event *vkev = (vk_event *)event->context;
++
++    vk_context compute_ctx;
++
++    if (ctx->compute_ctx.expired()) {
++        // Initialize new transfer context
++        compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
++        ctx->compute_ctx = compute_ctx;
++        ggml_vk_ctx_begin(ctx->device, compute_ctx);
++    } else {
++        compute_ctx = ctx->compute_ctx.lock();
++    }
++
++    ggml_vk_wait_events(compute_ctx, {vkev->event});
++    ggml_vk_ctx_end(compute_ctx);
++    ctx->compute_ctx.reset();
++}
++
++// TODO: enable async and synchronize
++static ggml_backend_i ggml_backend_vk_interface = {
++    /* .get_name                = */ ggml_backend_vk_name,
++    /* .free                    = */ ggml_backend_vk_free,
++    /* .set_tensor_async        = */ ggml_backend_vk_set_tensor_async,
++    /* .get_tensor_async        = */ ggml_backend_vk_get_tensor_async,
++    /* .cpy_tensor_async        = */ NULL,  // ggml_backend_vk_cpy_tensor_async,
++    /* .synchronize             = */ ggml_backend_vk_synchronize,
++    /* .graph_plan_create       = */ NULL,
++    /* .graph_plan_free         = */ NULL,
++    /* .graph_plan_update       = */ NULL,
++    /* .graph_plan_compute      = */ NULL,
++    /* .graph_compute           = */ ggml_backend_vk_graph_compute,
++    /* .event_record            = */ ggml_backend_vk_event_record,
++    /* .event_wait              = */ ggml_backend_vk_event_wait,
++    /* .graph_optimize          = */ ggml_vk_graph_optimize,
++};
++
++static ggml_guid_t ggml_backend_vk_guid() {
++    static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b };
++    return &guid;
++}
++
++ggml_backend_t ggml_backend_vk_init(size_t dev_num) {
++    VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")");
++
++    ggml_backend_vk_context * ctx = new ggml_backend_vk_context;
++    ggml_vk_init(ctx, dev_num);
++
++    ggml_backend_t vk_backend = new ggml_backend {
++        /* .guid    = */ ggml_backend_vk_guid(),
++        /* .iface   = */ ggml_backend_vk_interface,
++        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num),
++        /* .context = */ ctx,
++    };
++
++    if (!ctx->device->support_async) {
++        vk_backend->iface.get_tensor_async = nullptr;
++    }
++
++    return vk_backend;
++}
++
++bool ggml_backend_is_vk(ggml_backend_t backend) {
++    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid());
++}
++
++int ggml_backend_vk_get_device_count() {
++    return ggml_vk_get_device_count();
++}
++
++void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) {
++    GGML_ASSERT(device < (int) vk_instance.device_indices.size());
++    int dev_idx = vk_instance.device_indices[device];
++    ggml_vk_get_device_description(dev_idx, description, description_size);
++}
++
++void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
++    GGML_ASSERT(device < (int) vk_instance.device_indices.size());
++    GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size());
++
++    vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
++    vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops;
++    vk::PhysicalDeviceMemoryProperties2 memprops = {};
++    const bool membudget_supported = vk_instance.device_supports_membudget[device];
++    const bool is_integrated_gpu = vkdev.getProperties().deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
++
++    if (membudget_supported) {
++        memprops.pNext = &budgetprops;
++    }
++    vkdev.getMemoryProperties2(&memprops);
++
++    *total = 0;
++    *free = 0;
++
++    for (uint32_t i = 0; i < memprops.memoryProperties.memoryHeapCount; ++i) {
++        const vk::MemoryHeap & heap = memprops.memoryProperties.memoryHeaps[i];
++
++        if (is_integrated_gpu || (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal)) {
++            *total += heap.size;
++
++            if (membudget_supported && i < budgetprops.heapUsage.size()) {
++                *free += budgetprops.heapBudget[i] - budgetprops.heapUsage[i];
++            } else {
++                *free += heap.size;
++            }
++        }
++    }
++}
++
++static vk::PhysicalDeviceType ggml_backend_vk_get_device_type(int device_idx) {
++    GGML_ASSERT(device_idx >= 0 && device_idx < (int) vk_instance.device_indices.size());
++
++    vk::PhysicalDevice device = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device_idx]];
++
++    vk::PhysicalDeviceProperties2 props = {};
++    device.getProperties2(&props);
++
++    return props.properties.deviceType;
++}
++
++static std::string ggml_backend_vk_get_device_pci_id(int device_idx) {
++    GGML_ASSERT(device_idx >= 0 && device_idx < (int) vk_instance.device_indices.size());
++
++    vk::PhysicalDevice device = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device_idx]];
++
++    const std::vector ext_props = device.enumerateDeviceExtensionProperties();
++
++    bool ext_support = false;
++
++    for (const auto& properties : ext_props) {
++        if (strcmp("VK_EXT_pci_bus_info", properties.extensionName) == 0) {
++            ext_support = true;
++            break;
++        }
++    }
++
++    if (!ext_support) {
++        return "";
++    }
++
++    vk::PhysicalDeviceProperties2 props = {};
++    vk::PhysicalDevicePCIBusInfoPropertiesEXT pci_bus_info = {};
++
++    props.pNext = &pci_bus_info;
++
++    device.getProperties2(&props);
++
++    const uint32_t pci_domain = pci_bus_info.pciDomain;
++    const uint32_t pci_bus = pci_bus_info.pciBus;
++    const uint32_t pci_device = pci_bus_info.pciDevice;
++    const uint8_t pci_function = (uint8_t) pci_bus_info.pciFunction; // pci function is between 0 and 7, prevent printf overflow warning
++
++    char pci_bus_id[16] = {};
++    snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function);
++
++    return std::string(pci_bus_id);
++}
++
++//////////////////////////
++
++struct ggml_backend_vk_device_context {
++    size_t device;
++    std::string name;
++    std::string description;
++    bool is_integrated_gpu;
++    std::string pci_bus_id;
++    int op_offload_min_batch_size;
++};
++
++static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {
++    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
++    return ctx->name.c_str();
++}
++
++static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t dev) {
++    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
++    return ctx->description.c_str();
++}
++
++static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
++    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context;
++    ggml_backend_vk_get_device_memory(ctx->device, free, total);
++}
++
++static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) {
++    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
++    return ggml_backend_vk_buffer_type(ctx->device);
++}
++
++static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(ggml_backend_dev_t dev) {
++    UNUSED(dev);
++    return ggml_backend_vk_host_buffer_type();
++}
++
++static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) {
++    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
++
++    return ctx->is_integrated_gpu ? GGML_BACKEND_DEVICE_TYPE_IGPU : GGML_BACKEND_DEVICE_TYPE_GPU;
++}
++
++static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
++    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
++
++    props->name        = ggml_backend_vk_device_get_name(dev);
++    props->description = ggml_backend_vk_device_get_description(dev);
++    props->type        = ggml_backend_vk_device_get_type(dev);
++    props->device_id   = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
++    ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
++    props->caps = {
++        /* .async                 = */ true,
++        /* .host_buffer           = */ true,
++        /* .buffer_from_host_ptr  = */ false,
++        /* .events                = */ true,
++    };
++}
++
++static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) {
++    UNUSED(params);
++    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
++    return ggml_backend_vk_init(ctx->device);
++}
++
++static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
++    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
++    const vk_device& device = ggml_vk_get_device(ctx->device);
++
++    const bool uses_bda = (op->op == GGML_OP_IM2COL || op->op == GGML_OP_IM2COL_3D) &&
++                          device->shader_int64 && device->buffer_device_address;
++
++    auto const & tensor_size_supported = [&](size_t tensor_size) {
++        if (tensor_size > device->max_buffer_size) {
++            return false;
++        }
++        // For im2col shaders using BDA, maxStorageBufferRange limit doesn't apply.
++        // If shader64BitIndexing is enabled, maxStorageBufferRange limit doesn't apply.
++        if (!uses_bda && !device->shader_64b_indexing) {
++            if (tensor_size > device->properties.limits.maxStorageBufferRange) {
++                return false;
++            }
++        }
++        return true;
++    };
++    // reject any tensors larger than the max buffer size
++    for (int i = 0; i < GGML_MAX_SRC; i++) {
++        if (op->src[i] && !tensor_size_supported(ggml_nbytes(op->src[i]))) {
++            return false;
++        }
++    }
++    if (!tensor_size_supported(ggml_nbytes(op))) {
++        return false;
++    }
++
++    switch (op->op) {
++        case GGML_OP_UNARY:
++            switch (ggml_get_unary_op(op)) {
++                case GGML_UNARY_OP_EXP:
++                case GGML_UNARY_OP_GELU:
++                case GGML_UNARY_OP_GELU_ERF:
++                case GGML_UNARY_OP_GELU_QUICK:
++                case GGML_UNARY_OP_SILU:
++                case GGML_UNARY_OP_RELU:
++                case GGML_UNARY_OP_XIELU:
++                case GGML_UNARY_OP_NEG:
++                case GGML_UNARY_OP_TANH:
++                case GGML_UNARY_OP_SIGMOID:
++                case GGML_UNARY_OP_HARDSIGMOID:
++                case GGML_UNARY_OP_HARDSWISH:
++                case GGML_UNARY_OP_ABS:
++                case GGML_UNARY_OP_SOFTPLUS:
++                case GGML_UNARY_OP_STEP:
++                case GGML_UNARY_OP_ROUND:
++                case GGML_UNARY_OP_CEIL:
++                case GGML_UNARY_OP_FLOOR:
++                case GGML_UNARY_OP_TRUNC:
++                    return ggml_is_contiguous(op->src[0]) &&
++                           (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
++                           (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
++                           (op->src[0]->type == op->type);
++                default:
++                    return false;
++            }
++        case GGML_OP_GLU:
++            switch (ggml_get_glu_op(op)) {
++                case GGML_GLU_OP_GEGLU:
++                case GGML_GLU_OP_REGLU:
++                case GGML_GLU_OP_SWIGLU:
++                case GGML_GLU_OP_SWIGLU_OAI:
++                case GGML_GLU_OP_GEGLU_ERF:
++                case GGML_GLU_OP_GEGLU_QUICK:
++                    return ggml_is_contiguous(op->src[0]) &&
++                           (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
++                           (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
++                           (op->src[0]->type == op->type);
++                default:
++                    return false;
++            }
++        case GGML_OP_MUL_MAT:
++        case GGML_OP_MUL_MAT_ID:
++            {
++                ggml_type src0_type = op->src[0]->type;
++                if (op->op == GGML_OP_MUL_MAT_ID) {
++                    if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
++                        // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
++                        return false;
++                    }
++                }
++                switch (src0_type) {
++                    case GGML_TYPE_F32:
++                    case GGML_TYPE_F16:
++                    case GGML_TYPE_BF16:
++                    case GGML_TYPE_Q4_0:
++                    case GGML_TYPE_Q4_1:
++                    case GGML_TYPE_Q5_0:
++                    case GGML_TYPE_Q5_1:
++                    case GGML_TYPE_Q8_0:
++                    case GGML_TYPE_Q2_K:
++                    case GGML_TYPE_Q3_K:
++                    case GGML_TYPE_Q4_K:
++                    case GGML_TYPE_Q5_K:
++                    case GGML_TYPE_Q6_K:
++                    case GGML_TYPE_IQ1_S:
++                    case GGML_TYPE_IQ1_M:
++                    case GGML_TYPE_IQ2_XXS:
++                    case GGML_TYPE_IQ2_XS:
++                    case GGML_TYPE_IQ2_S:
++                    case GGML_TYPE_IQ3_XXS:
++                    case GGML_TYPE_IQ3_S:
++                    case GGML_TYPE_IQ4_XS:
++                    case GGML_TYPE_IQ4_NL:
++                    case GGML_TYPE_MXFP4:
++                        break;
++                    default:
++                        return false;
++                }
++                struct ggml_tensor * a;
++                struct ggml_tensor * b;
++                if (op->op == GGML_OP_MUL_MAT) {
++                    a = op->src[0];
++                    b = op->src[1];
++                } else {
++                    a = op->src[2];
++                    b = op->src[1];
++                }
++                if (a->ne[3] != b->ne[3]) {
++                    return false;
++                }
++                if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16) ||
++                    !(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) {
++                    return false;
++                }
++                if (op->src[0]->type == GGML_TYPE_BF16 && op->src[1]->type == GGML_TYPE_F16) {
++                    // We currently don't have a bf16 x f16 shader, or an fp16->bf16 copy shader.
++                    // So don't support this combination for now.
++                    return false;
++                }
++
++                return true;
++            }
++        case GGML_OP_FLASH_ATTN_EXT:
++            {
++                bool coopmat2 = device->coopmat2;
++                uint32_t HSK = op->src[1]->ne[0];
++                uint32_t HSV = op->src[2]->ne[0];
++                if ((HSK % 8) != 0 || (HSV % 8) != 0) {
++                    return false;
++                }
++                if (op->src[4] && op->src[4]->type != GGML_TYPE_F32) {
++                    return false;
++                }
++                if (op->src[0]->type != GGML_TYPE_F32) {
++                    return false;
++                }
++                if (op->type != GGML_TYPE_F32) {
++                    return false;
++                }
++                if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
++                    return false;
++                }
++                // It's straightforward to support different K/V dequant, but would
++                // significantly increase the number of pipelines
++                if (op->src[1]->type != op->src[2]->type) {
++                    return false;
++                }
++                switch (op->src[1]->type) {
++                case GGML_TYPE_F16:
++                case GGML_TYPE_F32:
++                case GGML_TYPE_Q4_0:
++                case GGML_TYPE_Q8_0:
++                    // supported in scalar and coopmat2 paths
++                    break;
++                case GGML_TYPE_Q4_1:
++                case GGML_TYPE_Q5_0:
++                case GGML_TYPE_Q5_1:
++                // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
++                //case GGML_TYPE_Q2_K:
++                //case GGML_TYPE_Q3_K:
++                //case GGML_TYPE_Q4_K:
++                //case GGML_TYPE_Q5_K:
++                //case GGML_TYPE_Q6_K:
++                //case GGML_TYPE_IQ1_S:
++                //case GGML_TYPE_IQ1_M:
++                //case GGML_TYPE_IQ2_XXS:
++                //case GGML_TYPE_IQ2_XS:
++                //case GGML_TYPE_IQ2_S:
++                //case GGML_TYPE_IQ3_XXS:
++                //case GGML_TYPE_IQ3_S:
++                //case GGML_TYPE_IQ4_XS:
++                case GGML_TYPE_IQ4_NL:
++                    // currently supported only in coopmat2 path
++                    if (!coopmat2) {
++                        return false;
++                    }
++                    break;
++                default:
++                    return false;
++                }
++                if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) {
++                    // scalar/coopmat1 FA uses subgroupShuffle/subgroupAll
++                    return false;
++                }
++                return true;
++            }
++        case GGML_OP_GET_ROWS:
++            {
++                switch (op->src[0]->type) {
++                    case GGML_TYPE_F32:
++                    case GGML_TYPE_F16:
++                    case GGML_TYPE_BF16:
++                    case GGML_TYPE_Q4_0:
++                    case GGML_TYPE_Q4_1:
++                    case GGML_TYPE_Q5_0:
++                    case GGML_TYPE_Q5_1:
++                    case GGML_TYPE_Q8_0:
++                    case GGML_TYPE_Q2_K:
++                    case GGML_TYPE_Q3_K:
++                    case GGML_TYPE_Q4_K:
++                    case GGML_TYPE_Q5_K:
++                    case GGML_TYPE_Q6_K:
++                    case GGML_TYPE_IQ1_S:
++                    case GGML_TYPE_IQ1_M:
++                    case GGML_TYPE_IQ2_XXS:
++                    case GGML_TYPE_IQ2_XS:
++                    case GGML_TYPE_IQ2_S:
++                    case GGML_TYPE_IQ3_XXS:
++                    case GGML_TYPE_IQ3_S:
++                    case GGML_TYPE_IQ4_XS:
++                    case GGML_TYPE_IQ4_NL:
++                    case GGML_TYPE_MXFP4:
++                    case GGML_TYPE_I32:
++                        return true;
++                    default:
++                        return false;
++                }
++            }
++        case GGML_OP_SET_ROWS:
++            {
++                switch (op->type) {
++                    case GGML_TYPE_F32:
++                    case GGML_TYPE_F16:
++                    case GGML_TYPE_BF16:
++                    case GGML_TYPE_Q4_0:
++                    case GGML_TYPE_Q4_1:
++                    case GGML_TYPE_Q5_0:
++                    case GGML_TYPE_Q5_1:
++                    case GGML_TYPE_Q8_0:
++                    case GGML_TYPE_IQ4_NL:
++                        return true;
++                    default:
++                        return false;
++                }
++            }
++        case GGML_OP_CONT:
++        case GGML_OP_CPY:
++        case GGML_OP_DUP:
++            {
++                ggml_type src0_type = op->src[0]->type;
++                ggml_type src1_type = op->src[1] != nullptr ? op->src[1]->type : src0_type;
++
++                if (src0_type == GGML_TYPE_F32) {
++                    switch (src1_type) {
++                    case GGML_TYPE_F32:
++                    case GGML_TYPE_F16:
++                    case GGML_TYPE_BF16:
++                    case GGML_TYPE_Q4_0:
++                    case GGML_TYPE_Q4_1:
++                    case GGML_TYPE_Q5_0:
++                    case GGML_TYPE_Q5_1:
++                    case GGML_TYPE_Q8_0:
++                    case GGML_TYPE_IQ4_NL:
++                        return true;
++                    default:
++                        break;
++                    }
++                }
++                if (src1_type == GGML_TYPE_F32) {
++                    switch (src0_type) {
++                    case GGML_TYPE_F16:
++                    case GGML_TYPE_Q4_0:
++                    case GGML_TYPE_Q4_1:
++                    case GGML_TYPE_Q5_0:
++                    case GGML_TYPE_Q5_1:
++                    case GGML_TYPE_Q8_0:
++                    case GGML_TYPE_IQ4_NL:
++                        return true;
++                    default:
++                        break;
++                    }
++                }
++
++                if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
++                    return true;
++                }
++
++                if (
++                    (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) ||
++                    (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32)
++                ) {
++                    return true;
++                }
++
++                // We can handle copying from a type to the same type if it's
++                // either not quantized or is quantized and contiguous.
++                // We use f16 or f32 shaders to do the copy,
++                // so the type/block size must be a multiple of 4.
++                if (src0_type == src1_type &&
++                    (!ggml_is_quantized(src0_type) || (ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op))) &&
++                    (ggml_type_size(src0_type) % 2) == 0) {
++                    return true;
++                }
++                return false;
++            }
++        case GGML_OP_REPEAT:
++            return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
++        case GGML_OP_REPEAT_BACK:
++            return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
++        case GGML_OP_ROPE:
++            return ggml_is_contiguous_rows(op) && ggml_is_contiguous_rows(op->src[0]);
++        case GGML_OP_ROPE_BACK:
++        case GGML_OP_NONE:
++        case GGML_OP_RESHAPE:
++        case GGML_OP_VIEW:
++        case GGML_OP_PERMUTE:
++        case GGML_OP_TRANSPOSE:
++        case GGML_OP_RMS_NORM:
++            return true;
++        case GGML_OP_NORM:
++        case GGML_OP_GROUP_NORM:
++            return ggml_is_contiguous(op->src[0]);
++        case GGML_OP_L2_NORM:
++            return ggml_is_contiguous_rows(op->src[0]) &&
++                   op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
++        case GGML_OP_ADD:
++        case GGML_OP_SUB:
++        case GGML_OP_MUL:
++        case GGML_OP_DIV:
++            return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
++                   (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
++                   (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
++        case GGML_OP_ADD_ID:
++            return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_I32 &&
++                   op->type == GGML_TYPE_F32;
++        case GGML_OP_SILU_BACK:
++        case GGML_OP_RMS_NORM_BACK:
++            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
++        case GGML_OP_SQR:
++        case GGML_OP_SQRT:
++        case GGML_OP_SIN:
++        case GGML_OP_COS:
++        case GGML_OP_CLAMP:
++            return op->src[0]->type == GGML_TYPE_F32;
++        case GGML_OP_LEAKY_RELU:
++        case GGML_OP_OPT_STEP_ADAMW:
++        case GGML_OP_OPT_STEP_SGD:
++            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
++        case GGML_OP_LOG:
++        case GGML_OP_TRI:
++        case GGML_OP_DIAG:
++            return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
++                   op->type == op->src[0]->type;
++        case GGML_OP_ARGSORT:
++            {
++                if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
++                    return false;
++                }
++                // pipeline_argsort_large_f32 requires vulkan memory model.
++                if (device->vulkan_memory_model) {
++                    return true;
++                } else {
++                    return op->ne[0] <= (1 << device->max_workgroup_size_log2);
++                }
++            }
++        case GGML_OP_TOP_K:
++            {
++                if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
++                    return false;
++                }
++                // We could potentially support larger, using argsort to sort the
++                // whole thing. Not clear if this is needed.
++                uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1;
++                if (min_pipeline >= num_topk_pipelines ||
++                    !device->pipeline_topk_f32[min_pipeline]) {
++                    return false;
++                }
++            }
++            return true;
++        case GGML_OP_UPSCALE:
++            if (op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS) {
++                if ((op->op_params[0] & 0xFF) != GGML_SCALE_MODE_BILINEAR) {
++                    return false;
++                }
++            }
++            return op->src[0]->type == GGML_TYPE_F32;
++        case GGML_OP_ACC:
++            return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
++        case GGML_OP_SET:
++            return op->src[0]->type == op->src[1]->type && op->src[0]->type == op->type &&
++                   (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_I32);
++        case GGML_OP_CONCAT:
++            return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32);
++        case GGML_OP_ADD1:
++            return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32)
++                || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32)
++                || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16);
++        case GGML_OP_ARANGE:
++        case GGML_OP_FILL:
++            return op->type == GGML_TYPE_F32;
++        case GGML_OP_SCALE:
++            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
++        case GGML_OP_PAD:
++        case GGML_OP_ROLL:
++            return op->src[0]->type == GGML_TYPE_F32;
++        case GGML_OP_DIAG_MASK_INF:
++            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
++        case GGML_OP_SOFT_MAX:
++            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32
++                && (!op->src[1] || (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16));
++        case GGML_OP_SOFT_MAX_BACK:
++            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32
++                && ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32;
++        case GGML_OP_SUM:
++        case GGML_OP_SUM_ROWS:
++        case GGML_OP_MEAN:
++            return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
++        case GGML_OP_CUMSUM:
++            {
++                if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
++                    return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
++                }
++                return false;
++            }
++        case GGML_OP_SOLVE_TRI:
++            {
++                if (op->type != GGML_TYPE_F32 || op->src[0]->type != GGML_TYPE_F32) {
++                    return false;
++                }
++                const uint32_t N = op->src[0]->ne[0];
++                const uint32_t K = op->src[1]->ne[0];
++                // K dimension limited to workgroup size
++                if (K > 1u << device->max_workgroup_size_log2) {
++                    return false;
++                }
++                const uint32_t batch_N = device->properties.limits.maxComputeSharedMemorySize / ((N + K) * sizeof(float));
++
++                if (batch_N == 0) {
++                    return false;
++                }
++                return true;
++            }
++        case GGML_OP_ARGMAX:
++            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
++        case GGML_OP_COUNT_EQUAL:
++            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_I32
++                && ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_I32;
++        case GGML_OP_IM2COL:
++            return ggml_is_contiguous(op->src[1])
++                && op->src[1]->type == GGML_TYPE_F32
++                && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
++        case GGML_OP_IM2COL_3D:
++            return op->src[1]->type == GGML_TYPE_F32
++                && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
++        case GGML_OP_TIMESTEP_EMBEDDING:
++            return op->src[0]->type == GGML_TYPE_F32;
++        case GGML_OP_CONV_2D_DW:
++            return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16)
++                && op->src[1]->type == GGML_TYPE_F32;
++        case GGML_OP_POOL_2D:
++            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
++        case GGML_OP_RWKV_WKV6:
++        case GGML_OP_RWKV_WKV7:
++            return true; // all inputs are contiguous, see ggml.c
++        case GGML_OP_SSM_SCAN:
++            {
++                for (int i = 0; i < 6; i++) {
++                    if (op->src[i] && ggml_is_quantized(op->src[i]->type)) {
++                        return false;
++                    }
++                }
++                if (op->src[6] && op->src[6]->type != GGML_TYPE_I32) {
++                    return false;
++                }
++                if (op->src[0]->type != GGML_TYPE_F32 || op->type != GGML_TYPE_F32) {
++                    return false;
++                }
++
++                const uint32_t d_state = op->src[0]->ne[0];
++                const uint32_t head_dim = op->src[0]->ne[1];
++
++                bool is_mamba2 = (op->src[3] && op->src[3]->nb[1] == sizeof(float));
++                if (!is_mamba2) {
++                    return false;
++                }
++
++                if ((d_state != 128 && d_state != 256) || head_dim % 16 != 0) {
++                    return false;
++                }
++
++                size_t shmem_size = d_state * sizeof(float);
++
++                if (shmem_size > device->properties.limits.maxComputeSharedMemorySize) {
++                    return false;
++                }
++
++                if (!device->subgroup_basic) {
++                    return false;
++                }
++
++                return true;
++            }
++        case GGML_OP_SSM_CONV:
++            return op->src[0]->type == GGML_TYPE_F32;
++        case GGML_OP_CONV_TRANSPOSE_1D:
++            return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
++        case GGML_OP_CONV_2D:
++        case GGML_OP_CONV_TRANSPOSE_2D:
++            {
++                // Channel-contiguous format is not supported yet.
++                return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
++                    op->src[1]->type == GGML_TYPE_F32 &&
++                    op->type == GGML_TYPE_F32 &&
++                    ggml_is_contiguous(op->src[0]) &&
++                    ggml_is_contiguous(op->src[1]) &&
++                    ggml_is_contiguous(op));
++            }
++        default:
++            return false;
++    }
++
++    UNUSED(dev);
++}
++
++static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
++    if (buft->iface.get_name != ggml_backend_vk_buffer_type_name) {
++        return false;
++    }
++
++    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
++    ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context;
++
++    return buft_ctx->device->idx == ctx->device;
++}
++
++static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
++    ggml_backend_vk_device_context * dev_ctx = (ggml_backend_vk_device_context *)dev->context;
++
++    return (op->ne[1] >= dev_ctx->op_offload_min_batch_size && op->op != GGML_OP_GET_ROWS) ||
++           (op->ne[2] >= dev_ctx->op_offload_min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
++}
++
++static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t dev) {
++    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
++    auto device = ggml_vk_get_device(ctx->device);
++
++    vk_event *vkev = new vk_event;
++    if (!vkev) {
++        return nullptr;
++    }
++
++    // The event/fence is expected to initially be in the signaled state.
++    vkev->event = device->device.createEvent({});
++    vkev->fence = device->device.createFence({vk::FenceCreateFlagBits::eSignaled});
++    device->device.setEvent(vkev->event);
++
++    return new ggml_backend_event {
++        /* .device  = */ dev,
++        /* .context = */ vkev,
++    };
++}
++
++static void ggml_backend_vk_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
++    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
++    auto device = ggml_vk_get_device(ctx->device);
++
++    vk_event *vkev = (vk_event *)event->context;
++
++    device->device.destroyFence(vkev->fence);
++    device->device.destroyEvent(vkev->event);
++    delete vkev;
++    delete event;
++}
++
++static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
++    VK_LOG_DEBUG("ggml_backend_vk_device_event_synchronize(backend=" << dev << ", event=" << event << ")");
++    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
++    auto device = ggml_vk_get_device(ctx->device);
++    vk_event *vkev = (vk_event *)event->context;
++
++    VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize");
++}
++
++static vk_buffer ggml_vk_buffer_from_host_ptr(vk_device & device, void * ptr, size_t size) {
++    if (!device->external_memory_host) {
++        return {};
++    }
++
++    uintptr_t uptr = reinterpret_cast(ptr);
++    if (uptr & (device->min_imported_host_pointer_alignment - 1)) {
++        return {};
++    }
++    if (size & (device->min_imported_host_pointer_alignment - 1)) {
++        return {};
++    }
++
++    const vk::MemoryPropertyFlags property_flags = vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached;
++
++    vk_buffer buf {};
++    try {
++        buf = ggml_vk_create_buffer(device, size, { property_flags }, ptr);
++    } catch (vk::SystemError& e) {
++        GGML_LOG_WARN("ggml_vulkan: Failed ggml_vk_create_buffer (%s)\n", e.what());
++    }
++
++    return buf;
++}
++
++static ggml_backend_buffer_t ggml_backend_vk_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
++    VK_LOG_DEBUG("ggml_backend_vk_device_buffer_from_host_ptr(backend=" << dev << ", ptr=" << ptr << ", size=" << size << ")");
++    GGML_UNUSED(max_tensor_size);
++
++    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
++    auto device = ggml_vk_get_device(ctx->device);
++
++    vk_buffer buf = ggml_vk_buffer_from_host_ptr(device, ptr, size);
++
++    if (!buf) {
++        return {};
++    }
++
++    ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(device, std::move(buf), device->name);
++
++    ggml_backend_buffer_t ret = ggml_backend_buffer_init(ggml_backend_vk_device_get_buffer_type(dev), ggml_backend_vk_buffer_interface, bufctx, size);
++
++    return ret;
++}
++
++static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
++    /* .get_name             = */ ggml_backend_vk_device_get_name,
++    /* .get_description      = */ ggml_backend_vk_device_get_description,
++    /* .get_memory           = */ ggml_backend_vk_device_get_memory,
++    /* .get_type             = */ ggml_backend_vk_device_get_type,
++    /* .get_props            = */ ggml_backend_vk_device_get_props,
++    /* .init_backend         = */ ggml_backend_vk_device_init,
++    /* .get_buffer_type      = */ ggml_backend_vk_device_get_buffer_type,
++    /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type,
++    /* .buffer_from_host_ptr = */ ggml_backend_vk_device_buffer_from_host_ptr,
++    /* .supports_op          = */ ggml_backend_vk_device_supports_op,
++    /* .supports_buft        = */ ggml_backend_vk_device_supports_buft,
++    /* .offload_op           = */ ggml_backend_vk_device_offload_op,
++    /* .event_new            = */ ggml_backend_vk_device_event_new,
++    /* .event_free           = */ ggml_backend_vk_device_event_free,
++    /* .event_synchronize    = */ ggml_backend_vk_device_event_synchronize,
++};
++
++static const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) {
++    UNUSED(reg);
++    return GGML_VK_NAME;
++}
++
++static size_t ggml_backend_vk_reg_get_device_count(ggml_backend_reg_t reg) {
++    UNUSED(reg);
++    return ggml_backend_vk_get_device_count();
++}
++
++static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, size_t device) {
++    static std::vector devices;
++
++    static bool initialized = false;
++
++    {
++        static std::mutex mutex;
++        std::lock_guard lock(mutex);
++        if (!initialized) {
++            const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
++            for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) {
++                ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;
++                char desc[256];
++                ggml_backend_vk_get_device_description(i, desc, sizeof(desc));
++                ctx->device = i;
++                ctx->name = GGML_VK_NAME + std::to_string(i);
++                ctx->description = desc;
++                ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu;
++                ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i);
++                ctx->op_offload_min_batch_size = min_batch_size;
++                devices.push_back(new ggml_backend_device {
++                    /* .iface   = */ ggml_backend_vk_device_i,
++                    /* .reg     = */ reg,
++                    /* .context = */ ctx,
++                });
++            }
++            initialized = true;
++        }
++    }
++
++    GGML_ASSERT(device < devices.size());
++    return devices[device];
++}
++
++static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = {
++    /* .get_name         = */ ggml_backend_vk_reg_get_name,
++    /* .get_device_count = */ ggml_backend_vk_reg_get_device_count,
++    /* .get_device       = */ ggml_backend_vk_reg_get_device,
++    /* .get_proc_address = */ NULL,
++};
++
++ggml_backend_reg_t ggml_backend_vk_reg() {
++    static ggml_backend_reg reg = {
++        /* .api_version = */ GGML_BACKEND_API_VERSION,
++        /* .iface       = */ ggml_backend_vk_reg_i,
++        /* .context     = */ nullptr,
++    };
++    try {
++        ggml_vk_instance_init();
++        return ®
++    } catch (const vk::SystemError& e) {
++        VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: System error: " << e.what());
++        return nullptr;
++    } catch (const std::exception &e) {
++        VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: " << e.what());
++        return nullptr;
++    } catch (...) {
++        VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: unknown exception during Vulkan init");
++        return nullptr;
++    }
++}
++
++// Extension availability
++static bool ggml_vk_instance_layer_settings_available() {
++#ifdef GGML_VULKAN_VALIDATE
++    // Check if validation layer provides the extension
++    const std::string layer_name = "VK_LAYER_KHRONOS_validation";
++    for (const auto& layer : vk::enumerateInstanceLayerProperties()) {
++        if (layer_name == layer.layerName.data()) {
++            for (const auto& ext : vk::enumerateInstanceExtensionProperties(layer_name)) {
++                if (strcmp("VK_EXT_layer_settings", ext.extensionName.data()) == 0) {
++                    return true;
++                }
++            }
++        }
++    }
++
++    std::cerr << "ggml_vulkan: WARNING: Validation layer or layer extension VK_EXT_layer_settings not found." << std::endl;
++#endif
++    return false;
++}
++static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions) {
++#ifdef __APPLE__
++    // Check for portability enumeration extension for MoltenVK support
++    for (const auto& properties : instance_extensions) {
++        if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
++            return true;
++        }
++    }
++    std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
++#endif
++    return false;
++
++    UNUSED(instance_extensions);
++}
++
++// Extension availability
++static bool ggml_vk_instance_debug_utils_ext_available(
++    const std::vector & instance_extensions) {
++    // Check for portability enumeration extension for MoltenVK support
++    for (const auto & properties : instance_extensions) {
++        if (strcmp("VK_EXT_debug_utils", properties.extensionName) == 0) {
++            return true;
++        }
++    }
++
++    std::cerr << "ggml_vulkan: WARNING: Instance extension VK_EXT_debug_utils not found." << std::endl;
++    return false;
++
++    UNUSED(instance_extensions);
++}
++
++static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev) {
++    VkPhysicalDeviceFeatures2 device_features2;
++    device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
++
++    VkPhysicalDeviceVulkan11Features vk11_features;
++    vk11_features.pNext = nullptr;
++    vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
++    device_features2.pNext = &vk11_features;
++
++    vkGetPhysicalDeviceFeatures2(vkdev, &device_features2);
++
++    return vk11_features.storageBuffer16BitAccess;
++}
++
++static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
++    switch (props.vendorID) {
++    case VK_VENDOR_ID_INTEL:
++        // Only allowing Xe2 GPU at the moment since Xe2 GPU can gain significant performance boost,
++        // while some older hardware (ex. Arc A770) has performance regressions
++        return arch == vk_device_architecture::INTEL_XE2;
++    case VK_VENDOR_ID_AMD:
++        if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
++            // Workaround for AMD proprietary driver reporting support on all GPUs
++            return arch == vk_device_architecture::AMD_RDNA3;
++        }
++        return true;
++    default:
++        return true;
++    }
++}
++
++static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev) {
++    VkPhysicalDeviceProperties2 props = vkdev.getProperties2();
++
++    if (props.properties.vendorID != VK_VENDOR_ID_INTEL) {
++        return 0;
++    }
++
++    const uint32_t device_id = props.properties.deviceID;
++
++    switch (device_id) {
++    case 0x56A6:  // A310
++        return 6;
++    case 0x5693:  // A370M
++    case 0x56A5:  // A380
++    case 0x56B1:  // Pro A40/A50
++        return 8;
++    case 0x5697:  // A530M
++        return 12;
++    case 0x5692:  // A550M
++    case 0x56B3:  // Pro A60
++        return 16;
++    case 0x56A2:  // A580
++        return 24;
++    case 0x5691:  // A730M
++    case 0x56A1:  // A750
++        return 28;
++    case 0x56A0:  // A770
++    case 0x5690:  // A770M
++        return 32;
++    case 0xE212:  // Pro B50
++        return 16;
++    case 0xE20C:  // B570
++        return 18;
++    case 0xE20B:  // B580
++        return 20;
++    default:
++        return 0;
++    }
++}
++
++// checks
++
++#ifdef GGML_VULKAN_CHECK_RESULTS
++static void ggml_vk_print_graph_origin(const ggml_tensor * tensor, std::vector& done, int level = 0) {
++    if (std::find(done.begin(), done.end(), tensor) != done.end() || level > 10) {
++        return;
++    }
++    for (int j = 0; j < level; j++) {
++        std::cerr << " ";
++    }
++    std::cerr << ggml_op_name(tensor->op) << " gpu=" << (tensor->extra != nullptr) << std::endl;
++
++    done.push_back(tensor);
++
++    for (int i = 0; i < GGML_MAX_SRC; i++) {
++        if (tensor->src[i] != nullptr) {
++            ggml_vk_print_graph_origin(tensor->src[i], done, level + 1);
++        }
++    }
++}
++
++static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * data, int i0, int i1, int i2, int i3) {
++    if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16 && tensor->type != GGML_TYPE_I32) {
++        return;
++    }
++    i0 = std::max(i0, 5);
++    i1 = std::max(i1, 5);
++    i2 = std::max(i2, 0);
++    i3 = std::max(i3, 0);
++    fprintf(stderr, "         ");
++    for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
++        fprintf(stderr, "%7d ", idx1);
++    }
++    fprintf(stderr, "\n");
++    for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) {
++        fprintf(stderr, "%7d: ", idx0);
++        for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
++            if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) {
++                float val;
++                if (tensor->type == GGML_TYPE_F32) {
++                    val = *(const float *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]);
++                } else if (tensor->type == GGML_TYPE_F16) {
++                    val = ggml_fp16_to_fp32(*(const ggml_fp16_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]));
++                } else if (tensor->type == GGML_TYPE_I32) {
++                    val = *(const int32_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]);
++                } else {
++                    GGML_ABORT("fatal error");
++                }
++                fprintf(stderr, "% 7.2f ", val);
++            } else {
++                fprintf(stderr, "        ");
++            }
++        }
++        fprintf(stderr, "\n");
++    }
++}
++
++static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name) {
++    void * tensor_data = tensor->data;
++
++    const bool is_gpu = tensor->buffer != nullptr && ggml_backend_buffer_is_vk(tensor->buffer);
++
++    if (is_gpu) {
++        const size_t tensor_size = ggml_nbytes(tensor);
++        tensor_data = malloc(tensor_size);
++
++        ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
++
++        vk_buffer buffer_gpu = buf_ctx->dev_buffer;
++        ggml_vk_buffer_read(buffer_gpu, vk_tensor_offset(tensor) + tensor->view_offs, tensor_data, tensor_size);
++    }
++
++    std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl;
++    std::cerr << "tensor=" << tensor << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << std::endl;
++    if (tensor->src[0] != nullptr) {
++        std::cerr << "tensor->src[0]=" << tensor->src[0] << " name=" << tensor->src[0]->name << " op=" << ggml_op_name(tensor->src[0]->op) << " type=" << ggml_type_name(tensor->src[0]->type) << " ne0=" << tensor->src[0]->ne[0] << " nb0=" << tensor->src[0]->nb[0] << " ne1=" << tensor->src[0]->ne[1] << " nb1=" << tensor->src[0]->nb[1] << " ne2=" << tensor->src[0]->ne[2] << " nb2=" << tensor->src[0]->nb[2] << " ne3=" << tensor->src[0]->ne[3] << " nb3=" << tensor->src[0]->nb[3] << std::endl;
++    }
++    if (tensor->src[1] != nullptr) {
++        std::cerr << "tensor->src[1]=" << tensor->src[1] << " name=" << tensor->src[1]->name << " op=" << ggml_op_name(tensor->src[1]->op) << " type=" << ggml_type_name(tensor->src[1]->type) << " ne0=" << tensor->src[1]->ne[0] << " nb0=" << tensor->src[1]->nb[0] << " ne1=" << tensor->src[1]->ne[1] << " nb1=" << tensor->src[1]->nb[1] << " ne2=" << tensor->src[1]->ne[2] << " nb2=" << tensor->src[1]->nb[2] << " ne3=" << tensor->src[1]->ne[3] << " nb3=" << tensor->src[1]->nb[3] << std::endl;
++    }
++    std::cerr << std::endl << "Result:" << std::endl;
++    ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
++    std::cerr << std::endl;
++    std::vector done;
++    ggml_vk_print_graph_origin(tensor, done);
++
++    if (is_gpu) {
++        free(tensor_data);
++    }
++}
++
++void * comp_result;
++size_t comp_size;
++size_t comp_nb[GGML_MAX_DIMS];
++size_t check_counter = 0;
++static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
++    ggml_tensor * tensor = cgraph->nodes[tensor_idx + ctx->num_additional_fused_ops];
++    if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) {
++        return;
++    }
++
++    check_counter++;
++    if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
++        return;
++    }
++
++    VK_LOG_DEBUG("ggml_vk_check_results_0(" << tensor->name << ")");
++
++    struct ggml_init_params iparams = {
++        /*.mem_size   =*/ 2ul*1024ul*1024ul*1024ul,
++        /*.mem_buffer =*/ NULL,
++        /*.no_alloc   =*/ false,
++    };
++
++    struct ggml_context * ggml_ctx = ggml_init(iparams);
++
++    std::array src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
++    const char * srci_name[GGML_MAX_SRC] = {"src0", "src1", "src2", "src3", "src4", "src5", "src6", "src7", "src8", "src9"};
++
++    std::map cloned_tensors;
++    std::vector cloned_mallocs;
++
++    struct ggml_tensor * tensor_clone = nullptr;
++
++    for (int f = 0; f < ctx->num_additional_fused_ops + 1; ++f) {
++        tensor = cgraph->nodes[tensor_idx + f];
++        for (int i = 0; i < GGML_MAX_SRC; i++) {
++            ggml_tensor * srci = tensor->src[i];
++            if (srci == nullptr) {
++                continue;
++            }
++            // If a src tensor has been cloned, use that one
++            auto it = cloned_tensors.find(srci);
++            if (it != cloned_tensors.end()) {
++                src_clone[i] = it->second;
++                continue;
++            }
++            ggml_tensor * srci_clone = ggml_dup_tensor(ggml_ctx, srci);
++            size_t srci_size = ggml_nbytes(srci);
++
++            src_clone[i] = srci_clone;
++            void *src_buffer = malloc(srci_size);
++            cloned_mallocs.push_back(src_buffer);
++
++            srci_clone->data = src_buffer;
++            if (ggml_backend_buffer_is_host(srci->buffer)) {
++                memcpy(srci_clone->data, srci->data, srci_size);
++                memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS);
++            } else if (ggml_backend_buffer_is_vk(srci->buffer)) {
++                ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)srci->buffer->context;
++                vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
++                uint64_t offset = vk_tensor_offset(srci) + srci->view_offs;
++                if (!ggml_is_contiguous(srci) && ggml_vk_dim01_contiguous(srci)) {
++                    for (int i3 = 0; i3 < srci->ne[3]; i3++) {
++                        for (int i2 = 0; i2 < srci->ne[2]; i2++) {
++                            const int idx = i3*srci->ne[2] + i2;
++                            ggml_vk_buffer_read(buffer_gpu, offset + idx * srci->nb[2], ((char *)srci_clone->data + idx * srci_clone->nb[2]), srci->ne[1] * srci->nb[1]);
++                        }
++                    }
++
++                    srci_clone->nb[0] = srci->nb[0];
++                    srci_clone->nb[1] = srci->nb[1];
++                    for (int i = 2; i < GGML_MAX_DIMS; i++) {
++                        srci_clone->nb[i] = srci_clone->nb[i - 1]*srci_clone->ne[i - 1];
++                    }
++                } else {
++                    if (offset + srci_size >= buffer_gpu->size) {
++                        srci_size = buffer_gpu->size - offset;
++                    }
++                    ggml_vk_buffer_read(buffer_gpu, offset, srci_clone->data, srci_size);
++                    memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS);
++                }
++            } else {
++                GGML_ABORT("fatal error");
++            }
++
++            if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
++                ggml_vk_print_tensor(srci, srci_name[i]);
++            }
++        }
++
++        if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
++            const float * params = (const float *)tensor->op_params;
++            tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
++            if (src_clone[4]) {
++                ggml_flash_attn_ext_add_sinks(tensor_clone, src_clone[4]);
++            }
++        } else if (tensor->op == GGML_OP_MUL_MAT) {
++            tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
++        } else if (tensor->op == GGML_OP_MUL_MAT_ID) {
++            tensor_clone = ggml_mul_mat_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
++        } else if (tensor->op == GGML_OP_SUB) {
++            tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
++        } else if (tensor->op == GGML_OP_MUL) {
++            tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
++        } else if (tensor->op == GGML_OP_DIV) {
++            tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
++        } else if (tensor->op == GGML_OP_CONCAT) {
++            tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params);
++        } else if (tensor->op == GGML_OP_UPSCALE) {
++            tensor_clone = ggml_interpolate(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]);
++        } else if (tensor->op == GGML_OP_SCALE) {
++            const float * params = (const float *)tensor->op_params;
++            tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]);
++        } else if (tensor->op == GGML_OP_ADD1) {
++            tensor_clone = ggml_add1(ggml_ctx, src_clone[0], src_clone[1]);
++        } else if (tensor->op == GGML_OP_ARANGE) {
++            const float start = ggml_get_op_params_f32(tensor, 0);
++            const float stop = ggml_get_op_params_f32(tensor, 1);
++            const float step = ggml_get_op_params_f32(tensor, 2);
++            tensor_clone = ggml_arange(ggml_ctx, start, stop, step);
++        } else if (tensor->op == GGML_OP_FILL) {
++            const float value = ggml_get_op_params_f32(tensor, 0);
++            tensor_clone = ggml_fill(ggml_ctx, tensor_clone, value);
++        } else if (tensor->op == GGML_OP_SQR) {
++            tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
++        } else if (tensor->op == GGML_OP_SQRT) {
++            tensor_clone = ggml_sqrt(ggml_ctx, src_clone[0]);
++        } else if (tensor->op == GGML_OP_SIN) {
++            tensor_clone = ggml_sin(ggml_ctx, src_clone[0]);
++        } else if (tensor->op == GGML_OP_COS) {
++            tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
++        } else if (tensor->op == GGML_OP_LOG) {
++            tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
++        } else if (tensor->op == GGML_OP_TRI) {
++            tensor_clone = ggml_tri(ggml_ctx, src_clone[0], (ggml_tri_type)ggml_get_op_params_i32(tensor, 0));
++        } else if (tensor->op == GGML_OP_DIAG) {
++            tensor_clone = ggml_diag(ggml_ctx, src_clone[0]);
++        } else if (tensor->op == GGML_OP_CLAMP) {
++            const float * params = (const float *)tensor->op_params;
++            tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
++        } else if (tensor->op == GGML_OP_PAD) {
++            tensor_clone = ggml_pad_ext(ggml_ctx, src_clone[0], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3],
++                                                                tensor->op_params[4], tensor->op_params[5], tensor->op_params[6], tensor->op_params[7]);
++        } else if (tensor->op == GGML_OP_REPEAT) {
++            tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor);
++        } else if (tensor->op == GGML_OP_REPEAT_BACK) {
++            tensor_clone = ggml_repeat_back(ggml_ctx, src_clone[0], tensor);
++        } else if (tensor->op == GGML_OP_ADD) {
++            tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);
++        } else if (tensor->op == GGML_OP_ACC) {
++            tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
++        } else if (tensor->op == GGML_OP_SET) {
++            tensor_clone = ggml_set(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
++        } else if (tensor->op == GGML_OP_NORM) {
++            tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
++        } else if (tensor->op == GGML_OP_GROUP_NORM) {
++            const float * float_params = (const float *)tensor->op_params;
++            tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]);
++        } else if (tensor->op == GGML_OP_RMS_NORM) {
++            tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
++        } else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
++            const float eps = ((float *) tensor->op_params)[0];
++            tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
++        } else if (tensor->op == GGML_OP_SILU_BACK) {
++            tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
++        } else if (tensor->op == GGML_OP_L2_NORM) {
++            const float eps = ((float *) tensor->op_params)[0];
++            tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
++        } else if (tensor->op == GGML_OP_SOFT_MAX) {
++            if (tensor->src[1] != nullptr) {
++                const float * params = (const float *)tensor->op_params;
++                tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], params[0], params[1]);
++            } else {
++                tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
++            }
++        } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
++            tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
++        } else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
++            tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], tensor->op_params[0]);
++        } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
++            const int n_dims      = ((int32_t *) tensor->op_params)[1];
++            const int mode        = ((int32_t *) tensor->op_params)[2];
++            //const int n_ctx_ggml       = ((int32_t *) tensor->op_params)[3];
++            const int n_ctx_orig_ggml  = ((int32_t *) tensor->op_params)[4];
++            const float freq_base       = ((float *) tensor->op_params)[5];
++            const float freq_scale      = ((float *) tensor->op_params)[6];
++            const float ext_factor      = ((float *) tensor->op_params)[7];
++            const float attn_factor     = ((float *) tensor->op_params)[8];
++            const float beta_fast       = ((float *) tensor->op_params)[9];
++            const float beta_slow       = ((float *) tensor->op_params)[10];
++            if (mode & GGML_ROPE_TYPE_MROPE) {
++                int32_t *sections = ((int32_t *) tensor->op_params) + 11;
++                if (tensor->op == GGML_OP_ROPE) {
++                    tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
++                } else {
++                    tensor_clone = ggml_rope_multi_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
++                }
++            } else {
++                if (tensor->op == GGML_OP_ROPE) {
++                    tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
++                } else {
++                    tensor_clone = ggml_rope_ext_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
++                }
++            }
++        } else if (tensor->op == GGML_OP_UNARY) {
++            switch (ggml_get_unary_op(tensor)) {
++            case GGML_UNARY_OP_EXP:
++                tensor_clone = ggml_exp(ggml_ctx, src_clone[0]);
++                break;
++            case GGML_UNARY_OP_SILU:
++                tensor_clone = ggml_silu(ggml_ctx, src_clone[0]);
++                break;
++            case GGML_UNARY_OP_GELU:
++                tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]);
++                break;
++            case GGML_UNARY_OP_GELU_ERF:
++                tensor_clone = ggml_gelu_erf(ggml_ctx, src_clone[0]);
++                break;
++            case GGML_UNARY_OP_GELU_QUICK:
++                tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]);
++                break;
++            case GGML_UNARY_OP_RELU:
++                tensor_clone = ggml_relu(ggml_ctx, src_clone[0]);
++                break;
++            case GGML_UNARY_OP_XIELU:
++                tensor_clone = ggml_xielu(ggml_ctx, src_clone[0], 0, 0, 0, 0);
++                ggml_set_op_params_f32(tensor_clone, 1, ggml_get_op_params_f32(tensor, 1));
++                ggml_set_op_params_f32(tensor_clone, 2, ggml_get_op_params_f32(tensor, 2));
++                ggml_set_op_params_f32(tensor_clone, 3, ggml_get_op_params_f32(tensor, 3));
++                ggml_set_op_params_f32(tensor_clone, 4, ggml_get_op_params_f32(tensor, 4));
++                break;
++            case GGML_UNARY_OP_NEG:
++                tensor_clone = ggml_neg(ggml_ctx, src_clone[0]);
++                break;
++            case GGML_UNARY_OP_TANH:
++                tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]);
++                break;
++            case GGML_UNARY_OP_SIGMOID:
++                tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]);
++                break;
++            case GGML_UNARY_OP_HARDSIGMOID:
++                tensor_clone = ggml_hardsigmoid(ggml_ctx, src_clone[0]);
++                break;
++            case GGML_UNARY_OP_HARDSWISH:
++                tensor_clone = ggml_hardswish(ggml_ctx, src_clone[0]);
++                break;
++            case GGML_UNARY_OP_ABS:
++                tensor_clone = ggml_abs(ggml_ctx, src_clone[0]);
++                break;
++            case GGML_UNARY_OP_SOFTPLUS:
++                tensor_clone = ggml_softplus(ggml_ctx, src_clone[0]);
++                break;
++            case GGML_UNARY_OP_STEP:
++                tensor_clone = ggml_step(ggml_ctx, src_clone[0]);
++                break;
++            case GGML_UNARY_OP_ROUND:
++                tensor_clone = ggml_round(ggml_ctx, src_clone[0]);
++                break;
++            case GGML_UNARY_OP_CEIL:
++                tensor_clone = ggml_ceil(ggml_ctx, src_clone[0]);
++                break;
++            case GGML_UNARY_OP_FLOOR:
++                tensor_clone = ggml_floor(ggml_ctx, src_clone[0]);
++                break;
++            case GGML_UNARY_OP_TRUNC:
++                tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]);
++                break;
++            default:
++                std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
++                GGML_ABORT("fatal error");
++            }
++        } else if (tensor->op == GGML_OP_GLU) {
++            if (src_clone[1] == nullptr) {
++                tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
++            } else {
++                tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]);
++            }
++            ggml_set_op_params_i32(tensor_clone, 2, ggml_get_op_params_i32(tensor, 2));
++            ggml_set_op_params_i32(tensor_clone, 3, ggml_get_op_params_i32(tensor, 3));
++        } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
++            if (tensor->src[1] == nullptr) {
++                tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
++                tensor_clone->type = tensor->type;
++            } else {
++                tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);
++            }
++        } else if (tensor->op == GGML_OP_CONT) {
++            tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
++        } else if (tensor->op == GGML_OP_RESHAPE) {
++            tensor_clone = ggml_reshape_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
++        } else if (tensor->op == GGML_OP_VIEW) {
++            tensor_clone = ggml_view_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]);
++        } else if (tensor->op == GGML_OP_PERMUTE) {
++            int32_t * params = (int32_t *)tensor->op_params;
++            tensor_clone = ggml_permute(ggml_ctx, src_clone[0], params[0], params[1], params[2], params[3]);
++        } else if (tensor->op == GGML_OP_TRANSPOSE) {
++            tensor_clone = ggml_transpose(ggml_ctx, src_clone[0]);
++        } else if (tensor->op == GGML_OP_GET_ROWS) {
++            tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]);
++        } else if (tensor->op == GGML_OP_ARGSORT) {
++            tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params);
++        } else if (tensor->op == GGML_OP_TOP_K) {
++            tensor_clone = ggml_top_k(ggml_ctx, src_clone[0], tensor->ne[0]);
++        } else if (tensor->op == GGML_OP_SUM) {
++            tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
++        } else if (tensor->op == GGML_OP_SUM_ROWS) {
++            tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
++        } else if (tensor->op == GGML_OP_CUMSUM) {
++            tensor_clone = ggml_cumsum(ggml_ctx, src_clone[0]);
++        } else if (tensor->op == GGML_OP_MEAN) {
++            tensor_clone = ggml_mean(ggml_ctx, src_clone[0]);
++        } else if (tensor->op == GGML_OP_ARGMAX) {
++            tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]);
++        } else if (tensor->op == GGML_OP_COUNT_EQUAL) {
++            tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]);
++        } else if (tensor->op == GGML_OP_SOLVE_TRI) {
++            tensor_clone = ggml_solve_tri(ggml_ctx, src_clone[0], src_clone[1], true, true, false);
++        } else if (tensor->op == GGML_OP_IM2COL) {
++            const int32_t s0 = tensor->op_params[0];
++            const int32_t s1 = tensor->op_params[1];
++            const int32_t p0 = tensor->op_params[2];
++            const int32_t p1 = tensor->op_params[3];
++            const int32_t d0 = tensor->op_params[4];
++            const int32_t d1 = tensor->op_params[5];
++
++            const bool is_2D = tensor->op_params[6] == 1;
++            tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type);
++        } else if (tensor->op == GGML_OP_IM2COL_3D) {
++            const int32_t s0 = tensor->op_params[0];
++            const int32_t s1 = tensor->op_params[1];
++            const int32_t s2 = tensor->op_params[2];
++            const int32_t p0 = tensor->op_params[3];
++            const int32_t p1 = tensor->op_params[4];
++            const int32_t p2 = tensor->op_params[5];
++            const int32_t d0 = tensor->op_params[6];
++            const int32_t d1 = tensor->op_params[7];
++            const int32_t d2 = tensor->op_params[8];
++            const int32_t IC = tensor->op_params[9];
++
++            tensor_clone = ggml_im2col_3d(ggml_ctx, src_clone[0], src_clone[1], IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, tensor->type);
++        } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) {
++            const int32_t dim = tensor->op_params[0];
++            const int32_t max_period = tensor->op_params[1];
++            tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period);
++        } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_1D){
++            const int32_t s0 = tensor->op_params[0];
++            const int32_t p0 = tensor->op_params[1];
++            const int32_t d0 = tensor->op_params[2];
++            tensor_clone = ggml_conv_transpose_1d(ggml_ctx, src_clone[0], src_clone[1], s0, p0, d0);
++        } else if (tensor->op == GGML_OP_POOL_2D) {
++            enum ggml_op_pool op = static_cast(tensor->op_params[0]);
++            const int32_t k0 = tensor->op_params[1];
++            const int32_t k1 = tensor->op_params[2];
++            const int32_t s0 = tensor->op_params[3];
++            const int32_t s1 = tensor->op_params[4];
++            const int32_t p0 = tensor->op_params[5];
++            const int32_t p1 = tensor->op_params[6];
++
++            tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1);
++        } else if (tensor->op == GGML_OP_CONV_2D) {
++            const int32_t s0 = tensor->op_params[0];
++            const int32_t s1 = tensor->op_params[1];
++            const int32_t p0 = tensor->op_params[2];
++            const int32_t p1 = tensor->op_params[3];
++            const int32_t d0 = tensor->op_params[4];
++            const int32_t d1 = tensor->op_params[5];
++            tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
++        } else if (tensor->op == GGML_OP_CONV_2D_DW) {
++            const int32_t s0 = tensor->op_params[0];
++            const int32_t s1 = tensor->op_params[1];
++            const int32_t p0 = tensor->op_params[2];
++            const int32_t p1 = tensor->op_params[3];
++            const int32_t d0 = tensor->op_params[4];
++            const int32_t d1 = tensor->op_params[5];
++            tensor_clone = ggml_conv_2d_dw_direct(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
++        } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_2D) {
++            const int32_t s = tensor->op_params[0];
++            tensor_clone = ggml_conv_transpose_2d_p0(ggml_ctx, src_clone[0], src_clone[1], s);
++        } else if (tensor->op == GGML_OP_LEAKY_RELU) {
++            const float * op_params = (const float *)tensor->op_params;
++            tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false);
++        } else if (tensor->op == GGML_OP_RWKV_WKV6) {
++            tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
++            src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
++        } else if (tensor->op == GGML_OP_RWKV_WKV7) {
++            tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
++            src_clone[4], src_clone[5], src_clone[6]);
++        } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
++            src_clone[0]->flags = tensor->src[0]->flags;
++            tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
++            src_clone[2], src_clone[3], src_clone[4]);
++        } else if (tensor->op == GGML_OP_OPT_STEP_SGD) {
++            src_clone[0]->flags = tensor->src[0]->flags;
++            tensor_clone = ggml_opt_step_sgd(ggml_ctx, src_clone[0], src_clone[1],
++            src_clone[2]);
++        } else if (tensor->op == GGML_OP_ADD_ID) {
++            tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
++        } else if (tensor->op == GGML_OP_SSM_SCAN) {
++            tensor_clone = ggml_ssm_scan(ggml_ctx, src_clone[0], src_clone[1], src_clone[2],
++                                         src_clone[3], src_clone[4], src_clone[5], src_clone[6]);
++        } else if (tensor->op == GGML_OP_SSM_CONV) {
++            tensor_clone = ggml_ssm_conv(ggml_ctx, src_clone[0], src_clone[1]);
++        } else if (tensor->op == GGML_OP_ROLL) {
++            const int32_t s0 = tensor->op_params[0];
++            const int32_t s1 = tensor->op_params[1];
++            const int32_t s2 = tensor->op_params[2];
++            const int32_t s3 = tensor->op_params[3];
++            tensor_clone = ggml_roll(ggml_ctx, src_clone[0], s0, s1, s2, s3);
++        }
++        else {
++            std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
++            GGML_ABORT("fatal error");
++        }
++        cloned_tensors[tensor] = tensor_clone;
++    }
++
++    ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx);
++    ggml_build_forward_expand(cgraph_cpu, tensor_clone);
++
++    ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8);
++
++    if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
++        ggml_vk_print_tensor(tensor_clone, "tensor_clone");
++    }
++
++    comp_size = ggml_nbytes(tensor_clone);
++
++    comp_result = malloc(comp_size);
++    memcpy(comp_result, tensor_clone->data, comp_size);
++    memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS);
++
++    for (auto m : cloned_mallocs) {
++        free(m);
++    }
++
++    ggml_free(ggml_ctx);
++
++    VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")");
++}
++
++static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
++    ggml_tensor * tensor = cgraph->nodes[tensor_idx + ctx->num_additional_fused_ops];
++    if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) {
++        return;
++    }
++
++    if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
++        return;
++    }
++
++    VK_LOG_DEBUG("ggml_vk_check_results_1(" << tensor->name << ")");
++
++    ggml_tensor * src0 = tensor->src[0];
++    ggml_tensor * src1 = tensor->src[1];
++    ggml_tensor * src2 = tensor->src[2];
++    ggml_tensor * src3 = tensor->src[3];
++
++    void * tensor_data = tensor->data;
++
++    if (ggml_backend_buffer_is_vk(tensor->buffer)) {
++        size_t tensor_size = ggml_nbytes(tensor);
++        tensor_data = malloc(tensor_size);
++
++        ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
++
++        vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
++        uint64_t offset = vk_tensor_offset(tensor) + tensor->view_offs;
++        if (offset + tensor_size >= buffer_gpu->size) {
++            tensor_size = buffer_gpu->size - offset;
++        }
++
++        ggml_vk_buffer_read(buffer_gpu, offset, tensor_data, tensor_size);
++    }
++
++    float first_error_result = -1.0f;
++    float first_error_correct = -1.0f;
++    std::array first_error = { -1, -1, -1, -1 };
++    double avg_err = 0.0;
++    size_t counter = 0;
++
++    for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
++        for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
++            for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
++                for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
++                    const bool buffer_size_fit = i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0] < comp_size;
++                    float correct = 0.0f;
++                    float result = 0.0f;
++
++                    if (buffer_size_fit) {
++                        if (tensor->type == GGML_TYPE_F32) {
++                            correct = *(float *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
++                            result  = *(float *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
++                        } else if (tensor->type == GGML_TYPE_F16) {
++                            correct = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]));
++                            result  = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]));
++                        } else if (tensor->type == GGML_TYPE_BF16) {
++                            correct = ggml_bf16_to_fp32(*(ggml_bf16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]));
++                            result  = ggml_bf16_to_fp32(*(ggml_bf16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]));
++                        } else if (tensor->type == GGML_TYPE_I32) {
++                            correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
++                            result  = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
++                        } else if (tensor->type == GGML_TYPE_I64) {
++                            correct = *(int64_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
++                            result  = *(int64_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
++                        } else {
++                            std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl;
++                        }
++                    } else {
++                        std::cerr << "Missing debug code for type " << ggml_type_name(tensor->type) << std::endl;
++                        GGML_ABORT("fatal error");
++                    }
++
++                    if ((std::isnan(correct) != std::isnan(result)) || (std::isinf(correct) != std::isinf(result)) || !buffer_size_fit) {
++                        std::cerr << "ERROR: Invalid value in " << ggml_op_name(tensor->op) << " i3=" << i3 << " i2=" << i2 << " i1=" << i1 << " i0=" << i0 << " result=" << result << " correct=" << correct << " avg_err=" << (avg_err / counter) << std::endl;
++                        std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
++                        if (src0 != nullptr) {
++                            std::cerr << "src0=" << src0 << " src0->name=" << src0->name << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl;
++                        }
++                        if (src1 != nullptr) {
++                            std::cerr << "src1=" << src1 << " src1->name=" << src1->name << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl;
++                        }
++                        if (src2 != nullptr) {
++                            std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
++                        }
++                        if (src3 != nullptr) {
++                            std::cerr << "src3=" << src3 << " src3->name=" << src3->name << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
++                        }
++                        std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct  << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
++                        std::cerr << std::endl << "Result:" << std::endl;
++                        ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3);
++                        std::cerr << std::endl << "Correct:" << std::endl;
++                        ggml_vk_print_tensor_area(tensor, comp_result, i0, i1, i2, i3);
++                        std::cerr << std::endl;
++                        std::vector done;
++                        ggml_vk_print_graph_origin(tensor, done);
++                        GGML_ABORT("fatal error");
++                    }
++                    const double denom = std::fabs(correct) > 1.0f ? (std::fabs(correct) > 1e-8 ? std::fabs(correct) : 1e-8) : 1.0f;
++                    if (first_error[0] == -1 && std::fabs(correct - result) / denom > 0.5) {
++                        first_error[0] = i0;
++                        first_error[1] = i1;
++                        first_error[2] = i2;
++                        first_error[3] = i3;
++                        first_error_result = result;
++                        first_error_correct = correct;
++                    }
++
++                    // Special case, value is infinite, avoid NaN result in avg_err
++                    // NaN also appears in results, if both are nan error is 0
++                    if (!std::isinf(correct) && !std::isinf(result) && !std::isnan(correct) && !std::isnan(result)) {
++                        avg_err += std::fabs(correct - result) / denom;
++                    }
++                    counter++;
++                }
++            }
++        }
++    }
++
++    avg_err /= counter;
++
++    if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
++        std::cerr << "TENSOR CHECK: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl;
++        std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
++        if (src0 != nullptr) {
++            std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl;
++        }
++        if (src1 != nullptr) {
++            std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl;
++        }
++        if (src2 != nullptr) {
++            std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
++        }
++        if (src3 != nullptr) {
++            std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
++        }
++        std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct  << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
++        std::cerr << std::endl << "Result:" << std::endl;
++        ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
++        std::cerr << std::endl << "Correct:" << std::endl;
++        ggml_vk_print_tensor_area(tensor, comp_result, 5, 5, 0, 0);
++        std::cerr << std::endl;
++        std::vector done;
++        ggml_vk_print_graph_origin(tensor, done);
++    }
++
++    if (avg_err > 0.01 || std::isnan(avg_err)) {
++        std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl;
++        std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
++        if (src0 != nullptr) {
++            std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl;
++        }
++        if (src1 != nullptr) {
++            std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl;
++        }
++        if (src2 != nullptr) {
++            std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
++        }
++        if (src3 != nullptr) {
++            std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
++        }
++        std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct  << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
++        std::cerr << std::endl << "Result:" << std::endl;
++        ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]);
++        std::cerr << std::endl << "Correct:" << std::endl;
++        ggml_vk_print_tensor_area(tensor, comp_result, first_error[0], first_error[1], first_error[2], first_error[3]);
++        std::cerr << std::endl;
++        std::vector done;
++        ggml_vk_print_graph_origin(tensor, done);
++        GGML_ABORT("fatal error");
++    } else {
++        std::cerr << check_counter << " " << tensor->name << " op=" << ggml_op_name(tensor->op) << " avg_err=" << avg_err << std::endl;
++    }
++
++    free(comp_result);
++    comp_result = nullptr;
++    comp_size = 0;
++
++    if (ggml_backend_buffer_is_vk(tensor->buffer)) {
++        free(tensor_data);
++    }
++
++    VK_LOG_DEBUG("END ggml_vk_check_results_1(" << tensor->name << ")");
++}
++#endif
++
++GGML_BACKEND_DL_IMPL(ggml_backend_vk_reg)
+diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp
+index 1c00d3cb2..5e6859506 100644
+--- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp
++++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp
+@@ -2035,8 +2035,9 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx,
+     }
+ }
+ 
+-static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
++static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph, int batch_size) {
+     WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)");
++    GGML_UNUSED(batch_size);
+ 
+     ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context;
+     webgpu_context                ctx         = backend_ctx->webgpu_ctx;
+diff --git a/ggml/src/ggml-zdnn/ggml-zdnn.cpp b/ggml/src/ggml-zdnn/ggml-zdnn.cpp
+index 9b6938abf..9c5f3ae84 100644
+--- a/ggml/src/ggml-zdnn/ggml-zdnn.cpp
++++ b/ggml/src/ggml-zdnn/ggml-zdnn.cpp
+@@ -412,7 +412,8 @@ static void ggml_backend_zdnn_free(ggml_backend_t backend) {
+     free(backend);
+ }
+ 
+-static enum ggml_status ggml_backend_zdnn_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
++static enum ggml_status ggml_backend_zdnn_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) {
++    GGML_UNUSED(batch_size);
+     return ggml_zdnn_graph_compute(backend, cgraph);
+ }
+ 
+diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp
+index c87603040..88b3bd628 100644
+--- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp
++++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp
+@@ -204,9 +204,11 @@ static void ggml_backend_zendnn_free(ggml_backend_t backend) {
+     delete backend;
+ }
+ 
+-static ggml_status ggml_backend_zendnn_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
++static ggml_status ggml_backend_zendnn_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) {
+     ggml_backend_zendnn_context * ctx = (ggml_backend_zendnn_context *)backend->context;
+ 
++    GGML_UNUSED(batch_size);
++
+     for (int i = 0; i < cgraph->n_nodes; i++) {
+         struct ggml_tensor * node = cgraph->nodes[i];
+ 
diff --git a/llama/patches/0019-fix-mtmd-audio.cpp-build-on-windows.patch b/llama/patches/0018-fix-mtmd-audio.cpp-build-on-windows.patch
similarity index 94%
rename from llama/patches/0019-fix-mtmd-audio.cpp-build-on-windows.patch
rename to llama/patches/0018-fix-mtmd-audio.cpp-build-on-windows.patch
index 2c4e305045d..b9b541d53dd 100644
--- a/llama/patches/0019-fix-mtmd-audio.cpp-build-on-windows.patch
+++ b/llama/patches/0018-fix-mtmd-audio.cpp-build-on-windows.patch
@@ -8,7 +8,7 @@ Subject: [PATCH] fix mtmd-audio.cpp build on windows
  1 file changed, 1 insertion(+), 1 deletion(-)
 
 diff --git a/tools/mtmd/mtmd-audio.cpp b/tools/mtmd/mtmd-audio.cpp
-index f68829a61..2024d3d37 100644
+index e8eef035f..a208c7789 100644
 --- a/tools/mtmd/mtmd-audio.cpp
 +++ b/tools/mtmd/mtmd-audio.cpp
 @@ -1,6 +1,6 @@
diff --git a/llama/patches/0018-ggml-Add-batch-size-hint.patch b/llama/patches/0018-ggml-Add-batch-size-hint.patch
deleted file mode 100644
index 5b66ee3621a..00000000000
--- a/llama/patches/0018-ggml-Add-batch-size-hint.patch
+++ /dev/null
@@ -1,300 +0,0 @@
-From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: Jesse Gross 
-Date: Tue, 28 Oct 2025 17:36:54 -0700
-Subject: [PATCH] ggml: Add batch size hint
-
-Some operations use heuristics to determine the batch size, which
-affects offloading decisions. However, these are not always
-accurate when looking at single operations. This provides an
-explicit signal on the batch size from higher layers to ensure
-consistent performance.
----
- ggml/include/ggml-backend.h          |  5 ++-
- ggml/src/ggml-backend-impl.h         |  4 +--
- ggml/src/ggml-backend.cpp            | 19 +++++++----
- ggml/src/ggml-blas/ggml-blas.cpp     |  3 +-
- ggml/src/ggml-cpu/ggml-cpu.cpp       |  4 ++-
- ggml/src/ggml-cuda/ggml-cuda.cu      | 48 +++++++++++++++++-----------
- ggml/src/ggml-metal/ggml-metal.cpp   |  4 ++-
- ggml/src/ggml-vulkan/ggml-vulkan.cpp |  3 +-
- 8 files changed, 58 insertions(+), 32 deletions(-)
-
-diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
-index 03557bb31..93c95602d 100644
---- a/ggml/include/ggml-backend.h
-+++ b/ggml/include/ggml-backend.h
-@@ -98,7 +98,7 @@ extern "C" {
- 
-     GGML_API enum ggml_status ggml_backend_graph_plan_compute (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
-     GGML_API enum ggml_status ggml_backend_graph_compute      (ggml_backend_t backend, struct ggml_cgraph * cgraph);
--    GGML_API enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph);
-+    GGML_API enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph, int batch_size);
- 
-     // NOTE: will be removed, use device version instead
-     GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op);
-@@ -307,6 +307,9 @@ extern "C" {
-     GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload);
-     GGML_API void                 ggml_backend_sched_free(ggml_backend_sched_t sched);
- 
-+    // Provide a hint on the batch size to optimize processing (uses heuristics if unset)
-+    GGML_API void                 ggml_backend_sched_set_batch_size(ggml_backend_sched_t sched, int batch_size);
-+
-     // Initialize backend buffers from a measure graph
-     GGML_API void                 ggml_backend_sched_reserve_size(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph, size_t * sizes);
-     GGML_API bool                 ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); // returns success
-diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h
-index 6792ba986..0f5b03cef 100644
---- a/ggml/src/ggml-backend-impl.h
-+++ b/ggml/src/ggml-backend-impl.h
-@@ -106,8 +106,8 @@ extern "C" {
-         // compute the graph with the plan
-         enum ggml_status          (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
- 
--        // compute graph (always async if supported by the backend)
--        enum ggml_status          (*graph_compute)     (ggml_backend_t backend, struct ggml_cgraph * cgraph);
-+        // compute graph (always async if supported by the backend). batch_size may be -1 if unknown
-+        enum ggml_status          (*graph_compute)     (ggml_backend_t backend, struct ggml_cgraph * cgraph, int batch_size);
- 
-         // (optional) event synchronization
-         // record an event on this stream
-diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
-index 1459d16dd..498186a7c 100644
---- a/ggml/src/ggml-backend.cpp
-+++ b/ggml/src/ggml-backend.cpp
-@@ -353,14 +353,14 @@ enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_ba
- }
- 
- enum ggml_status ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
--    enum ggml_status err = ggml_backend_graph_compute_async(backend, cgraph);
-+    enum ggml_status err = ggml_backend_graph_compute_async(backend, cgraph, -1);
-     ggml_backend_synchronize(backend);
-     return err;
- }
- 
--enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
-+enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph, int batch_size) {
-     GGML_ASSERT(backend);
--    return backend->iface.graph_compute(backend, cgraph);
-+    return backend->iface.graph_compute(backend, cgraph, batch_size);
- }
- 
- bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
-@@ -727,6 +727,8 @@ struct ggml_backend_sched {
- 
-     bool op_offload;
- 
-+    int batch_size; // a hint on the batch size to optimize processing, -1 to use heuristics
-+
-     int debug;
- 
-     // used for debugging graph reallocations [GGML_SCHED_DEBUG_REALLOC]
-@@ -825,7 +827,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
-         if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
-             int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor);
-             // check if a backend with higher prio wants to offload the op
--            if (sched->op_offload && src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) {
-+            if (sched->op_offload && (sched->batch_size < 0 || sched->batch_size >= 32) && src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) {
-                 for (int b = 0; b < src_backend_id; b++) {
-                     if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) {
-                         SET_CAUSE(tensor, "1.off");
-@@ -1577,7 +1579,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
-         }
- 
-         if (!sched->callback_eval) {
--            enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph);
-+            enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph, sched->batch_size);
-             if (ec != GGML_STATUS_SUCCESS) {
-                 return ec;
-             }
-@@ -1599,7 +1601,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
- 
-                 struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1);
- 
--                enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &gv);
-+                enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &gv, sched->batch_size);
-                 if (ec != GGML_STATUS_SUCCESS) {
-                     return ec;
-                 }
-@@ -1689,6 +1691,7 @@ ggml_backend_sched_t ggml_backend_sched_new(
- 
-     sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends);
-     sched->op_offload = op_offload;
-+    sched->batch_size = -1;
- 
-     ggml_backend_sched_reset(sched);
- 
-@@ -1720,6 +1723,10 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
-     free(sched);
- }
- 
-+void ggml_backend_sched_set_batch_size(ggml_backend_sched_t sched, int batch_size) {
-+    sched->batch_size = batch_size;
-+}
-+
- void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
-     GGML_ASSERT(sched);
-     // reset state for the next run
-diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp
-index 5b888cdd8..88d088952 100644
---- a/ggml/src/ggml-blas/ggml-blas.cpp
-+++ b/ggml/src/ggml-blas/ggml-blas.cpp
-@@ -224,7 +224,7 @@ static void ggml_backend_blas_free(ggml_backend_t backend) {
-     delete backend;
- }
- 
--static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
-+static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph, int batch_size) {
-     ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
- 
-     for (int i = 0; i < cgraph->n_nodes; i++) {
-@@ -254,6 +254,7 @@ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend,
-     return GGML_STATUS_SUCCESS;
- 
-     GGML_UNUSED(backend);
-+    GGML_UNUSED(batch_size);
- }
- 
- static struct ggml_backend_i blas_backend_i = {
-diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp
-index f4713a421..92ba577a5 100644
---- a/ggml/src/ggml-cpu/ggml-cpu.cpp
-+++ b/ggml/src/ggml-cpu/ggml-cpu.cpp
-@@ -164,7 +164,7 @@ static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backe
-     GGML_UNUSED(backend);
- }
- 
--static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
-+static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph, int batch_size) {
-     struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
- 
-     struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool);
-@@ -184,6 +184,8 @@ static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, s
-     cplan.abort_callback_data = cpu_ctx->abort_callback_data;
- 
-     return ggml_graph_compute(cgraph, &cplan);
-+
-+    GGML_UNUSED(batch_size);
- }
- 
- static const struct ggml_backend_i ggml_backend_cpu_i = {
-diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
-index c9d3a2b03..25548629d 100644
---- a/ggml/src/ggml-cuda/ggml-cuda.cu
-+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
-@@ -2901,7 +2901,7 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
- 
- #ifdef USE_CUDA_GRAPH
- static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
--    bool use_cuda_graph) {
-+    int batch_size, bool use_cuda_graph) {
- 
-     // Loop over nodes in GGML graph to obtain info needed for CUDA graph
- 
-@@ -2934,24 +2934,34 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
- #endif
-         }
- 
--        if (node->op == GGML_OP_ADD &&
--            node->src[1] && node->src[1]->ne[1] > 1 &&
--            (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
--            (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
--            strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
--            strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
--            strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
--            strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
--            strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
--            // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
--            // by means of matching node names. See
--            // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
--            // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
--            // Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
--            use_cuda_graph = false;
-+        // If we have an explicit batch size hint then we don't need to use the tensor name heuristics
-+        if (batch_size >= 0) {
-+            if (batch_size > 1) {
-+                use_cuda_graph = false;
- #ifndef NDEBUG
--            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
-+                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%d]\n", __func__, batch_size);
- #endif
-+            }
-+        } else {
-+            if (node->op == GGML_OP_ADD &&
-+                node->src[1] && node->src[1]->ne[1] > 1 &&
-+                (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
-+                (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
-+                strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
-+                strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
-+                strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
-+                strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
-+                strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
-+                // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
-+                // by means of matching node names. See
-+                // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
-+                // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
-+                // Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
-+                use_cuda_graph = false;
-+#ifndef NDEBUG
-+                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
-+#endif
-+            }
-         }
- 
-         if (!use_cuda_graph) {
-@@ -3742,7 +3752,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
-     }
- }
- 
--static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
-+static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) {
-     ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
- 
-     ggml_cuda_set_device(cuda_ctx->device);
-@@ -3780,7 +3790,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
-     if (use_cuda_graph) {
-         cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
- 
--        use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph);
-+        use_cuda_graph = check_node_graph_compatibility(cgraph, batch_size, use_cuda_graph);
- 
-         // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
-         if (use_cuda_graph && cuda_graph_update_required) {
-diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp
-index 8fc1c2fb5..ba95b4acc 100644
---- a/ggml/src/ggml-metal/ggml-metal.cpp
-+++ b/ggml/src/ggml-metal/ggml-metal.cpp
-@@ -419,10 +419,12 @@ static bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml
-     GGML_UNUSED(dst);
- }
- 
--static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
-+static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) {
-     ggml_metal_t ctx = (ggml_metal_t)backend->context;
- 
-     return ggml_metal_graph_compute(ctx, cgraph);
-+
-+    GGML_UNUSED(batch_size);
- }
- 
- static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
-diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
-index 120191ca0..5349bce24 100644
---- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
-+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
-@@ -13099,7 +13099,7 @@ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const stru
-     return num_adds;
- }
- 
--static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
-+static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) {
-     VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
-     ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
- 
-@@ -13334,6 +13334,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
-     return GGML_STATUS_SUCCESS;
- 
-     UNUSED(backend);
-+    UNUSED(batch_size);
- }
- 
- // Sort the graph for improved parallelism.
diff --git a/llama/patches/0020-ggml-No-alloc-mode.patch b/llama/patches/0019-ggml-No-alloc-mode.patch
similarity index 86%
rename from llama/patches/0020-ggml-No-alloc-mode.patch
rename to llama/patches/0019-ggml-No-alloc-mode.patch
index 19f5f7e73c2..86bc0c3afd2 100644
--- a/llama/patches/0020-ggml-No-alloc-mode.patch
+++ b/llama/patches/0019-ggml-No-alloc-mode.patch
@@ -3,23 +3,21 @@ From: Jesse Gross 
 Date: Wed, 23 Jul 2025 11:58:49 -0700
 Subject: [PATCH] ggml: No-alloc mode
 
-Callers can set a scheduler to be no-alloc, meaning that
-it does not allocate memory for tensors or operations. This can
-be used for calculating memory requirements. Tensors and graphs
-must be recreated with no-alloc set to false before loading data.
+Adds infrastructure for scheduler no-alloc mode that enables
+fast memory sizing calculations without actual allocations.
 ---
  ggml/include/ggml-backend.h     |   1 +
  ggml/src/ggml-backend-impl.h    |  16 +++
- ggml/src/ggml-backend.cpp       |  75 ++++++++++-
- ggml/src/ggml-cuda/common.cuh   |  62 ++++++++-
- ggml/src/ggml-cuda/ggml-cuda.cu | 224 ++++++++++++++++++++++++++------
- 5 files changed, 333 insertions(+), 45 deletions(-)
+ ggml/src/ggml-backend.cpp       |  75 +++++++++++-
+ ggml/src/ggml-cuda/common.cuh   |  62 +++++++++-
+ ggml/src/ggml-cuda/ggml-cuda.cu | 211 ++++++++++++++++++++++++++------
+ 5 files changed, 320 insertions(+), 45 deletions(-)
 
 diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
-index 93c95602d..dbbb61d9c 100644
+index 97f630faa..cc4f0e5af 100644
 --- a/ggml/include/ggml-backend.h
 +++ b/ggml/include/ggml-backend.h
-@@ -305,6 +305,7 @@ extern "C" {
+@@ -306,6 +306,7 @@ extern "C" {
  
      // Initialize a backend scheduler, backends with low index are given priority over backends with high index
      GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload);
@@ -28,7 +26,7 @@ index 93c95602d..dbbb61d9c 100644
  
      // Provide a hint on the batch size to optimize processing (uses heuristics if unset)
 diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h
-index 0f5b03cef..7bdf9d81f 100644
+index a833756f9..93eb8e511 100644
 --- a/ggml/src/ggml-backend-impl.h
 +++ b/ggml/src/ggml-backend-impl.h
 @@ -26,12 +26,17 @@ extern "C" {
@@ -75,7 +73,7 @@ index 0f5b03cef..7bdf9d81f 100644
  
      struct ggml_backend {
 diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
-index 498186a7c..7746e8b92 100644
+index 761bd12df..bcc42c5c5 100644
 --- a/ggml/src/ggml-backend.cpp
 +++ b/ggml/src/ggml-backend.cpp
 @@ -36,11 +36,25 @@ const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) {
@@ -128,7 +126,7 @@ index 498186a7c..7746e8b92 100644
      // FIXME JG: a multi_buffer has a non-zero size, according to the above comment get_base is not optional,
      //     I don't know whether the above comment is correct
      if (!buffer->iface.get_base) {
-@@ -736,6 +757,12 @@ struct ggml_backend_sched {
+@@ -738,6 +759,12 @@ struct ggml_backend_sched {
      int debug_realloc;
      int debug_graph_size;
      int debug_prev_graph_size;
@@ -141,7 +139,7 @@ index 498186a7c..7746e8b92 100644
  };
  
  #define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor)
-@@ -1635,6 +1662,17 @@ ggml_backend_sched_t ggml_backend_sched_new(
+@@ -1637,6 +1664,17 @@ ggml_backend_sched_t ggml_backend_sched_new(
          size_t graph_size,
          bool parallel,
          bool op_offload) {
@@ -159,7 +157,7 @@ index 498186a7c..7746e8b92 100644
      GGML_ASSERT(n_backends > 0);
      GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS);
      GGML_ASSERT(ggml_backend_dev_type(ggml_backend_get_device(backends[n_backends - 1])) == GGML_BACKEND_DEVICE_TYPE_CPU);
-@@ -1687,11 +1725,14 @@ ggml_backend_sched_t ggml_backend_sched_new(
+@@ -1689,11 +1727,14 @@ ggml_backend_sched_t ggml_backend_sched_new(
                  sched->events[b][c] = ggml_backend_event_new(backends[b]->device);
              }
          }
@@ -174,7 +172,7 @@ index 498186a7c..7746e8b92 100644
  
      ggml_backend_sched_reset(sched);
  
-@@ -1706,6 +1747,10 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
+@@ -1712,6 +1753,10 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
          for (int c = 0; c < sched->n_copies; c++) {
              ggml_backend_event_free(sched->events[b][c]);
          }
@@ -185,7 +183,7 @@ index 498186a7c..7746e8b92 100644
      }
      ggml_gallocr_free(sched->galloc);
      ggml_free(sched->ctx);
-@@ -1765,6 +1810,24 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
+@@ -1767,6 +1812,24 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
          return false;
      }
  
@@ -210,7 +208,7 @@ index 498186a7c..7746e8b92 100644
      ggml_backend_sched_reset(sched);
  
      return true;
-@@ -1870,7 +1933,13 @@ size_t ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched,
+@@ -1872,7 +1935,13 @@ size_t ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched,
      int backend_index = ggml_backend_sched_backend_id(sched, backend);
      GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
  
@@ -226,7 +224,7 @@ index 498186a7c..7746e8b92 100644
  
  void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
 diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
-index 9fcb2f9fd..e800ee8f6 100644
+index 36d8a3aaa..321357713 100644
 --- a/ggml/src/ggml-cuda/common.cuh
 +++ b/ggml/src/ggml-cuda/common.cuh
 @@ -37,6 +37,41 @@
@@ -271,7 +269,7 @@ index 9fcb2f9fd..e800ee8f6 100644
  #define STRINGIZE_IMPL(...) #__VA_ARGS__
  #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
  
-@@ -941,6 +976,9 @@ struct ggml_cuda_pool {
+@@ -1062,6 +1097,9 @@ struct ggml_cuda_pool {
  
      virtual void * alloc(size_t size, size_t * actual_size) = 0;
      virtual void free(void * ptr, size_t size) = 0;
@@ -281,7 +279,7 @@ index 9fcb2f9fd..e800ee8f6 100644
  };
  
  template
-@@ -1232,11 +1270,15 @@ struct ggml_backend_cuda_context {
+@@ -1399,11 +1437,15 @@ struct ggml_backend_cuda_context {
      // pool
      std::unique_ptr pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS];
  
@@ -299,7 +297,7 @@ index 9fcb2f9fd..e800ee8f6 100644
          }
          return *pools[device][curr_stream_no];
      }
-@@ -1244,6 +1286,22 @@ struct ggml_backend_cuda_context {
+@@ -1411,6 +1453,22 @@ struct ggml_backend_cuda_context {
      ggml_cuda_pool & pool() {
          return pool(device);
      }
@@ -323,10 +321,19 @@ index 9fcb2f9fd..e800ee8f6 100644
  
  struct ggml_cuda_mm_fusion_args_host {
 diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
-index 25548629d..eeaae3fe4 100644
+index 0ab859d3c..d3dacc270 100644
 --- a/ggml/src/ggml-cuda/ggml-cuda.cu
 +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
-@@ -365,6 +365,8 @@ const ggml_cuda_device_info & ggml_cuda_info() {
+@@ -85,6 +85,8 @@
+ 
+ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
+ 
++bool reserving_graph = false;
++
+ [[noreturn]]
+ void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
+     int id = -1; // in case cudaGetDevice fails
+@@ -371,6 +373,8 @@ const ggml_cuda_device_info & ggml_cuda_info() {
  
  // #define DEBUG_CUDA_MALLOC
  
@@ -335,7 +342,7 @@ index 25548629d..eeaae3fe4 100644
  // buffer pool for cuda (legacy)
  struct ggml_cuda_pool_leg : public ggml_cuda_pool {
      static const int MAX_BUFFERS = 256;
-@@ -377,9 +379,12 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
+@@ -383,17 +387,25 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
  
      ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {};
      size_t pool_size = 0;
@@ -349,8 +356,11 @@ index 25548629d..eeaae3fe4 100644
 +        allocate(alloc) {
      }
  
++    bool alloc_memory() override { return allocate; }
++    size_t alloc_size() override { return pool_size + last_alloc; }
++
      ~ggml_cuda_pool_leg() {
-@@ -387,7 +392,9 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
+         ggml_cuda_set_device(device);
          for (int i = 0; i < MAX_BUFFERS; ++i) {
              ggml_cuda_buffer & b = buffer_pool[i];
              if (b.ptr != nullptr) {
@@ -361,7 +371,7 @@ index 25548629d..eeaae3fe4 100644
                  pool_size -= b.size;
              }
          }
-@@ -435,8 +442,15 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
+@@ -441,8 +453,15 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
          void * ptr;
          size_t look_ahead_size = (size_t) (1.05 * size);
          look_ahead_size = 256 * ((look_ahead_size + 255)/256);
@@ -379,7 +389,7 @@ index 25548629d..eeaae3fe4 100644
          *actual_size = look_ahead_size;
          pool_size += look_ahead_size;
  #ifdef DEBUG_CUDA_MALLOC
-@@ -456,10 +470,20 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
+@@ -462,8 +481,10 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
              }
          }
          GGML_LOG_DEBUG(GGML_CUDA_NAME " buffer pool full, increase MAX_CUDA_BUFFERS\n");
@@ -391,18 +401,8 @@ index 25548629d..eeaae3fe4 100644
 +        }
          pool_size -= size;
      }
-+
-+    bool alloc_memory() override {
-+        return allocate;
-+    }
-+
-+    size_t alloc_size() override {
-+        return pool_size + last_alloc;
-+    }
  };
- 
- // pool with virtual memory
-@@ -471,18 +495,24 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
+@@ -477,18 +498,27 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
      CUdeviceptr pool_addr = 0;
      size_t pool_used = 0;
      size_t pool_size = 0;
@@ -424,13 +424,16 @@ index 25548629d..eeaae3fe4 100644
 +        }
      }
  
++    bool alloc_memory() override { return allocate; }
++    size_t alloc_size() override { return pool_size + last_alloc; }
++
      ~ggml_cuda_pool_vmm() {
 -        if (pool_addr != 0) {
 +        if (pool_addr != 0 && allocate) {
  #if defined(GGML_USE_HIP)
              // Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285
              for (std::pair & mapping : mappings) {
-@@ -509,35 +539,49 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
+@@ -515,35 +545,49 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
  
              GGML_ASSERT(pool_size + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);
  
@@ -478,7 +481,13 @@ index 25548629d..eeaae3fe4 100644
 +                    CU_CHECK(cuMemRelease(handle));
 +                    throw std::bad_alloc();
 +                }
-+
+ 
+-            // set access
+-            CUmemAccessDesc access = {};
+-            access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
+-            access.location.id = device;
+-            access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
+-            CU_CHECK(cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1));
 +                // the memory allocation handle is no longer needed after mapping
 +                CU_CHECK(cuMemRelease(handle));
 +
@@ -492,13 +501,7 @@ index 25548629d..eeaae3fe4 100644
 +                    last_alloc = reserve_size;
 +                    throw std::bad_alloc();
 +                }
- 
--            // set access
--            CUmemAccessDesc access = {};
--            access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
--            access.location.id = device;
--            access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
--            CU_CHECK(cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1));
++
 +    #if defined(GGML_USE_HIP)
 +                mappings.push_back({start_ptr, reserve_size});
 +    #endif
@@ -506,26 +509,13 @@ index 25548629d..eeaae3fe4 100644
  
              // add to the pool
              pool_size += reserve_size;
-@@ -570,17 +614,27 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
-         // all deallocations must be in reverse order of the allocations
-         GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used));
-     }
-+
-+    bool alloc_memory() override {
-+        return allocate;
-+    }
-+
-+    size_t alloc_size() override {
-+        return pool_size + last_alloc;
-+    }
-+
- };
+@@ -580,13 +624,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
  #endif // defined(GGML_USE_VMM)
  
  std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int                  device,
 -                                                                               [[maybe_unused]] int stream_no) {
 +                                                                               [[maybe_unused]] int stream_no,
-+                                                                               bool alloc) {
++                                                                               bool                 alloc) {
  #if defined(GGML_USE_VMM)
      if (ggml_cuda_info().devices[device].vmm) {
 -        return std::unique_ptr(new ggml_cuda_pool_vmm(device));
@@ -537,7 +527,7 @@ index 25548629d..eeaae3fe4 100644
  }
  
  // destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error
-@@ -764,11 +818,20 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac
+@@ -770,11 +815,20 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac
  }
  
  static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
@@ -559,7 +549,7 @@ index 25548629d..eeaae3fe4 100644
  static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
      size_t size = ggml_nbytes(tensor);
      int64_t ne0 = tensor->ne[0];
-@@ -792,6 +855,7 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface
+@@ -798,6 +852,7 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface
      /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
      /* .get_alloc_size   = */ ggml_backend_cuda_buffer_type_get_alloc_size,
      /* .is_host          = */ NULL,
@@ -567,15 +557,7 @@ index 25548629d..eeaae3fe4 100644
  };
  
  ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
-@@ -3274,6 +3338,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
- 
- static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
-     bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
-+
-     // flag used to determine whether it is an integrated_gpu
-     const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
- 
-@@ -3410,6 +3475,10 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
+@@ -3567,6 +3622,11 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
                      continue;
                  }
  
@@ -583,32 +565,30 @@ index 25548629d..eeaae3fe4 100644
 +                if (reserving_graph && node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) {
 +                    continue;
 +                }
- 
++
                  // start of fusion operations
                  static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
-@@ -3754,6 +3823,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
+                 if (!disable_fusion) {
+@@ -3972,6 +4032,7 @@ static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, co
  
  static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) {
-     ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+     ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
 +    cuda_ctx->pool_set_alloc(true);
  
-     ggml_cuda_set_device(cuda_ctx->device);
+     GGML_UNUSED(batch_size);
  
-@@ -3829,6 +3899,77 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
+@@ -4031,6 +4092,73 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
      return GGML_STATUS_SUCCESS;
  }
  
-+// This is used to skip operations that are not graph safe during the reservation process.
-+bool reserving_graph = false;
-+
 +static enum ggml_status ggml_backend_cuda_graph_reserve(ggml_backend_t backend, ggml_cgraph * cgraph, bool alloc) {
 +    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
 +    cuda_ctx->pool_set_alloc(alloc);
 +
++    const void * graph_key = nullptr;
 +    #ifdef USE_CUDA_GRAPH
-+    if (cuda_ctx->cuda_graph == nullptr) {
-+        cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
-+    }
++    graph_key = ggml_cuda_graph_get_key(cgraph);
++    // cuda_ctx->cuda_graph(graph_key) will auto-create the graph if needed
 +    #endif
 +
 +    ggml_cuda_set_device(cuda_ctx->device);
@@ -630,9 +610,8 @@ index 25548629d..eeaae3fe4 100644
 +    try {
 +        bool use_cuda_graph = false;
 +        bool cuda_graph_update_required = false;
-+        bool graph_evaluated_or_captured = false;
 +
-+        evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
++        ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required, graph_key);
 +    } catch (const std::exception &e) {
 +        result = GGML_STATUS_FAILED;
 +    }
@@ -672,7 +651,7 @@ index 25548629d..eeaae3fe4 100644
  static void ggml_backend_cuda_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
      ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
  
-@@ -4097,6 +4238,9 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
+@@ -4315,6 +4443,9 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
      /* .event_record            = */ ggml_backend_cuda_event_record,
      /* .event_wait              = */ ggml_backend_cuda_event_wait,
      /* .graph_optimize          = */ ggml_backend_cuda_graph_optimize,
diff --git a/llama/patches/0021-decode-disable-output_all.patch b/llama/patches/0020-decode-disable-output_all.patch
similarity index 62%
rename from llama/patches/0021-decode-disable-output_all.patch
rename to llama/patches/0020-decode-disable-output_all.patch
index 20001bd978f..af5f71d9e7b 100644
--- a/llama/patches/0021-decode-disable-output_all.patch
+++ b/llama/patches/0020-decode-disable-output_all.patch
@@ -8,16 +8,16 @@ Subject: [PATCH] decode: disable output_all
  1 file changed, 1 insertion(+), 2 deletions(-)
 
 diff --git a/src/llama-context.cpp b/src/llama-context.cpp
-index 8786d4ee3..9e6998272 100644
+index 98d055d34..964bb3220 100644
 --- a/src/llama-context.cpp
 +++ b/src/llama-context.cpp
-@@ -1051,8 +1051,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
+@@ -1437,8 +1437,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
      const int64_t n_vocab = vocab.n_tokens();
      const int64_t n_embd  = hparams.n_embd_inp();
  
 -    // when computing embeddings, all tokens are output
--    const bool output_all = cparams.embeddings;
+-    const bool output_all   = cparams.embeddings;
 +    const bool output_all = false;
+     const bool has_samplers = !sampling.samplers.empty();
  
-     if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
-         LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
+     const uint32_t n_seq_max = cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max;
diff --git a/llama/patches/0022-ggml-Enable-resetting-backend-devices.patch b/llama/patches/0021-ggml-Enable-resetting-backend-devices.patch
similarity index 75%
rename from llama/patches/0022-ggml-Enable-resetting-backend-devices.patch
rename to llama/patches/0021-ggml-Enable-resetting-backend-devices.patch
index 3197f94e8c8..adffcb7dbd3 100644
--- a/llama/patches/0022-ggml-Enable-resetting-backend-devices.patch
+++ b/llama/patches/0021-ggml-Enable-resetting-backend-devices.patch
@@ -3,23 +3,22 @@ From: Jesse Gross 
 Date: Wed, 27 Aug 2025 14:39:48 -0700
 Subject: [PATCH] ggml: Enable resetting backend devices
 
-Touching a CUDA device causes the allocation of a primary context
-with CUDA data structures (~300 MB of VRAM). If a device is
-unused then it can be reset to free these data structures.
+Allows resetting CUDA devices to free primary context allocations
+(~300 MB of VRAM per device) when a device is unused.
 ---
  ggml/include/ggml-backend.h      |  1 +
  ggml/src/ggml-backend-impl.h     |  4 ++++
  ggml/src/ggml-backend.cpp        |  8 ++++++++
- ggml/src/ggml-cuda/ggml-cuda.cu  | 16 +++++++++++++++-
+ ggml/src/ggml-cuda/ggml-cuda.cu  | 11 +++++++++++
  ggml/src/ggml-cuda/vendors/hip.h |  1 +
  src/llama.cpp                    |  4 +++-
- 6 files changed, 32 insertions(+), 2 deletions(-)
+ 6 files changed, 28 insertions(+), 1 deletion(-)
 
 diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
-index dbbb61d9c..92ca32a4b 100644
+index cc4f0e5af..006867064 100644
 --- a/ggml/include/ggml-backend.h
 +++ b/ggml/include/ggml-backend.h
-@@ -178,6 +178,7 @@ extern "C" {
+@@ -179,6 +179,7 @@ extern "C" {
      GGML_API void                          ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props);
      GGML_API ggml_backend_reg_t            ggml_backend_dev_backend_reg(ggml_backend_dev_t device);
      GGML_API ggml_backend_t                ggml_backend_dev_init(ggml_backend_dev_t device, const char * params);
@@ -28,7 +27,7 @@ index dbbb61d9c..92ca32a4b 100644
      GGML_API ggml_backend_buffer_type_t    ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device);
      GGML_API ggml_backend_buffer_t         ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size);
 diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h
-index 7bdf9d81f..21b35ac5c 100644
+index 93eb8e511..c815c2eed 100644
 --- a/ggml/src/ggml-backend-impl.h
 +++ b/ggml/src/ggml-backend-impl.h
 @@ -195,6 +195,10 @@ extern "C" {
@@ -43,10 +42,10 @@ index 7bdf9d81f..21b35ac5c 100644
  
      struct ggml_backend_device {
 diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
-index 7746e8b92..189e97170 100644
+index bcc42c5c5..4d8abe025 100644
 --- a/ggml/src/ggml-backend.cpp
 +++ b/ggml/src/ggml-backend.cpp
-@@ -532,6 +532,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par
+@@ -534,6 +534,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par
      return device->iface.init_backend(device, params);
  }
  
@@ -62,10 +61,10 @@ index 7746e8b92..189e97170 100644
      GGML_ASSERT(device);
      return device->iface.get_buffer_type(device);
 diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
-index eeaae3fe4..6852d2e20 100644
+index d3dacc270..99fecd81e 100644
 --- a/ggml/src/ggml-cuda/ggml-cuda.cu
 +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
-@@ -113,6 +113,11 @@ int ggml_cuda_get_device() {
+@@ -118,6 +118,11 @@ int ggml_cuda_get_device() {
      return id;
  }
  
@@ -77,19 +76,7 @@ index eeaae3fe4..6852d2e20 100644
  static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
      ggml_cuda_set_device(device);
      cudaError_t err;
-@@ -4448,7 +4453,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
-     props->id          = ggml_backend_cuda_device_get_id(dev);
-     props->type        = ggml_backend_cuda_device_get_type(dev);
-     props->device_id   = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
--    ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
-+
-+    // Memory reporting is disabled to avoid allocation of a CUDA primary context (~300 MB per device).
-+    // If you need the memory data, call ggml_backend_dev_memory() explicitly.
-+    props->memory_total = props->memory_free = 0;
- 
-     bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
- #ifdef GGML_CUDA_NO_PEER_COPY
-@@ -4908,6 +4916,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g
+@@ -5119,6 +5124,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g
      CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
  }
  
@@ -101,7 +88,7 @@ index eeaae3fe4..6852d2e20 100644
  static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
      /* .get_name                = */ ggml_backend_cuda_device_get_name,
      /* .get_description         = */ ggml_backend_cuda_device_get_description,
-@@ -4924,6 +4937,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
+@@ -5135,6 +5145,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
      /* .event_new               = */ ggml_backend_cuda_device_event_new,
      /* .event_free              = */ ggml_backend_cuda_device_event_free,
      /* .event_synchronize       = */ ggml_backend_cuda_device_event_synchronize,
@@ -110,22 +97,22 @@ index eeaae3fe4..6852d2e20 100644
  
  // backend reg
 diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
-index 951a88d56..4e162258d 100644
+index 5cc1b5431..5c172bf5d 100644
 --- a/ggml/src/ggml-cuda/vendors/hip.h
 +++ b/ggml/src/ggml-cuda/vendors/hip.h
-@@ -49,6 +49,7 @@
- #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
+@@ -51,6 +51,7 @@
  #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
+ #define cudaDeviceGetAttribute hipDeviceGetAttribute
  #define cudaDeviceProp hipDeviceProp_t
 +#define cudaDeviceReset hipDeviceReset
  #define cudaDeviceSynchronize hipDeviceSynchronize
  #define cudaError_t hipError_t
  #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
 diff --git a/src/llama.cpp b/src/llama.cpp
-index f69964b6d..759152b76 100644
+index 6da90d6f1..c5aec0816 100644
 --- a/src/llama.cpp
 +++ b/src/llama.cpp
-@@ -921,10 +921,12 @@ static struct llama_model * llama_model_load_from_file_impl(
+@@ -997,10 +997,12 @@ static struct llama_model * llama_model_load_from_file_impl(
      for (auto * dev : model->devices) {
          ggml_backend_dev_props props;
          ggml_backend_dev_get_props(dev, &props);
diff --git a/llama/patches/0023-harden-uncaught-exception-registration.patch b/llama/patches/0022-harden-uncaught-exception-registration.patch
similarity index 100%
rename from llama/patches/0023-harden-uncaught-exception-registration.patch
rename to llama/patches/0022-harden-uncaught-exception-registration.patch
diff --git a/llama/patches/0024-GPU-discovery-enhancements.patch b/llama/patches/0023-ollama-GPU-discovery-enhancements.patch
similarity index 87%
rename from llama/patches/0024-GPU-discovery-enhancements.patch
rename to llama/patches/0023-ollama-GPU-discovery-enhancements.patch
index 6e4ef239477..f16561f5e27 100644
--- a/llama/patches/0024-GPU-discovery-enhancements.patch
+++ b/llama/patches/0023-ollama-GPU-discovery-enhancements.patch
@@ -1,37 +1,30 @@
 From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
 From: Daniel Hiltgen 
 Date: Tue, 26 Aug 2025 12:48:29 -0700
-Subject: [PATCH] GPU discovery enhancements
+Subject: [PATCH] ollama: GPU discovery enhancements
 
-Expose more information about the devices through backend props, and leverage
-management libraries for more accurate VRAM usage reporting if available.
-
-vulkan: get GPU ID (ollama v0.11.5)
-
-Signed-off-by: Xiaodong Ye 
-
-Vulkan PCI and Memory
-
-fix vulkan PCI ID and ID handling
+Add NVML and ADLX memory reporting for accurate VRAM metrics.
+Add new device properties: compute version, driver version, integrated flag, library name.
+Update CUDA, Metal, and Vulkan backends with enhanced device info.
 ---
  ggml/include/ggml-backend.h          |   6 +
  ggml/src/CMakeLists.txt              |   2 +
- ggml/src/ggml-cuda/ggml-cuda.cu      |  65 ++++
+ ggml/src/ggml-cuda/ggml-cuda.cu      |  65 +++-
  ggml/src/ggml-cuda/vendors/hip.h     |   3 +
  ggml/src/ggml-impl.h                 |   8 +
  ggml/src/ggml-metal/ggml-metal.cpp   |   2 +
- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 169 +++++++-
+ ggml/src/ggml-vulkan/ggml-vulkan.cpp | 122 +++++-
  ggml/src/mem_hip.cpp                 | 558 +++++++++++++++++++++++++++
  ggml/src/mem_nvml.cpp                | 209 ++++++++++
- 9 files changed, 1005 insertions(+), 17 deletions(-)
+ 9 files changed, 968 insertions(+), 7 deletions(-)
  create mode 100644 ggml/src/mem_hip.cpp
  create mode 100644 ggml/src/mem_nvml.cpp
 
 diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
-index 92ca32a4b..6ad583f09 100644
+index 006867064..21c46f4fc 100644
 --- a/ggml/include/ggml-backend.h
 +++ b/ggml/include/ggml-backend.h
-@@ -169,6 +169,12 @@ extern "C" {
+@@ -170,6 +170,12 @@ extern "C" {
          const char * device_id;
          // device capabilities
          struct ggml_backend_dev_caps caps;
@@ -45,7 +38,7 @@ index 92ca32a4b..6ad583f09 100644
  
      GGML_API const char *                  ggml_backend_dev_name(ggml_backend_dev_t device);
 diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
-index d55aed348..99ae293cc 100644
+index dbcb5ef5d..ab763741d 100644
 --- a/ggml/src/CMakeLists.txt
 +++ b/ggml/src/CMakeLists.txt
 @@ -205,6 +205,8 @@ add_library(ggml-base
@@ -58,10 +51,10 @@ index d55aed348..99ae293cc 100644
  
  set_target_properties(ggml-base PROPERTIES
 diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
-index 6852d2e20..334a30135 100644
+index 99fecd81e..2a76f0485 100644
 --- a/ggml/src/ggml-cuda/ggml-cuda.cu
 +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
-@@ -267,6 +267,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
+@@ -262,6 +262,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
      for (int id = 0; id < info.device_count; ++id) {
          int device_vmm = 0;
  
@@ -78,19 +71,7 @@ index 6852d2e20..334a30135 100644
  #if defined(GGML_USE_VMM)
          CUdevice device;
          CU_CHECK(cuDeviceGet(&device, id));
-@@ -320,6 +330,11 @@ static ggml_cuda_device_info ggml_cuda_init() {
- #else
-         info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
-         info.devices[id].cc = 100*prop.major + 10*prop.minor;
-+#ifdef __CUDA_ARCH_LIST__
-+        if (std::getenv("GGML_CUDA_INIT") != NULL) {
-+            GGML_ASSERT(ggml_cuda_has_arch(info.devices[id].cc) && "ggml was not compiled with support for this arch");
-+        }
-+#endif // defined(__CUDA_ARCH_LIST__)
-         GGML_LOG_INFO("  Device %d: %s, compute capability %d.%d, VMM: %s, ID: %s\n",
-                         id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no",
-                         ggml_cuda_parse_uuid(prop, id).c_str());
-@@ -4317,6 +4332,11 @@ struct ggml_backend_cuda_device_context {
+@@ -4522,6 +4532,11 @@ struct ggml_backend_cuda_device_context {
      std::string description;
      std::string pci_bus_id;
      std::string id;
@@ -99,10 +80,10 @@ index 6852d2e20..334a30135 100644
 +    int driver_major;
 +    int driver_minor;
 +    int integrated;
+     int op_offload_min_batch_size;
  };
  
- static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
-@@ -4413,6 +4433,28 @@ static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) {
+@@ -4619,6 +4634,28 @@ static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) {
  static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
      ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
      ggml_cuda_set_device(ctx->device);
@@ -131,7 +112,7 @@ index 6852d2e20..334a30135 100644
      CUDA_CHECK(cudaMemGetInfo(free, total));
  
  // ref: https://github.com/ggml-org/llama.cpp/pull/17368
-@@ -4445,6 +4487,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend
+@@ -4651,6 +4688,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend
      return GGML_BACKEND_DEVICE_TYPE_GPU;
  }
  
@@ -139,10 +120,15 @@ index 6852d2e20..334a30135 100644
  static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
      ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
  
-@@ -4458,6 +4501,19 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
-     // If you need the memory data, call ggml_backend_dev_memory() explicitly.
-     props->memory_total = props->memory_free = 0;
- 
+@@ -4659,7 +4697,22 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
+     props->id          = ggml_backend_cuda_device_get_id(dev);
+     props->type        = ggml_backend_cuda_device_get_type(dev);
+     props->device_id   = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
+-    ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
++    // Prefer calling ggml_backend_dev_memory() explicitly if you need memory data.
++    // If you need the memory data, call ggml_backend_dev_memory() explicitly.
++    props->memory_total = props->memory_free = 0;
++
 +#if defined(GGML_USE_HIP)
 +    int cc = ggml_cuda_info().devices[ctx->device].cc - GGML_CUDA_CC_OFFSET_AMD;
 +    props->compute_major = cc / 0x100;
@@ -155,22 +141,22 @@ index 6852d2e20..334a30135 100644
 +    props->driver_minor = ctx->driver_minor;
 +    props->integrated = ctx->integrated;
 +    props->library = GGML_CUDA_NAME;
-+
+ 
      bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
  #ifdef GGML_CUDA_NO_PEER_COPY
-     bool events = false;
-@@ -5047,6 +5103,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
-         std::lock_guard lock(mutex);
+@@ -5266,6 +5319,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
          if (!initialized) {
              ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context;
+             const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
 +            int driverVersion = 0;
  
              for (int i = 0; i < ggml_cuda_info().device_count; i++) {
                  ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context;
-@@ -5062,6 +5119,14 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
+@@ -5280,6 +5334,15 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
+                 char pci_bus_id[16] = {};
                  snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
                  dev_ctx->pci_bus_id = pci_bus_id;
- 
++
 +                dev_ctx->major = prop.major;
 +                dev_ctx->minor = prop.minor;
 +                if (driverVersion == 0) {
@@ -179,11 +165,11 @@ index 6852d2e20..334a30135 100644
 +                dev_ctx->driver_major = driverVersion / 1000;
 +                dev_ctx->driver_minor = (driverVersion - (dev_ctx->driver_major * 1000)) / 10;
 +                dev_ctx->integrated = prop.integrated;
+                 dev_ctx->op_offload_min_batch_size = min_batch_size;
+ 
                  ggml_backend_dev_t dev = new ggml_backend_device {
-                     /* .iface   = */ ggml_backend_cuda_device_interface,
-                     /* .reg     = */ ®,
 diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
-index 4e162258d..d89e35a8e 100644
+index 5c172bf5d..14473a97c 100644
 --- a/ggml/src/ggml-cuda/vendors/hip.h
 +++ b/ggml/src/ggml-cuda/vendors/hip.h
 @@ -5,6 +5,8 @@
@@ -195,7 +181,7 @@ index 4e162258d..d89e35a8e 100644
  
  #if defined(GGML_HIP_ROCWMMA_FATTN)
  #include 
-@@ -51,6 +53,7 @@
+@@ -53,6 +55,7 @@
  #define cudaDeviceProp hipDeviceProp_t
  #define cudaDeviceReset hipDeviceReset
  #define cudaDeviceSynchronize hipDeviceSynchronize
@@ -204,10 +190,10 @@ index 4e162258d..d89e35a8e 100644
  #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
  #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
 diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
-index fe57d4c58..dba8f4695 100644
+index e3714b38a..80ebb059e 100644
 --- a/ggml/src/ggml-impl.h
 +++ b/ggml/src/ggml-impl.h
-@@ -677,6 +677,14 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
+@@ -680,6 +680,14 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
      return ggml_can_fuse_subgraph_ext(cgraph, idxs, count, ops, outputs, num_outputs);
  }
  
@@ -223,10 +209,10 @@ index fe57d4c58..dba8f4695 100644
  }
  #endif
 diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp
-index ba95b4acc..f6f8f7a10 100644
+index ba12c7c14..128c10fa0 100644
 --- a/ggml/src/ggml-metal/ggml-metal.cpp
 +++ b/ggml/src/ggml-metal/ggml-metal.cpp
-@@ -546,6 +546,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen
+@@ -664,6 +664,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen
      GGML_UNUSED(dev);
  }
  
@@ -234,19 +220,19 @@ index ba95b4acc..f6f8f7a10 100644
  static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
      props->name        = ggml_backend_metal_device_get_name(dev);
      props->description = ggml_backend_metal_device_get_description(dev);
-@@ -554,6 +555,7 @@ static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_bac
+@@ -672,6 +673,7 @@ static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_bac
  
      ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
  
 +    props->library = GGML_METAL_NAME;
      props->caps = {
-         /* .async                 = */ true,
-         /* .host_buffer           = */ false,
+         /* .async                = */ true,
+         /* .host_buffer          = */ false,
 diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
-index 5349bce24..0103fd03a 100644
+index 008b82e9b..1d3705400 100644
 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
 +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
-@@ -236,6 +236,7 @@ class vk_memory_logger;
+@@ -243,6 +243,7 @@ class vk_memory_logger;
  class vk_perf_logger;
  static void ggml_vk_destroy_buffer(vk_buffer& buf);
  static void ggml_vk_synchronize(ggml_backend_vk_context * ctx);
@@ -254,7 +240,7 @@ index 5349bce24..0103fd03a 100644
  
  static constexpr uint32_t mul_mat_vec_max_cols = 8;
  static constexpr uint32_t p021_max_gqa_ratio = 8;
-@@ -12350,6 +12351,29 @@ static void ggml_vk_get_device_description(int device, char * description, size_
+@@ -13212,6 +13213,29 @@ static void ggml_vk_get_device_description(int device, char * description, size_
      snprintf(description, description_size, "%s", props.deviceName.data());
  }
  
@@ -284,25 +270,11 @@ index 5349bce24..0103fd03a 100644
  // backend interface
  
  #define UNUSED GGML_UNUSED
-@@ -13628,15 +13652,72 @@ void ggml_backend_vk_get_device_description(int device, char * description, size
-     ggml_vk_get_device_description(dev_idx, description, description_size);
- }
- 
--void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
-+std::string ggml_backend_vk_get_device_id(int device) {
-     GGML_ASSERT(device < (int) vk_instance.device_indices.size());
--    GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size());
-+    int dev_idx = vk_instance.device_indices[device];
-+    return ggml_vk_get_device_id(dev_idx);
-+}
-+
-+//////////////////////////
-+
-+struct ggml_backend_vk_device_context {
-+    size_t device;
-+    std::string name;
-+    std::string description;
-+    bool is_integrated_gpu;
+@@ -14863,7 +14887,14 @@ struct ggml_backend_vk_device_context {
+     std::string name;
+     std::string description;
+     bool is_integrated_gpu;
+-    std::string pci_bus_id;
 +    // Combined string id in the form "dddd:bb:dd.f" (domain:bus:device.function)
 +    std::string pci_id;
 +    std::string id;
@@ -311,32 +283,34 @@ index 5349bce24..0103fd03a 100644
 +    int minor;
 +    int driver_major;
 +    int driver_minor;
-+};
+     int op_offload_min_batch_size;
+ };
  
--    vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
-+void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size_t * free, size_t * total) {
-+    GGML_ASSERT(ctx->device < (int) vk_instance.device_indices.size());
-+    GGML_ASSERT(ctx->device < (int) vk_instance.device_supports_membudget.size());
-+
-+    vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[ctx->device]];
-     vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops;
-     vk::PhysicalDeviceMemoryProperties2 memprops = {};
--    const bool membudget_supported = vk_instance.device_supports_membudget[device];
-+    const bool membudget_supported = vk_instance.device_supports_membudget[ctx->device];
-     const bool is_integrated_gpu = vkdev.getProperties().deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
-+    
-+    vk::PhysicalDeviceProperties2 props2;
-+    vkdev.getProperties2(&props2);
-+
-+    if (!is_integrated_gpu)
-+    {
-+        // Use vendor specific management libraries for best VRAM reporting if available
+@@ -14877,8 +14908,48 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de
+     return ctx->description.c_str();
+ }
+ 
++static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) {
++    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
++    return ctx->id.c_str();
++}
++
+ static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
+     ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context;
++
++    // Use vendor specific management libraries for best VRAM reporting if available
++    if (!ctx->is_integrated_gpu) {
++        GGML_ASSERT(ctx->device < (int) vk_instance.device_indices.size());
++        vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[ctx->device]];
++        vk::PhysicalDeviceProperties2 props2;
++        vkdev.getProperties2(&props2);
++
 +        switch (props2.properties.vendorID) {
 +        case VK_VENDOR_ID_AMD:
 +            if (ggml_hip_mgmt_init() == 0) {
-+                int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total, ctx->is_integrated_gpu);
++                int status = ggml_hip_get_device_memory(!ctx->pci_id.empty() ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total, ctx->is_integrated_gpu);
 +                if (status == 0) {
-+                    GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total);
++                    GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, !ctx->pci_id.empty() ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total);
 +                    ggml_hip_mgmt_release();
 +                    return;
 +                }
@@ -356,77 +330,12 @@ index 5349bce24..0103fd03a 100644
 +            break;
 +        }
 +    }
-+    // else fallback to memory budget if supported
 +
- 
-     if (membudget_supported) {
-         memprops.pNext = &budgetprops;
-@@ -13688,8 +13769,13 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) {
-         }
-     }
- 
-+    vk::PhysicalDeviceProperties2 props2;
-     if (!ext_support) {
--        return "";
-+        device.getProperties2(&props2);
-+        if (props2.properties.vendorID != VK_VENDOR_ID_AMD) {
-+            return "";
-+        }
-+        // AMD doesn't claim to support PCI ID, but actually does, so try anyway and check for non-zero
-     }
- 
-     vk::PhysicalDeviceProperties2 props = {};
-@@ -13706,19 +13792,24 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) {
- 
-     char pci_bus_id[16] = {};
-     snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function);
-+    if (pci_domain == 0 && pci_bus == 0 && pci_device == 0 && pci_function == 0) {
-+        return "";
-+    }
- 
-     return std::string(pci_bus_id);
++    // Fallback to Vulkan memory budget
+     ggml_backend_vk_get_device_memory(ctx->device, free, total);
  }
  
--//////////////////////////
--
--struct ggml_backend_vk_device_context {
--    size_t device;
--    std::string name;
--    std::string description;
--    bool is_integrated_gpu;
--    std::string pci_bus_id;
--};
-+static bool ggml_backend_vk_parse_pci_bus_id(const std::string & id, int *domain, int *bus, int *device) {
-+    if (id.empty()) return false;
-+    unsigned int d = 0, b = 0, dev = 0, func = 0;
-+    // Expected format: dddd:bb:dd.f (all hex)
-+    int n = sscanf(id.c_str(), "%4x:%2x:%2x.%1x", &d, &b, &dev, &func);
-+    if (n < 4) return false;
-+    if (domain) *domain = (int) d;
-+    if (bus) *bus = (int) b;
-+    if (device) *device = (int) dev;
-+    return true;
-+}
- 
- static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {
-     ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
-@@ -13730,9 +13821,14 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de
-     return ctx->description.c_str();
- }
- 
-+static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) {
-+    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
-+    return ctx->id.c_str();
-+}
-+
- static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
-     ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context;
--    ggml_backend_vk_get_device_memory(ctx->device, free, total);
-+    ggml_backend_vk_get_device_memory(ctx, free, total);
- }
- 
- static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) {
-@@ -13756,8 +13852,9 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
+@@ -14903,15 +14974,23 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
  
      props->name        = ggml_backend_vk_device_get_name(dev);
      props->description = ggml_backend_vk_device_get_description(dev);
@@ -436,10 +345,12 @@ index 5349bce24..0103fd03a 100644
 +    props->device_id   = ctx->pci_id.empty() ? nullptr : ctx->pci_id.c_str();
      ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
      props->caps = {
-         /* .async                 = */ false,
-@@ -13765,6 +13862,13 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
+-        /* .async                 = */ true,
++        /* .async                 = */ false,
+         /* .host_buffer           = */ true,
          /* .buffer_from_host_ptr  = */ false,
-         /* .events                = */ false,
+-        /* .events                = */ true,
++        /* .events                = */ false,
      };
 +
 +    props->compute_major = ctx->major;
@@ -451,22 +362,24 @@ index 5349bce24..0103fd03a 100644
  }
  
  static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) {
-@@ -14331,6 +14435,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
+@@ -15603,7 +15682,9 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
          static std::mutex mutex;
          std::lock_guard lock(mutex);
          if (!initialized) {
 +            std::vector vk_devices = vk_instance.instance.enumeratePhysicalDevices();
+             const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
 +
              for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) {
                  ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;
                  char desc[256];
-@@ -14339,12 +14445,41 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
+@@ -15612,13 +15693,42 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
                  ctx->name = GGML_VK_NAME + std::to_string(i);
                  ctx->description = desc;
                  ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu;
 -                ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i);
+-                ctx->op_offload_min_batch_size = min_batch_size;
 +                ctx->pci_id = ggml_backend_vk_get_device_pci_id(i);
-+                ctx->id = ggml_backend_vk_get_device_id(i);
++                ctx->id = ggml_vk_get_device_id(i);
                  devices.push_back(new ggml_backend_device {
                      /* .iface   = */ ggml_backend_vk_device_i,
                      /* .reg     = */ reg,
@@ -488,8 +401,8 @@ index 5349bce24..0103fd03a 100644
 +                std::ostringstream oss;
 +                oss << std::hex << std::setfill('0');
 +                int byteIdx = 0;
-+                for (int i = 0; i < 16; ++i, ++byteIdx) {
-+                    oss << std::setw(2) << static_cast(device_id_props.deviceUUID[i]);
++                for (int j = 0; j < 16; ++j, ++byteIdx) {
++                    oss << std::setw(2) << static_cast(device_id_props.deviceUUID[j]);
 +                    if (byteIdx == 3 || byteIdx == 5 || byteIdx == 7 || byteIdx == 9) {
 +                        oss << '-';
 +                    }
@@ -500,12 +413,13 @@ index 5349bce24..0103fd03a 100644
 +                // TODO regex parse driver_props.driverInfo for a X.Y or X.Y.Z version string
 +                ctx->driver_major = 0;
 +                ctx->driver_minor = 0;
++                ctx->op_offload_min_batch_size = min_batch_size;
              }
              initialized = true;
          }
 diff --git a/ggml/src/mem_hip.cpp b/ggml/src/mem_hip.cpp
 new file mode 100644
-index 000000000..23c765806
+index 000000000..734d437a7
 --- /dev/null
 +++ b/ggml/src/mem_hip.cpp
 @@ -0,0 +1,558 @@
@@ -799,7 +713,7 @@ index 000000000..23c765806
 +        const char *version = NULL;
 +        ADLX_RESULT status = adlx.ADLXQueryVersion(&version);
 +        if (ADLX_SUCCEEDED(status)) {
-+            GGML_LOG_DEBUG("%s located ADLX version %s\n", __func__, version);  
++            GGML_LOG_DEBUG("%s located ADLX version %s\n", __func__, version);
 +        }
 +    }
 +
@@ -917,7 +831,7 @@ index 000000000..23c765806
 +            adlx_gdm_cleanup;
 +            return status;
 +        }
-+        
++
 +        adlx_uint totalVRAM = 0;
 +        status = gpu->pVtbl->TotalVRAM(gpu, &totalVRAM);
 +        if (ADLX_FAILED(status)) {
@@ -1067,10 +981,9 @@ index 000000000..23c765806
 +} // extern "C"
 +
 +#endif // #ifdef _WIN32
-\ No newline at end of file
 diff --git a/ggml/src/mem_nvml.cpp b/ggml/src/mem_nvml.cpp
 new file mode 100644
-index 000000000..c9073cef0
+index 000000000..b5d46cbe7
 --- /dev/null
 +++ b/ggml/src/mem_nvml.cpp
 @@ -0,0 +1,209 @@
@@ -1283,4 +1196,3 @@ index 000000000..c9073cef0
 +}
 +
 +}
-\ No newline at end of file
diff --git a/llama/patches/0025-NVML-fallback-for-unified-memory-GPUs.patch b/llama/patches/0024-NVML-fallback-for-unified-memory-GPUs.patch
similarity index 99%
rename from llama/patches/0025-NVML-fallback-for-unified-memory-GPUs.patch
rename to llama/patches/0024-NVML-fallback-for-unified-memory-GPUs.patch
index ec3fdbaabb2..e74ec145d9d 100644
--- a/llama/patches/0025-NVML-fallback-for-unified-memory-GPUs.patch
+++ b/llama/patches/0024-NVML-fallback-for-unified-memory-GPUs.patch
@@ -8,7 +8,7 @@ Subject: [PATCH] NVML fallback for unified memory GPUs
  1 file changed, 68 insertions(+), 3 deletions(-)
 
 diff --git a/ggml/src/mem_nvml.cpp b/ggml/src/mem_nvml.cpp
-index c9073cef0..f473a2a2c 100644
+index b5d46cbe7..f8a4ac7b5 100644
 --- a/ggml/src/mem_nvml.cpp
 +++ b/ggml/src/mem_nvml.cpp
 @@ -13,6 +13,7 @@
diff --git a/llama/patches/0025-report-LoadLibrary-failures.patch b/llama/patches/0025-report-LoadLibrary-failures.patch
new file mode 100644
index 00000000000..ba799737eb1
--- /dev/null
+++ b/llama/patches/0025-report-LoadLibrary-failures.patch
@@ -0,0 +1,55 @@
+From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
+From: Daniel Hiltgen 
+Date: Fri, 17 Oct 2025 14:17:00 -0700
+Subject: [PATCH] report LoadLibrary failures
+
+---
+ ggml/src/ggml-backend-dl.cpp | 28 ++++++++++++++++++++++++++++
+ 1 file changed, 28 insertions(+)
+
+diff --git a/ggml/src/ggml-backend-dl.cpp b/ggml/src/ggml-backend-dl.cpp
+index a65cf0090..34949c261 100644
+--- a/ggml/src/ggml-backend-dl.cpp
++++ b/ggml/src/ggml-backend-dl.cpp
+@@ -1,13 +1,41 @@
+ #include "ggml-backend-dl.h"
++#include "ggml-impl.h"
+ 
+ #ifdef _WIN32
+ 
++static std::string path_str(const fs::path & path) {
++    try {
++#if defined(__cpp_lib_char8_t)
++        // C++20 and later: u8string() returns std::u8string
++        const std::u8string u8str = path.u8string();
++        return std::string(reinterpret_cast(u8str.data()), u8str.size());
++#else
++        // C++17: u8string() returns std::string
++        return path.u8string();
++#endif
++    } catch (...) {
++        return std::string();
++    }
++}
++
+ dl_handle * dl_load_library(const fs::path & path) {
+     // suppress error dialogs for missing DLLs
+     DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
+     SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
+ 
+     HMODULE handle = LoadLibraryW(path.wstring().c_str());
++    if (!handle) {
++        DWORD error_code = GetLastError();
++        std::string msg;
++        LPSTR lpMsgBuf = NULL;
++        DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
++                                      NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL);
++        if (bufLen) {
++            msg = lpMsgBuf;
++            LocalFree(lpMsgBuf);
++            GGML_LOG_INFO("%s unable to load library %s: %s\n", __func__, path_str(path).c_str(), msg.c_str());
++        }
++    }
+ 
+     SetErrorMode(old_mode);
+ 
diff --git a/llama/patches/0027-interleave-multi-rope.patch b/llama/patches/0026-interleave-multi-rope.patch
similarity index 74%
rename from llama/patches/0027-interleave-multi-rope.patch
rename to llama/patches/0026-interleave-multi-rope.patch
index 6ca94029d7f..b9601693fba 100644
--- a/llama/patches/0027-interleave-multi-rope.patch
+++ b/llama/patches/0026-interleave-multi-rope.patch
@@ -13,10 +13,10 @@ interleaved version used for qwen3vl
  4 files changed, 16 insertions(+), 16 deletions(-)
 
 diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
-index 7d1733adb..f4aae5332 100644
+index c90b7db7e..5988ef40b 100644
 --- a/ggml/src/ggml-cpu/ops.cpp
 +++ b/ggml/src/ggml-cpu/ops.cpp
-@@ -5599,14 +5599,14 @@ static void ggml_mrope_cache_init(
+@@ -5663,14 +5663,14 @@ static void ggml_mrope_cache_init(
  
          float theta = theta_t;
          if (is_imrope) { // qwen3vl apply interleaved mrope
@@ -36,33 +36,33 @@ index 7d1733adb..f4aae5332 100644
          } else {
              if (sector >= sections[0] && sector < sec_w) {
 diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu
-index 88ed79111..71ca60214 100644
+index 45a49a5dc..f47392de6 100644
 --- a/ggml/src/ggml-cuda/rope.cu
 +++ b/ggml/src/ggml-cuda/rope.cu
-@@ -200,14 +200,14 @@ static __global__ void rope_multi(
+@@ -229,14 +229,14 @@ static __global__ void rope_multi(const T *            x,
  
      float theta_base = 0.0;
      if (is_imrope) {
--        if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
-+        if (sector % 3 == 1 && sector < 1 + 3 * sections.v[1]) { // h
-             theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
--        } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
-+        } else if (sector % 3 == 2 && sector < 2 + 3 * sections.v[2]) { // w
-             theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
-         } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
-             theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
+-        if (sector % 3 == 1 && sector < 3 * sections.v[1]) {         // h
++        if (sector % 3 == 1 && sector < 1 + 3 * sections.v[1]) {         // h
+             theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f);
+-        } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) {  // w
++        } else if (sector % 3 == 2 && sector < 2 + 3 * sections.v[2]) {  // w
+             theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f);
+         } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) {  // t
+             theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f);
 -        } else {
--            theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
+-            theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f);
 +        // } else {
-+        //     theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
++        //    theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f);
          }
      } else {
          if (sector < sections.v[0]) {
 diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
-index 236838e9e..c98d269d1 100644
+index 628af4bb4..1fbdfc385 100644
 --- a/ggml/src/ggml-metal/ggml-metal.metal
 +++ b/ggml/src/ggml-metal/ggml-metal.metal
-@@ -4242,14 +4242,14 @@ kernel void kernel_rope_multi(
+@@ -4024,14 +4024,14 @@ kernel void kernel_rope_multi(
  
              float theta_base;
              if (FC_rope_is_imrope) {
@@ -82,25 +82,25 @@ index 236838e9e..c98d269d1 100644
              } else {
                  if (sector < args.sect_0) {
 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl
-index 9726b722d..1c8c69422 100644
+index 2e5345990..f2028c4c5 100644
 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl
 +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl
-@@ -148,14 +148,14 @@ void rope_multi(const uint i0, const uint i1, rope_params p) {
+@@ -135,14 +135,14 @@ void rope_multi(const uint i0, const uint i1, const uint i2, const uint i3, rope
  
      float theta_base = 0.0;
      if (p.is_imrope != 0) {
 -        if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
 +        if (sector % 3 == 1 && sector < 1 + 3 * p.sections[1]) {
-             theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
+             theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f);
 -        } else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
 +        } else if (sector % 3 == 2 && sector < 2 + 3 * p.sections[2]) {
-             theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
+             theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f);
          } else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
-             theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
+             theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f);
 -        } else {
--            theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
+-            theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f);
 +        //} else {
-+        //    theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
++        //    theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f);
          }
      } else {
          if (sector < p.sections[0]) {
diff --git a/llama/patches/0026-report-LoadLibrary-failures.patch b/llama/patches/0026-report-LoadLibrary-failures.patch
deleted file mode 100644
index 7f0e9be92e8..00000000000
--- a/llama/patches/0026-report-LoadLibrary-failures.patch
+++ /dev/null
@@ -1,32 +0,0 @@
-From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: Daniel Hiltgen 
-Date: Fri, 17 Oct 2025 14:17:00 -0700
-Subject: [PATCH] report LoadLibrary failures
-
----
- ggml/src/ggml-backend-reg.cpp | 12 ++++++++++++
- 1 file changed, 12 insertions(+)
-
-diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
-index 079dba211..2474e0ed6 100644
---- a/ggml/src/ggml-backend-reg.cpp
-+++ b/ggml/src/ggml-backend-reg.cpp
-@@ -126,6 +126,18 @@ static dl_handle * dl_load_library(const fs::path & path) {
-     SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
- 
-     HMODULE handle = LoadLibraryW(path.wstring().c_str());
-+    if (!handle) {
-+        DWORD error_code = GetLastError();
-+        std::string msg;
-+        LPSTR lpMsgBuf = NULL;
-+        DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
-+                                      NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL);
-+        if (bufLen) {
-+            msg = lpMsgBuf;
-+            LocalFree(lpMsgBuf);
-+            GGML_LOG_INFO("%s unable to load library %s: %s\n", __func__, path_str(path).c_str(), msg.c_str());
-+        }
-+    }
- 
-     SetErrorMode(old_mode);
- 
diff --git a/llama/patches/0028-Add-memory-detection-using-DXGI-PDH.patch b/llama/patches/0027-ollama-Add-memory-detection-using-DXGI-PDH.patch
similarity index 91%
rename from llama/patches/0028-Add-memory-detection-using-DXGI-PDH.patch
rename to llama/patches/0027-ollama-Add-memory-detection-using-DXGI-PDH.patch
index e7bca2de0ae..f88f84527f0 100644
--- a/llama/patches/0028-Add-memory-detection-using-DXGI-PDH.patch
+++ b/llama/patches/0027-ollama-Add-memory-detection-using-DXGI-PDH.patch
@@ -1,18 +1,21 @@
 From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
 From: Viraj Wadhwa 
 Date: Tue, 4 Nov 2025 12:04:04 -0800
-Subject: [PATCH] Add memory detection using DXGI + PDH
+Subject: [PATCH] ollama: Add memory detection using DXGI + PDH
 
+Add Windows-specific VRAM detection using DXGI and PDH performance counters.
+This provides accurate memory reporting for both integrated and discrete GPUs.
+Add luid field to Vulkan device context for device matching.
 ---
  ggml/src/CMakeLists.txt              |   1 +
  ggml/src/ggml-impl.h                 |   3 +
- ggml/src/ggml-vulkan/ggml-vulkan.cpp |  26 ++-
+ ggml/src/ggml-vulkan/ggml-vulkan.cpp |  24 +++
  ggml/src/mem_dxgi_pdh.cpp            | 297 +++++++++++++++++++++++++++
- 4 files changed, 325 insertions(+), 2 deletions(-)
+ 4 files changed, 325 insertions(+)
  create mode 100644 ggml/src/mem_dxgi_pdh.cpp
 
 diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
-index 99ae293cc..9a134b7af 100644
+index ab763741d..201a260a3 100644
 --- a/ggml/src/CMakeLists.txt
 +++ b/ggml/src/CMakeLists.txt
 @@ -207,6 +207,7 @@ add_library(ggml-base
@@ -24,10 +27,10 @@ index 99ae293cc..9a134b7af 100644
  
  set_target_properties(ggml-base PROPERTIES
 diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
-index dba8f4695..7e17032c7 100644
+index 80ebb059e..b502fdf78 100644
 --- a/ggml/src/ggml-impl.h
 +++ b/ggml/src/ggml-impl.h
-@@ -684,6 +684,9 @@ GGML_API void ggml_nvml_release();
+@@ -687,6 +687,9 @@ GGML_API void ggml_nvml_release();
  GGML_API int ggml_hip_mgmt_init();
  GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu);
  GGML_API void ggml_hip_mgmt_release();
@@ -38,7 +41,7 @@ index dba8f4695..7e17032c7 100644
  #ifdef __cplusplus
  }
 diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
-index 0103fd03a..9cc4ebdef 100644
+index 1d3705400..5bd306307 100644
 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
 +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
 @@ -74,6 +74,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
@@ -49,7 +52,7 @@ index 0103fd03a..9cc4ebdef 100644
  
  typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR {
      VkStructureType                       sType;
-@@ -13669,6 +13670,7 @@ struct ggml_backend_vk_device_context {
+@@ -14891,6 +14892,7 @@ struct ggml_backend_vk_device_context {
      std::string pci_id;
      std::string id;
      std::string uuid;
@@ -57,12 +60,12 @@ index 0103fd03a..9cc4ebdef 100644
      int major;
      int minor;
      int driver_major;
-@@ -13687,6 +13689,20 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size
-     
-     vk::PhysicalDeviceProperties2 props2;
-     vkdev.getProperties2(&props2);
-+    GGML_LOG_DEBUG("ggml_backend_vk_get_device_memory called: uuid %s\n", ctx->uuid.c_str());
-+    GGML_LOG_DEBUG("ggml_backend_vk_get_device_memory called: luid %s\n", ctx->luid.c_str());
+@@ -14915,6 +14917,20 @@ static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) {
+ 
+ static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
+     ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context;
++    GGML_LOG_DEBUG("ggml_backend_vk_device_get_memory called: uuid %s\n", ctx->uuid.c_str());
++    GGML_LOG_DEBUG("ggml_backend_vk_device_get_memory called: luid %s\n", ctx->luid.c_str());
 +
 +    // Check VRAM reporting for Windows IGPU/DGPU using DXGI + PDH (vendor agnostic)
 +    if (ggml_dxgi_pdh_init() == 0) {
@@ -76,25 +79,9 @@ index 0103fd03a..9cc4ebdef 100644
 +        ggml_dxgi_pdh_release();
 +    }
  
-     if (!is_integrated_gpu)
-     {
-@@ -13718,7 +13734,6 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size
-     }
-     // else fallback to memory budget if supported
- 
--
-     if (membudget_supported) {
-         memprops.pNext = &budgetprops;
-     }
-@@ -14452,7 +14467,6 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
-                     /* .reg     = */ reg,
-                     /* .context = */ ctx,
-                 });
--
-                 // Gather additional information about the device
-                 int dev_idx = vk_instance.device_indices[i];
-                 vk::PhysicalDeviceProperties props1;
-@@ -14475,6 +14489,14 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
+     // Use vendor specific management libraries for best VRAM reporting if available
+     if (!ctx->is_integrated_gpu) {
+@@ -15723,6 +15739,14 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
                      }
                  }
                  ctx->uuid = oss.str();
@@ -111,7 +98,7 @@ index 0103fd03a..9cc4ebdef 100644
                  // TODO regex parse driver_props.driverInfo for a X.Y or X.Y.Z version string
 diff --git a/ggml/src/mem_dxgi_pdh.cpp b/ggml/src/mem_dxgi_pdh.cpp
 new file mode 100644
-index 000000000..2f395761c
+index 000000000..4dd66c25f
 --- /dev/null
 +++ b/ggml/src/mem_dxgi_pdh.cpp
 @@ -0,0 +1,297 @@
@@ -158,7 +145,7 @@ index 000000000..2f395761c
 +    void *pdh_dll_handle;
 +    // DXGI Functions
 +    HRESULT (*CreateDXGIFactory1)(REFIID riid, void **ppFactory);
-+    // PDH functions  
++    // PDH functions
 +    PDH_STATUS (*PdhOpenQueryW)(LPCWSTR szDataSource, DWORD_PTR dwUserData, PDH_HQUERY *phQuery);
 +    PDH_STATUS (*PdhAddCounterW)(PDH_HQUERY hQuery, LPCWSTR szFullCounterPath, DWORD_PTR dwUserData, PDH_HCOUNTER *phCounter);
 +    PDH_STATUS (*PdhCollectQueryData)(PDH_HQUERY hQuery);
@@ -213,7 +200,7 @@ index 000000000..2f395761c
 +        while (pFactory->EnumAdapters1(i, &pAdapter) != DXGI_ERROR_NOT_FOUND) {
 +            DXGI_ADAPTER_DESC1 desc;
 +            pAdapter->GetDesc1(&desc);
-+            
++
 +            // Get all the GPU adapter info
 +            GpuInfo info;
 +            fetch_dxgi_adapter_desc1(desc, &info);
@@ -314,7 +301,7 @@ index 000000000..2f395761c
 +        dll_functions.PdhCollectQueryData = (PDH_STATUS (*)(PDH_HQUERY hQuery)) GetProcAddress((HMODULE)(dll_functions.pdh_dll_handle), "PdhCollectQueryData");
 +        dll_functions.PdhGetFormattedCounterValue = (PDH_STATUS (*)(PDH_HCOUNTER hCounter, DWORD dwFormat, LPDWORD lpdwType, PPDH_FMT_COUNTERVALUE pValue)) GetProcAddress((HMODULE)(dll_functions.pdh_dll_handle), "PdhGetFormattedCounterValue");
 +        dll_functions.PdhCloseQuery = (PDH_STATUS (*)(PDH_HQUERY hQuery)) GetProcAddress((HMODULE)(dll_functions.pdh_dll_handle), "PdhCloseQuery");
-+    
++
 +        SetErrorMode(old_mode); // set old mode before any return
 +
 +        // Check if any function pointers are NULL (not found)
@@ -326,7 +313,7 @@ index 000000000..2f395761c
 +            dll_functions.pdh_dll_handle = NULL;
 +            return ERROR_PROC_NOT_FOUND;
 +        }
-+    
++
 +        // No other initializations needed, successfully loaded the libraries and functions!
 +        return ERROR_SUCCESS;
 +    }
@@ -412,4 +399,3 @@ index 000000000..2f395761c
 +} // extern "C"
 +
 +#endif // #ifdef _WIN32
-\ No newline at end of file
diff --git a/llama/patches/0029-ggml-cuda-skip-large-batches.patch b/llama/patches/0028-ggml-cuda-skip-large-batches.patch
similarity index 91%
rename from llama/patches/0029-ggml-cuda-skip-large-batches.patch
rename to llama/patches/0028-ggml-cuda-skip-large-batches.patch
index 483c56537de..968f23e8cbb 100644
--- a/llama/patches/0029-ggml-cuda-skip-large-batches.patch
+++ b/llama/patches/0028-ggml-cuda-skip-large-batches.patch
@@ -10,10 +10,10 @@ fallback to cpu
  1 file changed, 3 insertions(+)
 
 diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
-index 334a30135..5c9dfd032 100644
+index 2a76f0485..e42a1599b 100644
 --- a/ggml/src/ggml-cuda/ggml-cuda.cu
 +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
-@@ -4633,6 +4633,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
+@@ -4835,6 +4835,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                  if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
                      return false;
                  }
diff --git a/llama/patches/0030-fix-bakllava-regression.patch b/llama/patches/0029-fix-bakllava-regression.patch
similarity index 91%
rename from llama/patches/0030-fix-bakllava-regression.patch
rename to llama/patches/0029-fix-bakllava-regression.patch
index 14ef26b57ff..97b57d76d82 100644
--- a/llama/patches/0030-fix-bakllava-regression.patch
+++ b/llama/patches/0029-fix-bakllava-regression.patch
@@ -9,10 +9,10 @@ Rever to prior logic of assuming an empty projector type is mlp
  1 file changed, 4 insertions(+)
 
 diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp
-index 84a3796b5..d3a37842d 100644
+index d12110a5c..db83c9ecf 100644
 --- a/tools/mtmd/clip.cpp
 +++ b/tools/mtmd/clip.cpp
-@@ -960,6 +960,10 @@ struct clip_model_loader {
+@@ -992,6 +992,10 @@ struct clip_model_loader {
              if (proj_type.empty()) {
                  if (modality == CLIP_MODALITY_VISION) {
                      get_string(KEY_VISION_PROJ_TYPE, proj_type, false);
diff --git a/llama/patches/0031-win-exit-instead-of-abort.patch b/llama/patches/0030-win-exit-instead-of-abort.patch
similarity index 87%
rename from llama/patches/0031-win-exit-instead-of-abort.patch
rename to llama/patches/0030-win-exit-instead-of-abort.patch
index 4e4edcbd163..716c9c24121 100644
--- a/llama/patches/0031-win-exit-instead-of-abort.patch
+++ b/llama/patches/0030-win-exit-instead-of-abort.patch
@@ -8,10 +8,10 @@ Subject: [PATCH] win: exit instead of abort
  1 file changed, 6 insertions(+), 1 deletion(-)
 
 diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
-index eb3ae72ea..c9242a15a 100644
+index e9529fbb6..d3c039bd9 100644
 --- a/ggml/src/ggml.c
 +++ b/ggml/src/ggml.c
-@@ -250,8 +250,13 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) {
+@@ -252,8 +252,13 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) {
          fprintf(stderr, "%s\n", message);
          ggml_print_backtrace();
      }
diff --git a/llama/patches/0031-Improve-ggml_backend_vk_get_device_pci_id.patch b/llama/patches/0031-Improve-ggml_backend_vk_get_device_pci_id.patch
new file mode 100644
index 00000000000..fc38424eb46
--- /dev/null
+++ b/llama/patches/0031-Improve-ggml_backend_vk_get_device_pci_id.patch
@@ -0,0 +1,63 @@
+From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
+From: inforithmics 
+Date: Sat, 7 Feb 2026 21:35:07 +0100
+Subject: [PATCH] Improve ggml_backend_vk_get_device_pci_id
+
+---
+ ggml/src/ggml-vulkan/ggml-vulkan.cpp | 31 ++++++++++++++++++----------
+ 1 file changed, 20 insertions(+), 11 deletions(-)
+
+diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+index 5bd306307..3f7abaf45 100644
+--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
++++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+@@ -14850,10 +14850,10 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) {
+ 
+     const std::vector ext_props = device.enumerateDeviceExtensionProperties();
+ 
+-    bool ext_support = false;
+-
++    vk::PhysicalDeviceProperties devProps = device.getProperties();
++    bool ext_support = devProps.vendorID == VK_VENDOR_ID_AMD || devProps.vendorID == VK_VENDOR_ID_NVIDIA;
+     for (const auto& properties : ext_props) {
+-        if (strcmp("VK_EXT_pci_bus_info", properties.extensionName) == 0) {
++        if (strcmp(properties.extensionName, VK_EXT_PCI_BUS_INFO_EXTENSION_NAME) == 0) {
+             ext_support = true;
+             break;
+         }
+@@ -14867,18 +14867,27 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) {
+     vk::PhysicalDevicePCIBusInfoPropertiesEXT pci_bus_info = {};
+ 
+     props.pNext = &pci_bus_info;
++    try {
++        device.getProperties2(&props);
+ 
+-    device.getProperties2(&props);
++        // If not supported and values are 0, it might be invalid
++        if (!ext_support && pci_bus_info.pciDomain == 0 && pci_bus_info.pciBus == 0 &&
++            pci_bus_info.pciDevice == 0 && pci_bus_info.pciFunction == 0) {
++            return "";
++        }
+ 
+-    const uint32_t pci_domain = pci_bus_info.pciDomain;
+-    const uint32_t pci_bus = pci_bus_info.pciBus;
+-    const uint32_t pci_device = pci_bus_info.pciDevice;
+-    const uint8_t pci_function = (uint8_t) pci_bus_info.pciFunction; // pci function is between 0 and 7, prevent printf overflow warning
++        const uint32_t pci_domain = pci_bus_info.pciDomain;
++        const uint32_t pci_bus = pci_bus_info.pciBus;
++        const uint32_t pci_device = pci_bus_info.pciDevice;
++        const uint8_t pci_function = (uint8_t) pci_bus_info.pciFunction; // pci function is between 0 and 7, prevent printf overflow warning
+ 
+-    char pci_bus_id[16] = {};
+-    snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function);
++        char pci_bus_id[16] = {};
++        snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function);
+ 
+-    return std::string(pci_bus_id);
++        return std::string(pci_bus_id);
++    } catch(...) {
++        return "";
++    }
+ }
+ 
+ //////////////////////////
diff --git a/llama/patches/0032-ggml-enable-MLA-flash-attention-for-GLM-4.7-flash.patch b/llama/patches/0032-ggml-enable-MLA-flash-attention-for-GLM-4.7-flash.patch
deleted file mode 100644
index abd7df93014..00000000000
--- a/llama/patches/0032-ggml-enable-MLA-flash-attention-for-GLM-4.7-flash.patch
+++ /dev/null
@@ -1,309 +0,0 @@
-From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: nobody <>
-Date: Sat, 24 Jan 2026 02:31:01 +0000
-Subject: [PATCH] ggml: enable MLA flash attention for GLM-4.7-flash
-
-Add support for gqa_ratio 4 in MLA flash attention kernels. GLM-4.7-flash
-uses head size 576 with gqa_ratio 4, which was previously only supported
-for gqa_ratio 16 (DeepSeek).
-
-Metal changes:
-- Enable head size 576 for flash attention
-- Increase simdgroups to 8 for large heads (>=512)
-- Add case 8 kernel dispatch for 8 simdgroups
-
-CUDA changes:
-- Add gqa_ratio 4 support for head 576/512
-- Add tile configs for (576, 512, 4) and (576, 512, 8)
-- Add MMA config cases for ncols 4
-- Add template instances for ncols2=4
-- Fix nbatch_fa values in nvidia_fp32 config (32->64)
----
- ggml/src/ggml-cuda/fattn-mma-f16.cuh          | 40 +++++++++++++++----
- ggml/src/ggml-cuda/fattn-tile.cuh             | 16 ++++++++
- ggml/src/ggml-cuda/fattn.cu                   | 12 ++++--
- ...ttn-mma-f16-instance-ncols1_16-ncols2_4.cu |  1 +
- ...attn-mma-f16-instance-ncols1_2-ncols2_4.cu |  1 +
- ...attn-mma-f16-instance-ncols1_4-ncols2_4.cu |  1 +
- ...attn-mma-f16-instance-ncols1_8-ncols2_4.cu |  1 +
- ggml/src/ggml-metal/ggml-metal-device.m       |  8 +---
- ggml/src/ggml-metal/ggml-metal-ops.cpp        |  2 +-
- ggml/src/ggml-metal/ggml-metal.metal          |  1 +
- 10 files changed, 64 insertions(+), 19 deletions(-)
-
-diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
-index 7bd1044c1..3dea2205e 100644
---- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh
-+++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
-@@ -66,7 +66,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
-     GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2,  32, 128, 128, 128, 2, true);
-     GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2,  32, 128, 128, 128, 2, true);
- 
--    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  8,  64, 4,  32, 288, 256, 128, 1, false);
-+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  4,  64, 4,  32, 288, 256, 128, 1, false);
-+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  8,  64, 4,  32, 288, 256, 128, 1, true);
-     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16,  64, 4,  32, 288, 256, 128, 1, false);
-     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2,  32, 160, 128, 128, 1, false);
-     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1,  32, 160, 128, 128, 1, false);
-@@ -80,7 +81,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
-     GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2,  64, 128, 128,  64, 2, true);
-     GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2,  64, 128, 128,  64, 2, true);
- 
--    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  8,  64, 4,  32,  96,  64, 128, 1, false);
-+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  4,  64, 4,  32,  96,  64, 128, 1, false);
-+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  8,  64, 4,  32,  96,  64, 128, 1, true);
-     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16,  64, 4,  32,  96,  64, 128, 1, false);
-     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2,  32, 160, 128, 128, 1, false);
-     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1,  32, 160, 128, 128, 1, false);
-@@ -89,7 +91,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
- }
- 
- static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
--    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  8,  64, 4,  32, 288, 256,  64, 1, false);
-+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  4,  64, 4,  32, 288, 256,  64, 1, false);
-+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  8,  64, 4,  32, 288, 256,  64, 1, true);
-     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16,  64, 4,  32, 288, 256,  64, 1, false);
-     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2,  32, 160, 128,  64, 1, false);
-     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1,  32, 160, 128,  64, 1, false);
-@@ -397,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
-     constexpr int  ncols           = ncols1 * ncols2;
-     constexpr int  cols_per_warp   = T_B_KQ::I;
-     constexpr int  cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
--    constexpr int  np              = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
-+    constexpr int  np              = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
-     constexpr int  nbatch_fa       = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
-     constexpr int  nbatch_K2       = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
-     constexpr int  nbatch_V2       = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
-@@ -467,7 +470,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
-                 }
-             }
-         } else {
--            static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
- #pragma unroll
-             for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
-                 load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
-@@ -479,8 +481,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
-                     T_A_KQ K_A;
-                     load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
- 
--                    // Wide version of KQ_C is column-major => swap A and B.
--                    mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
-+                    if constexpr (cols_per_warp == 8) {
-+                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
-+                    } else {
-+                        // Wide version of KQ_C is column-major
-+#if defined(AMD_WMMA_AVAILABLE)
-+                        // RDNA matrix C is column-major.
-+                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
-+#else
-+                        // swap A and B for CUDA.
-+                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
-+#endif // defined(AMD_WMMA_AVAILABLE)
-+                    }
-                 }
-             }
-         }
-@@ -841,7 +853,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
- 
-     constexpr int  cols_per_warp   = T_B_KQ::I;
-     constexpr int  cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
--    constexpr int  np              = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
-+    constexpr int  np              = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
-     constexpr int  nbatch_fa       = ggml_cuda_fattn_mma_get_nbatch_fa     (DKQ, DV, ncols);
-     constexpr int  nbatch_K2       = ggml_cuda_fattn_mma_get_nbatch_K2     (DKQ, DV, ncols);
-     constexpr int  nbatch_V2       = ggml_cuda_fattn_mma_get_nbatch_V2     (DKQ, DV, ncols);
-@@ -1353,6 +1365,13 @@ static __global__ void flash_attn_ext_f16(
-         NO_DEVICE_CODE;
-         return;
-     }
-+#ifdef VOLTA_MMA_AVAILABLE
-+    if (ncols1*ncols2 < 32) {
-+        NO_DEVICE_CODE;
-+        return;
-+    }
-+#endif // VOLTA_MMA_AVAILABLE
-+
- #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
-     if (ncols1*ncols2 > 32) {
-         NO_DEVICE_CODE;
-@@ -1585,3 +1604,8 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256,  64)
- extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
- extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
- extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
-+
-+// For GLM 4.7 Flash
-+extern DECL_FATTN_MMA_F16_CASE(576, 512,  4,  4);
-+extern DECL_FATTN_MMA_F16_CASE(576, 512,  8,  4);
-+extern DECL_FATTN_MMA_F16_CASE(576, 512, 16,  4);
-diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh
-index 7c4d6fe67..371be7442 100644
---- a/ggml/src/ggml-cuda/fattn-tile.cuh
-+++ b/ggml/src/ggml-cuda/fattn-tile.cuh
-@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
-     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  64,  64)
-     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  64,  64)
- 
-+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)
-+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)
-     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  64,  64)
- 
-     return 0;
-@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
-     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  32, 128)
-     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  32,  64)
- 
-+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  32,  64)
-+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  32,  64)
-     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  32,  64)
- 
-     return 0;
-@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
-     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  32, 128)
-     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  32, 128)
- 
-+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)
-+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)
-     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  64,  64)
-     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128,  64)
- 
-@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
-     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5,  32, 256)
-     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3,  64, 128)
- 
-+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)
-+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)
-     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4,  64,  64)
-     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128,  64)
- 
-@@ -1187,6 +1195,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
-             launch_fattn_tile_switch_ncols1(ctx, dst);
-             return;
-         }
-+        if (use_gqa_opt && gqa_ratio % 8 == 0) {
-+            launch_fattn_tile_switch_ncols1(ctx, dst);
-+            return;
-+        }
-+        if (use_gqa_opt && gqa_ratio % 4 == 0) {
-+            launch_fattn_tile_switch_ncols1(ctx, dst);
-+            return;
-+        }
-     }
- 
-     if constexpr (DV <= 256) {
-diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
-index 015540666..1693479cb 100644
---- a/ggml/src/ggml-cuda/fattn.cu
-+++ b/ggml/src/ggml-cuda/fattn.cu
-@@ -111,7 +111,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
-             ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
-             break;
-         case 576: {
--            // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
-+            // For Deepseek/GLM4, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
-             GGML_ASSERT(V->ne[0] == 512);
-             float max_bias = 0.0f;
-             memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
-@@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
- 
-             GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
-             const int gqa_ratio = Q->ne[2] / K->ne[2];
--            GGML_ASSERT(gqa_ratio % 16 == 0);
--            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
-+            GGML_ASSERT(gqa_ratio % 4 == 0);
-+            if (gqa_ratio % 16 == 0) {
-+                ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
-+            } else {
-+                ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512,  4>(ctx, dst);
-+            }
-         } break;
-         default:
-             GGML_ABORT("fatal error");
-@@ -251,7 +255,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
-             if (V->ne[0] != 512) {
-                 return BEST_FATTN_KERNEL_NONE;
-             }
--            if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
-+            if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
-                 return BEST_FATTN_KERNEL_NONE;
-             }
-             break;
-diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
-index 2074e954a..517993cb0 100644
---- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
-+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
-@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
- DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
- DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
- DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
-+DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
-diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
-index 24c64cf00..97b19c67a 100644
---- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
-+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
-@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
- DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
- DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
- DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
-+DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
-diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
-index 1ada657f1..989626dfa 100644
---- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
-+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
-@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
- DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
- DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
- DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
-+DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
-diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
-index 86d4ffae2..173de7aac 100644
---- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
-+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
-@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
- DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
- DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
- DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
-+DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
-diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m
-index f24270bb1..7b5ee968c 100644
---- a/ggml/src/ggml-metal/ggml-metal-device.m
-+++ b/ggml/src/ggml-metal/ggml-metal-device.m
-@@ -1071,12 +1071,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
-                 op->src[0]->ne[0] != 112 &&
-                 op->src[0]->ne[0] != 128 &&
-                 op->src[0]->ne[0] != 192 &&
--                op->src[0]->ne[0] != 256) {
--                return false;
--            }
--            if (op->src[0]->ne[0] == 576) {
--                // DeepSeek sizes
--                // TODO: disabled for now, until optmized
-+                op->src[0]->ne[0] != 256 &&
-+                op->src[0]->ne[0] != 576) {
-                 return false;
-             }
-             if (op->src[1]->type != op->src[2]->type) {
-diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
-index e99c1763f..80864f303 100644
---- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
-+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
-@@ -2456,7 +2456,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
- 
-         // simdgroups per threadgroup (a.k.a. warps)
-         //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
--        int32_t nsg = 4;
-+        int32_t nsg = ne00 >= 512 ? 8 : 4;
- 
-         const size_t smem = FATTN_SMEM(nsg);
- 
-diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
-index c98d269d1..d33c16079 100644
---- a/ggml/src/ggml-metal/ggml-metal.metal
-+++ b/ggml/src/ggml-metal/ggml-metal.metal
-@@ -6166,6 +6166,7 @@ kernel void kernel_flash_attn_ext(
-       //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break;
-       //case 2: kernel_flash_attn_ext_impl(FWD_ARGS); break;
-         case 4: kernel_flash_attn_ext_impl(FWD_ARGS); break;
-+        case 8: kernel_flash_attn_ext_impl(FWD_ARGS); break;
-     }
- #undef FWD_TMPL
- #undef FWD_ARGS
diff --git a/llama/patches/0034-ggml-metal-guard-mul_mat_id-map0-and-add-ne20-22-spe.patch b/llama/patches/0032-ggml-metal-guard-mul_mat_id-map0-and-add-ne20-22-spe.patch
similarity index 91%
rename from llama/patches/0034-ggml-metal-guard-mul_mat_id-map0-and-add-ne20-22-spe.patch
rename to llama/patches/0032-ggml-metal-guard-mul_mat_id-map0-and-add-ne20-22-spe.patch
index ff0c8199d5a..b649cae4469 100644
--- a/llama/patches/0034-ggml-metal-guard-mul_mat_id-map0-and-add-ne20-22-spe.patch
+++ b/llama/patches/0032-ggml-metal-guard-mul_mat_id-map0-and-add-ne20-22-spe.patch
@@ -10,10 +10,10 @@ Subject: [PATCH] ggml-metal: guard mul_mat_id map0 and add ne20=22
  2 files changed, 3 insertions(+), 1 deletion(-)
 
 diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
-index 4ac135603..ac5ad53db 100644
+index 3d5db0b79..771cb3876 100644
 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
 +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
-@@ -1961,7 +1961,8 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
+@@ -2207,7 +2207,8 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
      // ne21 = n_rows (batch size)
      const int ne21_mm_id_min = 32;
  
@@ -24,10 +24,10 @@ index 4ac135603..ac5ad53db 100644
          // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
          //switch (op->src[0]->type) {
 diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
-index c37447a10..4f338aa13 100644
+index 1fbdfc385..36b5a6812 100644
 --- a/ggml/src/ggml-metal/ggml-metal.metal
 +++ b/ggml/src/ggml-metal/ggml-metal.metal
-@@ -9427,6 +9427,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
+@@ -9127,6 +9127,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
  template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
  template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
  template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
diff --git a/llama/patches/0033-ggml-metal-solve_tri.patch b/llama/patches/0033-ggml-metal-solve_tri.patch
deleted file mode 100644
index 7bc65fda791..00000000000
--- a/llama/patches/0033-ggml-metal-solve_tri.patch
+++ /dev/null
@@ -1,276 +0,0 @@
-From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: Jeffrey Morgan 
-Date: Tue, 3 Feb 2026 12:00:00 -0800
-Subject: [PATCH] ggml: metal solve_tri
-
----
- ggml/src/ggml-metal/ggml-metal-device.cpp | 20 +++++++
- ggml/src/ggml-metal/ggml-metal-device.h   |  1 +
- ggml/src/ggml-metal/ggml-metal-device.m   | 11 ++++
- ggml/src/ggml-metal/ggml-metal-impl.h     | 21 ++++++++
- ggml/src/ggml-metal/ggml-metal-ops.cpp    | 63 +++++++++++++++++++++++
- ggml/src/ggml-metal/ggml-metal-ops.h      |  1 +
- ggml/src/ggml-metal/ggml-metal.metal      | 60 +++++++++++++++++++++
- 7 files changed, 177 insertions(+)
-
-diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp
-index 680904d13..83385c9ef 100644
---- a/ggml/src/ggml-metal/ggml-metal-device.cpp
-+++ b/ggml/src/ggml-metal/ggml-metal-device.cpp
-@@ -1370,6 +1370,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
-     return res;
- }
- 
-+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
-+    assert(op->op == GGML_OP_SOLVE_TRI);
-+
-+    GGML_ASSERT(ggml_is_contiguous(op->src[0]));
-+    GGML_ASSERT(ggml_is_contiguous(op->src[1]));
-+
-+    char base[256];
-+    char name[256];
-+
-+    snprintf(base, 256, "kernel_solve_tri_f32");
-+    snprintf(name, 256, "%s", base);
-+
-+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
-+    if (!res.pipeline) {
-+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-+    }
-+
-+    return res;
-+}
-+
- ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
-     assert(op->op == GGML_OP_GROUP_NORM);
- 
-diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h
-index 0a8b9211a..8a9d17460 100644
---- a/ggml/src/ggml-metal/ggml-metal-device.h
-+++ b/ggml/src/ggml-metal/ggml-metal-device.h
-@@ -133,6 +133,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k
- struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge       (ggml_metal_library_t lib, const struct ggml_tensor * op);
- struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin               (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
- struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm           (ggml_metal_library_t lib, const struct ggml_tensor * op);
-+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri         (ggml_metal_library_t lib, const struct ggml_tensor * op);
- struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm        (ggml_metal_library_t lib, const struct ggml_tensor * op);
- struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm              (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
- struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope              (ggml_metal_library_t lib, const struct ggml_tensor * op);
-diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m
-index 7b5ee968c..4e5acfbe5 100644
---- a/ggml/src/ggml-metal/ggml-metal-device.m
-+++ b/ggml/src/ggml-metal/ggml-metal-device.m
-@@ -1023,6 +1023,17 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
-             return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
-         case GGML_OP_L2_NORM:
-             return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
-+        case GGML_OP_SOLVE_TRI:
-+            return ggml_is_contiguous(op->src[0]) &&
-+                ggml_is_contiguous(op->src[1]) &&
-+                op->src[0]->type == GGML_TYPE_F32 &&
-+                op->src[1]->type == GGML_TYPE_F32 &&
-+                op->type == GGML_TYPE_F32;
-+        case GGML_OP_COUNT_EQUAL:
-+            return has_simdgroup_reduction &&
-+                op->src[0]->type == GGML_TYPE_I32 &&
-+                op->src[1]->type == GGML_TYPE_I32 &&
-+                op->type == GGML_TYPE_I64;
-         case GGML_OP_ARGMAX:
-             return has_simdgroup_reduction;
-         case GGML_OP_NORM:
-diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h
-index 8944b07e9..cfdea9c07 100644
---- a/ggml/src/ggml-metal/ggml-metal-impl.h
-+++ b/ggml/src/ggml-metal/ggml-metal-impl.h
-@@ -500,6 +500,27 @@ typedef struct {
-     float    eps;
- } ggml_metal_kargs_l2_norm;
- 
-+typedef struct {
-+    int32_t  ne00;
-+    int32_t  ne01;
-+    int32_t  ne02;
-+    int32_t  ne03;
-+    uint64_t nb00;
-+    uint64_t nb01;
-+    uint64_t nb02;
-+    uint64_t nb03;
-+    int32_t  ne10;
-+    int32_t  ne11;
-+    uint64_t nb10;
-+    uint64_t nb11;
-+    uint64_t nb12;
-+    uint64_t nb13;
-+    uint64_t nb0;
-+    uint64_t nb1;
-+    uint64_t nb2;
-+    uint64_t nb3;
-+} ggml_metal_kargs_solve_tri;
-+
- typedef struct {
-     int64_t  ne00;
-     int64_t  ne01;
-diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
-index 80864f303..4ac135603 100644
---- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
-+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
-@@ -357,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
-             {
-                 n_fuse = ggml_metal_op_l2_norm(ctx, idx);
-             } break;
-+        case GGML_OP_SOLVE_TRI:
-+            {
-+                n_fuse = ggml_metal_op_solve_tri(ctx, idx);
-+            } break;
-         case GGML_OP_GROUP_NORM:
-             {
-                 n_fuse = ggml_metal_op_group_norm(ctx, idx);
-@@ -2931,6 +2935,65 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
-     return 1;
- }
- 
-+int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
-+    ggml_tensor * op = ctx->node(idx);
-+
-+    ggml_metal_library_t lib = ctx->lib;
-+    ggml_metal_encoder_t enc = ctx->enc;
-+
-+    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
-+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
-+    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
-+    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
-+    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
-+
-+    ggml_metal_kargs_solve_tri args = {
-+        /*.ne00 =*/ ne00,
-+        /*.ne01 =*/ ne01,
-+        /*.ne02 =*/ ne02,
-+        /*.ne03 =*/ ne03,
-+        /*.nb00 =*/ nb00,
-+        /*.nb01 =*/ nb01,
-+        /*.nb02 =*/ nb02,
-+        /*.nb03 =*/ nb03,
-+        /*.ne10 =*/ ne10,
-+        /*.ne11 =*/ ne11,
-+        /*.nb10 =*/ nb10,
-+        /*.nb11 =*/ nb11,
-+        /*.nb12 =*/ nb12,
-+        /*.nb13 =*/ nb13,
-+        /*.nb0  =*/ nb0,
-+        /*.nb1  =*/ nb1,
-+        /*.nb2  =*/ nb2,
-+        /*.nb3  =*/ nb3,
-+    };
-+
-+    auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
-+
-+    const int64_t ncols = ne10;
-+    const int64_t n_batches = (int64_t)ne02 * ne03;
-+    const int64_t nr = n_batches * ncols;
-+
-+    int nth = 64;
-+    nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
-+    if (nth < 1) {
-+        nth = 1;
-+    }
-+
-+    const int64_t n_tg = (nr + nth - 1) / nth;
-+
-+    ggml_metal_encoder_set_pipeline(enc, pipeline);
-+    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
-+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
-+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
-+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);
-+
-+    ggml_metal_encoder_dispatch_threadgroups(enc, n_tg, 1, 1, nth, 1, 1);
-+
-+    return 1;
-+}
-+
- int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
-     ggml_tensor * op = ctx->node(idx);
- 
-diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h
-index 902b54452..a475183d3 100644
---- a/ggml/src/ggml-metal/ggml-metal-ops.h
-+++ b/ggml/src/ggml-metal/ggml-metal-ops.h
-@@ -68,6 +68,7 @@ int ggml_metal_op_add_id            (ggml_metal_op_t ctx, int idx);
- int ggml_metal_op_flash_attn_ext    (ggml_metal_op_t ctx, int idx);
- int ggml_metal_op_bin               (ggml_metal_op_t ctx, int idx);
- int ggml_metal_op_l2_norm           (ggml_metal_op_t ctx, int idx);
-+int ggml_metal_op_solve_tri         (ggml_metal_op_t ctx, int idx);
- int ggml_metal_op_group_norm        (ggml_metal_op_t ctx, int idx);
- int ggml_metal_op_norm              (ggml_metal_op_t ctx, int idx);
- int ggml_metal_op_rope              (ggml_metal_op_t ctx, int idx);
-diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
-index d33c16079..c37447a10 100644
---- a/ggml/src/ggml-metal/ggml-metal.metal
-+++ b/ggml/src/ggml-metal/ggml-metal.metal
-@@ -3012,6 +3012,66 @@ kernel void kernel_l2_norm_f32(
-     }
- }
- 
-+kernel void kernel_solve_tri_f32(
-+        constant ggml_metal_kargs_solve_tri & args,
-+        device const char * src0,
-+        device const char * src1,
-+        device       char * dst,
-+        uint   tgpig[[threadgroup_position_in_grid]],
-+        ushort tpitg[[thread_position_in_threadgroup]],
-+        ushort   ntg[[threads_per_threadgroup]]) {
-+    const uint64_t ncols = (uint64_t) args.ne10;
-+    const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
-+    const uint64_t nr = n_batches * ncols;
-+
-+    const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
-+    if (gid >= nr) {
-+        return;
-+    }
-+
-+    const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
-+    const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
-+    const uint64_t i02 = rem / ncols;
-+    const uint64_t i01 = rem - i02 * ncols;
-+
-+    const uint64_t sa0 = args.nb00 / sizeof(float);
-+    const uint64_t sa1 = args.nb01 / sizeof(float);
-+    const uint64_t sa2 = args.nb02 / sizeof(float);
-+    const uint64_t sa3 = args.nb03 / sizeof(float);
-+
-+    const uint64_t sb0 = args.nb10 / sizeof(float);
-+    const uint64_t sb1 = args.nb11 / sizeof(float);
-+    const uint64_t sb2 = args.nb12 / sizeof(float);
-+    const uint64_t sb3 = args.nb13 / sizeof(float);
-+
-+    const uint64_t sx0 = args.nb0 / sizeof(float);
-+    const uint64_t sx1 = args.nb1 / sizeof(float);
-+    const uint64_t sx2 = args.nb2 / sizeof(float);
-+    const uint64_t sx3 = args.nb3 / sizeof(float);
-+
-+    device const float * A = (device const float *) src0;
-+    device const float * B = (device const float *) src1;
-+    device       float * X = (device       float *) dst;
-+
-+    const uint64_t A_base = i02 * sa2 + i03 * sa3;
-+    const uint64_t B_base = i02 * sb2 + i03 * sb3;
-+    const uint64_t X_base = i02 * sx2 + i03 * sx3;
-+
-+    const uint64_t n = (uint64_t) args.ne11;
-+
-+    for (uint64_t i00 = 0; i00 < n; ++i00) {
-+        float sum = 0.0f;
-+        for (uint64_t t = 0; t < i00; ++t) {
-+            sum += A[A_base + i00 * sa1 + t * sa0] *
-+                X[X_base + t * sx1 + i01 * sx0];
-+        }
-+
-+        const float diag = A[A_base + i00 * sa1 + i00 * sa0];
-+        X[X_base + i00 * sx1 + i01 * sx0] =
-+            (B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
-+    }
-+}
-+
- kernel void kernel_group_norm_f32(
-         constant ggml_metal_kargs_group_norm & args,
-         device const float * src0,
diff --git a/llama/sampling_ext.cpp b/llama/sampling_ext.cpp
index 9ae5fd578f6..8ad39474c14 100644
--- a/llama/sampling_ext.cpp
+++ b/llama/sampling_ext.cpp
@@ -72,7 +72,7 @@ struct llama_vocab * llama_load_vocab_from_file(const char * fname) {
     try {
         const auto kv = LLM_KV(LLM_ARCH_UNKNOWN);
         std::vector splits = {};
-        llama_model_loader ml(std::string(fname), splits, false, false, false, nullptr, nullptr);
+        llama_model_loader ml(std::string(fname), splits, false, false, false, false, nullptr, nullptr);
         vocab->load(ml, kv);
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
diff --git a/llm/server.go b/llm/server.go
index 291dd47fe20..4a944cc15b4 100644
--- a/llm/server.go
+++ b/llm/server.go
@@ -688,7 +688,6 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
 	if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
 		(runtime.GOOS == "linux" && systemInfo.FreeMemory < totalSize && s.options.UseMMap == nil) ||
 		(len(gpus) == 0 && s.options.UseMMap == nil) ||
-		(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
 		(s.options.UseMMap != nil && !*s.options.UseMMap) {
 		s.loadRequest.UseMmap = false
 	}
diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go
index c2ce30dd8cf..1312745e3e0 100644
--- a/ml/backend/ggml/ggml.go
+++ b/ml/backend/ggml/ggml.go
@@ -1765,7 +1765,7 @@ func (t *Tensor) Duplicate(ctx ml.Context) ml.Tensor {
 func (t *Tensor) TopK(ctx ml.Context, k int) ml.Tensor {
 	return &Tensor{
 		b: t.b,
-		t: C.ggml_argsort_top_k(ctx.(*Context).ctx, t.t, C.int(k)),
+		t: C.ggml_top_k(ctx.(*Context).ctx, t.t, C.int(k)),
 	}
 }
 
diff --git a/ml/backend/ggml/ggml/.rsync-filter b/ml/backend/ggml/ggml/.rsync-filter
index 449ec9e5d0b..5036c245430 100644
--- a/ml/backend/ggml/ggml/.rsync-filter
+++ b/ml/backend/ggml/ggml/.rsync-filter
@@ -13,6 +13,7 @@ include /src/ggml-cpu/
 include /src/ggml-cpu/amx/
 include /src/ggml-cpu/arch/
 include /src/ggml-cpu/arch/arm/
+include /src/ggml-cpu/arch/powerpc/
 include /src/ggml-cpu/arch/x86/
 include /src/ggml-cpu/llamafile/
 include /src/ggml-cuda/
diff --git a/ml/backend/ggml/ggml/LICENSE b/ml/backend/ggml/ggml/LICENSE
index acb96ce78e0..e7dca554bcb 100644
--- a/ml/backend/ggml/ggml/LICENSE
+++ b/ml/backend/ggml/ggml/LICENSE
@@ -1,6 +1,6 @@
 MIT License
 
-Copyright (c) 2023-2024 The ggml authors
+Copyright (c) 2023-2026 The ggml authors
 
 Permission is hereby granted, free of charge, to any person obtaining a copy
 of this software and associated documentation files (the "Software"), to deal
diff --git a/ml/backend/ggml/ggml/include/ggml-backend.h b/ml/backend/ggml/ggml/include/ggml-backend.h
index 6ad583f09e7..21c46f4fc2d 100644
--- a/ml/backend/ggml/ggml/include/ggml-backend.h
+++ b/ml/backend/ggml/ggml/include/ggml-backend.h
@@ -158,6 +158,7 @@ extern "C" {
         const char * description;
         // device free memory in bytes
         size_t memory_free;
+        // device UUID
         const char * id;
         // device total memory in bytes
         size_t memory_total;
@@ -371,7 +372,7 @@ extern "C" {
     typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
 
     // Compare the output of two backends
-    GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node);
+    GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor const * const * test_nodes, size_t num_test_nodes);
 
     // Tensor initialization
     GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
diff --git a/ml/backend/ggml/ggml/include/ggml-cann.h b/ml/backend/ggml/ggml/include/ggml-cann.h
index b469e228d06..74af465337a 100644
--- a/ml/backend/ggml/ggml/include/ggml-cann.h
+++ b/ml/backend/ggml/ggml/include/ggml-cann.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2023-2024 The ggml authors
+ * Copyright (c) 2023-2026 The ggml authors
  *
  * Permission is hereby granted, free of charge, to any person obtaining a copy
  * of this software and associated documentation files (the "Software"), to
diff --git a/ml/backend/ggml/ggml/include/ggml-cpu.h b/ml/backend/ggml/ggml/include/ggml-cpu.h
index 4f3b99c8d07..e3e067c916f 100644
--- a/ml/backend/ggml/ggml/include/ggml-cpu.h
+++ b/ml/backend/ggml/ggml/include/ggml-cpu.h
@@ -19,6 +19,9 @@ extern "C" {
         // abort ggml_graph_compute when true
         ggml_abort_callback abort_callback;
         void *              abort_callback_data;
+
+        // use only reference implementations
+        bool use_ref;
     };
 
     // numa strategies
@@ -132,6 +135,8 @@ extern "C" {
     GGML_BACKEND_API void ggml_backend_cpu_set_threadpool    (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool);
     GGML_BACKEND_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
 
+    GGML_BACKEND_API void ggml_backend_cpu_set_use_ref(ggml_backend_t backend_cpu, bool use_ref);
+
     GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
 
     GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *,       float *, int64_t);
diff --git a/ml/backend/ggml/ggml/include/ggml-virtgpu.h b/ml/backend/ggml/ggml/include/ggml-virtgpu.h
new file mode 100644
index 00000000000..faaba8f246d
--- /dev/null
+++ b/ml/backend/ggml/ggml/include/ggml-virtgpu.h
@@ -0,0 +1,14 @@
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+GGML_BACKEND_API ggml_backend_reg_t ggml_backend_virtgpu_reg();
+
+#ifdef  __cplusplus
+}
+#endif
diff --git a/ml/backend/ggml/ggml/include/ggml.h b/ml/backend/ggml/ggml/include/ggml.h
index 20c912d0e9b..fcc51f1f71a 100644
--- a/ml/backend/ggml/ggml/include/ggml.h
+++ b/ml/backend/ggml/ggml/include/ggml.h
@@ -6,7 +6,7 @@
 // This documentation is still a work in progress.
 // If you wish some specific topics to be covered, feel free to drop a comment:
 //
-//   https://github.com/ggerganov/whisper.cpp/issues/40
+//   https://github.com/ggml-org/whisper.cpp/issues/40
 //
 // ## Overview
 //
@@ -234,6 +234,11 @@
 
 #if UINTPTR_MAX == 0xFFFFFFFF
     #define GGML_MEM_ALIGN 4
+#elif defined(__EMSCRIPTEN__)
+// emscripten uses max_align_t == 8, so we need GGML_MEM_ALIGN == 8 for 64-bit wasm.
+// (for 32-bit wasm, the first conditional is true and GGML_MEM_ALIGN stays 4.)
+// ref: https://github.com/ggml-org/llama.cpp/pull/18628
+    #define GGML_MEM_ALIGN 8
 #else
     #define GGML_MEM_ALIGN 16
 #endif
@@ -625,10 +630,11 @@ extern "C" {
 
     // this tensor...
     enum ggml_tensor_flag {
-        GGML_TENSOR_FLAG_INPUT  =  1, // ...is an input for the GGML compute graph
-        GGML_TENSOR_FLAG_OUTPUT =  2, // ...is an output for the GGML compute graph
-        GGML_TENSOR_FLAG_PARAM  =  4, // ...contains trainable parameters
-        GGML_TENSOR_FLAG_LOSS   =  8, // ...defines loss for numerical optimization (multiple loss tensors add up)
+        GGML_TENSOR_FLAG_INPUT   =  1, // ...is an input for the GGML compute graph
+        GGML_TENSOR_FLAG_OUTPUT  =  2, // ...is an output for the GGML compute graph
+        GGML_TENSOR_FLAG_PARAM   =  4, // ...contains trainable parameters
+        GGML_TENSOR_FLAG_LOSS    =  8, // ...defines loss for numerical optimization (multiple loss tensors add up)
+        GGML_TENSOR_FLAG_COMPUTE = 16, // ...must be computed
     };
 
     enum ggml_tri_type {
@@ -724,10 +730,6 @@ extern "C" {
     GGML_API size_t  ggml_type_size(enum ggml_type type);             // size in bytes for all elements in a block
     GGML_API size_t  ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
 
-    GGML_DEPRECATED(
-    GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float
-    "use ggml_row_size() instead");
-
     GGML_API const char * ggml_type_name(enum ggml_type type);
     GGML_API const char * ggml_op_name  (enum ggml_op   op);
     GGML_API const char * ggml_op_symbol(enum ggml_op   op);
@@ -746,6 +748,7 @@ extern "C" {
     GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor);
     GGML_API bool ggml_is_permuted  (const struct ggml_tensor * tensor);
     GGML_API bool ggml_is_empty     (const struct ggml_tensor * tensor);
+    GGML_API bool ggml_is_view      (const struct ggml_tensor * tensor);
     GGML_API bool ggml_is_scalar    (const struct ggml_tensor * tensor);
     GGML_API bool ggml_is_vector    (const struct ggml_tensor * tensor);
     GGML_API bool ggml_is_matrix    (const struct ggml_tensor * tensor);
@@ -2572,11 +2575,42 @@ extern "C" {
         struct ggml_tensor *  grad,
         struct ggml_tensor *  sgd_params); // alpha, weight decay
 
+    // build forward mutiple tensors and select one of them for computing
+    // this is useful for creating graphs that have constant topology but compute different things based on the input
+    // ref: https://github.com/ggml-org/llama.cpp/pull/18550
+    //
+    // nodes:
+    //   | - build forward into the graph but do not compute
+    //   c - build forward into the graph and compute
+    //
+    //    |  |  ...  c  ...  |
+    //    |  |  ...  c  ...  |
+    //    |  |  ...  c  ...  |
+    //   [0  1  ... idx ...  n-1]        <-- ggml_build_forward_select(..., n, idx)
+    //               c
+    //               c
+    //
+    // example:
+    //   struct ggml_tensor * curs[3];
+    //
+    //   curs[0]  = compute0(...);
+    //   curs[1]  = compute1(...);
+    //   curs[2]  = compute2(...);
     //
-    // automatic differentiation
+    //   int idx = select_branch(some_input);
     //
+    //   struct ggml_tensor * out = ggml_build_forward_select(cgraph, curs, 3, idx);
+    //
+    GGML_API struct ggml_tensor * ggml_build_forward_select(
+            struct ggml_cgraph  * cgraph,
+            struct ggml_tensor ** tensors,
+            int                   n_tensors,
+            int                   idx);
+
+    GGML_API void ggml_build_forward_expand(
+            struct ggml_cgraph * cgraph,
+            struct ggml_tensor * tensor);
 
-    GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
     GGML_API void ggml_build_backward_expand(
         struct ggml_context *  ctx,        // context for gradient computation
         struct ggml_cgraph  *  cgraph,
@@ -2608,7 +2642,7 @@ extern "C" {
     GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
 
     // dump the graph into a file using the dot format
-    GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
+    GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename);
 
     // TODO these functions were sandwiched in the old optimization interface, is there a better place for them?
     typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
diff --git a/ml/backend/ggml/ggml/src/CMakeLists.txt b/ml/backend/ggml/ggml/src/CMakeLists.txt
index 9a134b7af28..201a260a329 100644
--- a/ml/backend/ggml/ggml/src/CMakeLists.txt
+++ b/ml/backend/ggml/ggml/src/CMakeLists.txt
@@ -225,6 +225,7 @@ if (GGML_SCHED_NO_REALLOC)
 endif()
 
 add_library(ggml
+            ggml-backend-dl.cpp
             ggml-backend-reg.cpp)
 add_library(ggml::ggml ALIAS ggml)
 
@@ -362,12 +363,27 @@ if (GGML_CPU_ALL_VARIANTS)
     add_custom_target(ggml-cpu)
     if (GGML_SYSTEM_ARCH STREQUAL "x86")
         ggml_add_cpu_backend_variant(x64)
-        ggml_add_cpu_backend_variant(sse42        SSE42)
-        ggml_add_cpu_backend_variant(sandybridge  SSE42 AVX)
-        ggml_add_cpu_backend_variant(haswell      SSE42 AVX F16C AVX2 BMI2 FMA)
-        ggml_add_cpu_backend_variant(skylakex     SSE42 AVX F16C AVX2 BMI2 FMA AVX512)
-        ggml_add_cpu_backend_variant(icelake      SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
-        ggml_add_cpu_backend_variant(alderlake    SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI)
+        ggml_add_cpu_backend_variant(sse42              SSE42)
+        ggml_add_cpu_backend_variant(sandybridge        SSE42 AVX)
+        if (NOT MSVC)
+            # __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
+            ggml_add_cpu_backend_variant(ivybridge      SSE42 AVX F16C)
+            ggml_add_cpu_backend_variant(piledriver     SSE42 AVX F16C FMA)
+        endif()
+        ggml_add_cpu_backend_variant(haswell            SSE42 AVX F16C FMA AVX2 BMI2)
+        ggml_add_cpu_backend_variant(skylakex           SSE42 AVX F16C FMA AVX2 BMI2 AVX512)
+        ggml_add_cpu_backend_variant(cannonlake         SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI)
+        ggml_add_cpu_backend_variant(cascadelake        SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VNNI)
+        ggml_add_cpu_backend_variant(icelake            SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI)
+        if (NOT MSVC)
+            # MSVC 2022 doesn't support BF16 intrinsics without `/arch:AVX10.1` ?!
+            # https://learn.microsoft.com/en-us/cpp/intrinsics/x64-amd64-intrinsics-list?view=msvc-170
+            # https://learn.microsoft.com/en-us/cpp/build/reference/arch-x64?view=msvc-170
+            ggml_add_cpu_backend_variant(cooperlake     SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VNNI AVX512_BF16)
+            ggml_add_cpu_backend_variant(zen4           SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16)
+        endif()
+        ggml_add_cpu_backend_variant(alderlake          SSE42 AVX F16C FMA AVX2 BMI2 AVX_VNNI)
+        # AMX variants removed by ollama - sapphirerapids with AMX_TILE AMX_INT8 not included
     elseif(GGML_SYSTEM_ARCH STREQUAL "ARM")
         if (CMAKE_SYSTEM_NAME MATCHES "Linux")
             # Many of these features are optional so we build versions with popular
@@ -387,6 +403,9 @@ if (GGML_CPU_ALL_VARIANTS)
             ggml_add_cpu_backend_variant(android_armv8.2_1    DOTPROD)
             ggml_add_cpu_backend_variant(android_armv8.2_2    DOTPROD FP16_VECTOR_ARITHMETIC)
             ggml_add_cpu_backend_variant(android_armv8.6_1    DOTPROD FP16_VECTOR_ARITHMETIC MATMUL_INT8)
+            ggml_add_cpu_backend_variant(android_armv9.0_1    DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE2)
+            ggml_add_cpu_backend_variant(android_armv9.2_1    DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SME)
+            ggml_add_cpu_backend_variant(android_armv9.2_2    DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SVE2 SME)
         elseif (APPLE)
             ggml_add_cpu_backend_variant(apple_m1             DOTPROD)
             ggml_add_cpu_backend_variant(apple_m2_m3          DOTPROD MATMUL_INT8)
@@ -435,6 +454,7 @@ ggml_add_backend(HIP)
 ggml_add_backend(METAL)
 ggml_add_backend(MUSA)
 ggml_add_backend(RPC)
+ggml_add_backend(VirtGPU)
 ggml_add_backend(SYCL)
 ggml_add_backend(Vulkan)
 ggml_add_backend(WebGPU)
diff --git a/ml/backend/ggml/ggml/src/ggml-alloc.c b/ml/backend/ggml/ggml/src/ggml-alloc.c
index 73b39bfea08..d32b016afaa 100644
--- a/ml/backend/ggml/ggml/src/ggml-alloc.c
+++ b/ml/backend/ggml/ggml/src/ggml-alloc.c
@@ -17,11 +17,6 @@
 //#define AT_PRINTF(...) GGML_LOG_DEBUG(__VA_ARGS__)
 #define AT_PRINTF(...)
 
-
-static bool ggml_is_view(const struct ggml_tensor * t) {
-    return t->view_src != NULL;
-}
-
 // ops that return true for this function must not use restrict pointers for their backend implementations
 bool ggml_op_can_inplace(enum ggml_op op) {
     switch (op) {
@@ -632,7 +627,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
     GGML_ASSERT(buffer_id >= 0);
     struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
 
-    if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_is_view(node)) {
+    if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_impl_is_view(node)) {
         hn->allocated = true;
         assert(hn->addr.offset == 0);
 
@@ -663,7 +658,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
 
                 struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
                 if (p_hn->n_children == 1 && p_hn->n_views == 0) {
-                    if (ggml_is_view(parent)) {
+                    if (ggml_impl_is_view(parent)) {
                         struct ggml_tensor * view_src = parent->view_src;
                         struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src);
                         if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
@@ -744,7 +739,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
         // GGML_OP_NONE does not appear normally in the graph nodes, but is used by ggml-backend to add dependencies to
         // control when some tensors are allocated and freed. in this case, the dependencies are in `src`, but the node
         // itself is never used and should not be considered a dependency
-        if (ggml_is_view(node) && node->op != GGML_OP_NONE) {
+        if (ggml_impl_is_view(node) && node->op != GGML_OP_NONE) {
             struct ggml_tensor * view_src = node->view_src;
             ggml_gallocr_hash_get(galloc, view_src)->n_views += 1;
         }
@@ -811,7 +806,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
                 parent->name, p_hn->n_children, p_hn->n_views, p_hn->allocated);
 
             if (p_hn->n_children == 0 && p_hn->n_views == 0) {
-                if (ggml_is_view(parent)) {
+                if (ggml_impl_is_view(parent)) {
                     struct ggml_tensor * view_src = parent->view_src;
                     struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src);
                     view_src_hn->n_views -= 1;
diff --git a/ml/backend/ggml/ggml/src/ggml-backend-dl.cpp b/ml/backend/ggml/ggml/src/ggml-backend-dl.cpp
new file mode 100644
index 00000000000..34949c26163
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-backend-dl.cpp
@@ -0,0 +1,76 @@
+#include "ggml-backend-dl.h"
+#include "ggml-impl.h"
+
+#ifdef _WIN32
+
+static std::string path_str(const fs::path & path) {
+    try {
+#if defined(__cpp_lib_char8_t)
+        // C++20 and later: u8string() returns std::u8string
+        const std::u8string u8str = path.u8string();
+        return std::string(reinterpret_cast(u8str.data()), u8str.size());
+#else
+        // C++17: u8string() returns std::string
+        return path.u8string();
+#endif
+    } catch (...) {
+        return std::string();
+    }
+}
+
+dl_handle * dl_load_library(const fs::path & path) {
+    // suppress error dialogs for missing DLLs
+    DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
+    SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
+
+    HMODULE handle = LoadLibraryW(path.wstring().c_str());
+    if (!handle) {
+        DWORD error_code = GetLastError();
+        std::string msg;
+        LPSTR lpMsgBuf = NULL;
+        DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
+                                      NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL);
+        if (bufLen) {
+            msg = lpMsgBuf;
+            LocalFree(lpMsgBuf);
+            GGML_LOG_INFO("%s unable to load library %s: %s\n", __func__, path_str(path).c_str(), msg.c_str());
+        }
+    }
+
+    SetErrorMode(old_mode);
+
+    return handle;
+}
+
+void * dl_get_sym(dl_handle * handle, const char * name) {
+    DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
+    SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
+
+    void * p = (void *) GetProcAddress(handle, name);
+
+    SetErrorMode(old_mode);
+
+    return p;
+}
+
+const char * dl_error() {
+    return "";
+}
+
+#else
+
+dl_handle * dl_load_library(const fs::path & path) {
+    dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
+    return handle;
+}
+
+void * dl_get_sym(dl_handle * handle, const char * name) {
+    return dlsym(handle, name);
+}
+
+const char * dl_error() {
+    const char *rslt = dlerror();
+    return rslt != nullptr ? rslt : "";
+}
+
+#endif
diff --git a/ml/backend/ggml/ggml/src/ggml-backend-dl.h b/ml/backend/ggml/ggml/src/ggml-backend-dl.h
new file mode 100644
index 00000000000..f74b7c94894
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-backend-dl.h
@@ -0,0 +1,45 @@
+#pragma once
+
+#ifdef _WIN32
+#   define WIN32_LEAN_AND_MEAN
+#   ifndef NOMINMAX
+#       define NOMINMAX
+#   endif
+#   include 
+#   include 
+#else
+#    include 
+#    include 
+#endif
+#include 
+
+namespace fs = std::filesystem;
+
+#ifdef _WIN32
+
+using dl_handle = std::remove_pointer_t;
+
+struct dl_handle_deleter {
+    void operator()(HMODULE handle) {
+        FreeLibrary(handle);
+    }
+};
+
+#else
+
+using dl_handle = void;
+
+struct dl_handle_deleter {
+    void operator()(void * handle) {
+        dlclose(handle);
+    }
+};
+
+#endif
+
+using dl_handle_ptr = std::unique_ptr;
+
+dl_handle * dl_load_library(const fs::path & path);
+void * dl_get_sym(dl_handle * handle, const char * name);
+const char * dl_error();
+
diff --git a/ml/backend/ggml/ggml/src/ggml-backend-impl.h b/ml/backend/ggml/ggml/src/ggml-backend-impl.h
index 21b35ac5c79..c815c2eeda5 100644
--- a/ml/backend/ggml/ggml/src/ggml-backend-impl.h
+++ b/ml/backend/ggml/ggml/src/ggml-backend-impl.h
@@ -160,7 +160,7 @@ extern "C" {
         // device description: short informative description of the device, could be the model name
         const char * (*get_description)(ggml_backend_dev_t dev);
 
-        // device memory in bytes
+        // device memory in bytes: 0 bytes to indicate no memory to report
         void         (*get_memory)(ggml_backend_dev_t dev, size_t * free, size_t * total);
 
         // device type
diff --git a/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp b/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
index 2474e0ed685..03e32b2d546 100644
--- a/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
@@ -1,5 +1,6 @@
 #include "ggml-backend-impl.h"
 #include "ggml-backend.h"
+#include "ggml-backend-dl.h"
 #include "ggml-impl.h"
 #include 
 #include 
@@ -69,6 +70,10 @@
 #include "ggml-rpc.h"
 #endif
 
+#ifdef GGML_USE_VIRTGPU_FRONTEND
+#include "ggml-virtgpu.h"
+#endif
+
 #ifdef GGML_USE_CANN
 #include "ggml-cann.h"
 #endif
@@ -77,117 +82,23 @@
 #include "ggml-zendnn.h"
 #endif
 
-// disable C++17 deprecation warning for std::codecvt_utf8
-#if defined(__clang__)
-#    pragma clang diagnostic push
-#    pragma clang diagnostic ignored "-Wdeprecated-declarations"
-#elif defined(__GNUC__)
-#    pragma GCC diagnostic push
-#    pragma GCC diagnostic ignored "-Wdeprecated-declarations"
-#endif
-
 namespace fs = std::filesystem;
 
 static std::string path_str(const fs::path & path) {
-    std::string u8path;
     try {
 #if defined(__cpp_lib_char8_t)
         // C++20 and later: u8string() returns std::u8string
-        std::u8string u8str = path.u8string();
-        u8path = std::string(reinterpret_cast(u8str.c_str()));
+        const std::u8string u8str = path.u8string();
+        return std::string(reinterpret_cast(u8str.data()), u8str.size());
 #else
         // C++17: u8string() returns std::string
-        u8path = path.u8string();
+        return path.u8string();
 #endif
     } catch (...) {
+        return std::string();
     }
-    return u8path;
-}
-
-#if defined(__clang__)
-#    pragma clang diagnostic pop
-#elif defined(__GNUC__)
-#    pragma GCC diagnostic pop
-#endif
-
-#ifdef _WIN32
-
-using dl_handle = std::remove_pointer_t;
-
-struct dl_handle_deleter {
-    void operator()(HMODULE handle) {
-        FreeLibrary(handle);
-    }
-};
-
-static dl_handle * dl_load_library(const fs::path & path) {
-    // suppress error dialogs for missing DLLs
-    DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
-    SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
-
-    HMODULE handle = LoadLibraryW(path.wstring().c_str());
-    if (!handle) {
-        DWORD error_code = GetLastError();
-        std::string msg;
-        LPSTR lpMsgBuf = NULL;
-        DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
-                                      NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL);
-        if (bufLen) {
-            msg = lpMsgBuf;
-            LocalFree(lpMsgBuf);
-            GGML_LOG_INFO("%s unable to load library %s: %s\n", __func__, path_str(path).c_str(), msg.c_str());
-        }
-    }
-
-    SetErrorMode(old_mode);
-
-    return handle;
-}
-
-static void * dl_get_sym(dl_handle * handle, const char * name) {
-    DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
-    SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
-
-    void * p = (void *) GetProcAddress(handle, name);
-
-    SetErrorMode(old_mode);
-
-    return p;
 }
 
-static const char * dl_error() {
-    return "";
-}
-
-#else
-
-using dl_handle = void;
-
-struct dl_handle_deleter {
-    void operator()(void * handle) {
-        dlclose(handle);
-    }
-};
-
-static void * dl_load_library(const fs::path & path) {
-    dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
-
-    return handle;
-}
-
-static void * dl_get_sym(dl_handle * handle, const char * name) {
-    return dlsym(handle, name);
-}
-
-static const char * dl_error() {
-    const char *rslt = dlerror();
-    return rslt != nullptr ? rslt : "";
-}
-
-#endif
-
-using dl_handle_ptr = std::unique_ptr;
-
 struct ggml_backend_reg_entry {
     ggml_backend_reg_t reg;
     dl_handle_ptr handle;
@@ -208,7 +119,12 @@ struct ggml_backend_registry {
         register_backend(ggml_backend_sycl_reg());
 #endif
 #ifdef GGML_USE_VULKAN
+    // Add runtime disable check
+    if (getenv("GGML_DISABLE_VULKAN") == nullptr) {
         register_backend(ggml_backend_vk_reg());
+    } else {
+        GGML_LOG_DEBUG("Vulkan backend disabled by GGML_DISABLE_VULKAN environment variable\n");
+    }
 #endif
 #ifdef GGML_USE_WEBGPU
         register_backend(ggml_backend_webgpu_reg());
@@ -216,6 +132,10 @@ struct ggml_backend_registry {
 #ifdef GGML_USE_ZDNN
         register_backend(ggml_backend_zdnn_reg());
 #endif
+#ifdef GGML_USE_VIRTGPU_FRONTEND
+        register_backend(ggml_backend_virtgpu_reg());
+#endif
+
 #ifdef GGML_USE_OPENCL
         register_backend(ggml_backend_opencl_reg());
 #endif
@@ -556,9 +476,10 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
 
     int best_score = 0;
     fs::path best_path;
+    std::error_code ec;
 
     for (const auto & search_path : search_paths) {
-        if (std::error_code ec; !fs::exists(search_path, ec)) {
+        if (!fs::exists(search_path, ec)) {
             if (ec) {
                 GGML_LOG_DEBUG("%s: posix_stat(%s) failure, error-message: %s\n", __func__, path_str(search_path).c_str(), ec.message().c_str());
             } else {
@@ -568,7 +489,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
         }
         fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
         for (const auto & entry : dir_it) {
-            if (entry.is_regular_file()) {
+            if (entry.is_regular_file(ec)) {
                 auto filename = entry.path().filename();
                 auto ext = entry.path().extension();
                 if (filename.native().find(file_prefix) == 0 && ext == file_extension) {
@@ -637,6 +558,7 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
     ggml_backend_load_best("rpc", silent, dir_path);
     ggml_backend_load_best("sycl", silent, dir_path);
     ggml_backend_load_best("vulkan", silent, dir_path);
+    ggml_backend_load_best("virtgpu", silent, dir_path);
     ggml_backend_load_best("opencl", silent, dir_path);
     ggml_backend_load_best("hexagon", silent, dir_path);
     ggml_backend_load_best("musa", silent, dir_path);
diff --git a/ml/backend/ggml/ggml/src/ggml-backend.cpp b/ml/backend/ggml/ggml/src/ggml-backend.cpp
index 189e97170ff..4d8abe025e8 100644
--- a/ml/backend/ggml/ggml/src/ggml-backend.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-backend.cpp
@@ -278,6 +278,7 @@ void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor *
     GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
 
     if (backend->iface.set_tensor_async == NULL) {
+        ggml_backend_synchronize(backend);
         ggml_backend_tensor_set(tensor, data, offset, size);
     } else {
         backend->iface.set_tensor_async(backend, tensor, data, offset, size);
@@ -291,6 +292,7 @@ void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_ten
     GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
 
     if (backend->iface.get_tensor_async == NULL) {
+        ggml_backend_synchronize(backend);
         ggml_backend_tensor_get(tensor, data, offset, size);
     } else {
         backend->iface.get_tensor_async(backend, tensor, data, offset, size);
@@ -911,9 +913,9 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
         }
         if (sched->debug > 1) {
             ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
-            GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d:", i, ggml_op_name(node->op), node->name,
+            GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d,c=%d:", i, ggml_op_name(node->op), node->name,
                 fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node),
-                graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)]);
+                graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)], node->flags & GGML_TENSOR_FLAG_COMPUTE ? 1 : 0);
             for (int j = 0; j < GGML_MAX_SRC; j++) {
                 struct ggml_tensor * src = node->src[j];
                 if (src == NULL) {
@@ -1747,6 +1749,10 @@ ggml_backend_sched_t ggml_backend_sched_new_ext(
     return sched;
 }
 
+void ggml_backend_sched_set_batch_size(ggml_backend_sched_t sched, int batch_size) {
+    sched->batch_size = batch_size;
+}
+
 void ggml_backend_sched_free(ggml_backend_sched_t sched) {
     if (sched == NULL) {
         return;
@@ -1776,10 +1782,6 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
     free(sched);
 }
 
-void ggml_backend_sched_set_batch_size(ggml_backend_sched_t sched, int batch_size) {
-    sched->batch_size = batch_size;
-}
-
 void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
     GGML_ASSERT(sched);
     // reset state for the next run
@@ -2013,6 +2015,7 @@ static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set,
         dst->view_offs = src->view_offs;
     }
     dst->op = src->op;
+    dst->flags = src->flags;
     memcpy(dst->op_params, src->op_params, sizeof(dst->op_params));
     ggml_set_name(dst, src->name);
 
@@ -2144,7 +2147,7 @@ void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) {
     ggml_free(copy.ctx_unallocated);
 }
 
-bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node) {
+bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor const * const * test_nodes, size_t num_test_nodes) {
     struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
     if (copy.buffer == NULL) {
         return false;
@@ -2155,22 +2158,22 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
 
     assert(g1->n_nodes == g2->n_nodes);
 
-    if (test_node != nullptr) {
-        // Compute the whole graph and only test the output for a specific tensor
+    if (num_test_nodes != 0) {
+        GGML_ASSERT(test_nodes);
+        // Compute the whole graph and only test the output for specific tensors
         ggml_backend_graph_compute(backend1, g1);
         ggml_backend_graph_compute(backend2, g2);
 
-        int test_node_idx = -1;
+        bool verified = false;
         for (int i = 0; i < g1->n_nodes; i++) {
-            struct ggml_tensor * t1 = g1->nodes[i];
-            if (t1 == test_node) {
-                test_node_idx = i;
-                break;
+            for (size_t j = 0; j < num_test_nodes; ++j) {
+                if (g1->nodes[i] == test_nodes[j]) {
+                    callback(i, g1->nodes[i], g2->nodes[i], user_data);
+                    verified = true;
+                }
             }
         }
-        GGML_ASSERT(test_node_idx != -1);
-
-        callback(test_node_idx, g1->nodes[test_node_idx], g2->nodes[test_node_idx], user_data);
+        GGML_ASSERT(verified);
     } else {
         for (int i = 0; i < g1->n_nodes; i++) {
             struct ggml_tensor * t1 = g1->nodes[i];
diff --git a/ml/backend/ggml/ggml/src/ggml-blas/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-blas/CMakeLists.txt
index 60ce4b1e02c..c27dc174c00 100644
--- a/ml/backend/ggml/ggml/src/ggml-blas/CMakeLists.txt
+++ b/ml/backend/ggml/ggml/src/ggml-blas/CMakeLists.txt
@@ -32,14 +32,12 @@ if (BLAS_FOUND)
                 pkg_check_modules(DepBLAS openblas)
             endif()
         elseif (${GGML_BLAS_VENDOR} MATCHES "FLAME")
-            add_compile_definitions(GGML_BLAS_USE_BLIS)
             pkg_check_modules(DepBLAS blis)
         elseif (${GGML_BLAS_VENDOR} MATCHES "ATLAS")
             pkg_check_modules(DepBLAS blas-atlas)
         elseif (${GGML_BLAS_VENDOR} MATCHES "FlexiBLAS")
             pkg_check_modules(DepBLAS flexiblas_api)
         elseif (${GGML_BLAS_VENDOR} MATCHES "Intel")
-            add_compile_definitions(GGML_BLAS_USE_MKL)
             # all Intel* libraries share the same include path
             pkg_check_modules(DepBLAS mkl-sdl)
         elseif (${GGML_BLAS_VENDOR} MATCHES "NVHPC")
@@ -74,12 +72,28 @@ if (BLAS_FOUND)
 
     target_compile_options(ggml-blas PRIVATE ${BLAS_LINKER_FLAGS})
 
-    if ("${BLAS_INCLUDE_DIRS}" MATCHES "mkl" AND (${GGML_BLAS_VENDOR} MATCHES "Generic" OR ${GGML_BLAS_VENDOR} MATCHES "Intel"))
+    if ("${GGML_BLAS_VENDOR}" STREQUAL "")
+        message(WARNING "GGML_BLAS_VENDOR is not set; some methods may not link properly.")
+    endif()
+
+    if ("${GGML_BLAS_VENDOR}" MATCHES "Intel" OR ("${BLAS_INCLUDE_DIRS}" MATCHES "mkl" AND "${GGML_BLAS_VENDOR}" MATCHES "Generic"))
         add_compile_definitions(GGML_BLAS_USE_MKL)
     endif()
 
+    if ("${GGML_BLAS_VENDOR}" MATCHES "OpenBLAS")
+        add_compile_definitions(GGML_BLAS_USE_OPENBLAS)
+    endif()
+
+    if ("${GGML_BLAS_VENDOR}" MATCHES "FLAME" OR "${GGML_BLAS_VENDOR}" MATCHES "AOCL" OR "${GGML_BLAS_VENDOR}" MATCHES "AOCL_mt")
+        add_compile_definitions(GGML_BLAS_USE_BLIS)
+    endif()
+
+    if ("${GGML_BLAS_VENDOR}" MATCHES "NVPL")
+        add_compile_definitions(GGML_BLAS_USE_NVPL)
+    endif()
+
     target_link_libraries     (ggml-blas PRIVATE ${BLAS_LIBRARIES})
-    target_include_directories(ggml-blas PRIVATE ${BLAS_INCLUDE_DIRS})
+    target_include_directories(ggml-blas SYSTEM PRIVATE ${BLAS_INCLUDE_DIRS})
 else()
     message(FATAL_ERROR "BLAS not found, please refer to "
                         "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors"
diff --git a/ml/backend/ggml/ggml/src/ggml-blas/ggml-blas.cpp b/ml/backend/ggml/ggml/src/ggml-blas/ggml-blas.cpp
index 88d08895277..6a399bdb179 100644
--- a/ml/backend/ggml/ggml/src/ggml-blas/ggml-blas.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-blas/ggml-blas.cpp
@@ -115,15 +115,11 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
 #endif
     }
 
-#if defined(OPENBLAS_VERSION)
+#if defined(GGML_BLAS_USE_OPENBLAS)
     openblas_set_num_threads(ctx->n_threads);
-#endif
-
-#if defined(GGML_BLAS_USE_BLIS)
+#elif defined(GGML_BLAS_USE_BLIS)
     bli_thread_set_num_threads(ctx->n_threads);
-#endif
-
-#if defined(GGML_BLAS_USE_NVPL)
+#elif defined(GGML_BLAS_USE_NVPL)
     nvpl_blas_set_num_threads(ctx->n_threads);
 #endif
 
@@ -230,6 +226,10 @@ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend,
     for (int i = 0; i < cgraph->n_nodes; i++) {
         struct ggml_tensor * node = cgraph->nodes[i];
 
+        if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
+            continue;
+        }
+
         switch (node->op) {
             case GGML_OP_MUL_MAT:
                 ggml_backend_blas_mul_mat(ctx, node);
@@ -289,7 +289,7 @@ ggml_backend_t ggml_backend_blas_init(void) {
         /* .context = */ ctx,
     };
 
-#if defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP)
+#if defined(GGML_BLAS_USE_OPENBLAS) && defined(GGML_USE_OPENMP)
     if (openblas_get_parallel() != OPENBLAS_OPENMP) {
         GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__);
     }
@@ -330,7 +330,7 @@ static const char * ggml_backend_blas_device_get_description(ggml_backend_dev_t
         return "BLIS";
     #elif defined(GGML_BLAS_USE_NVPL)
         return "NVPL";
-    #elif defined(OPENBLAS_VERSION)
+    #elif defined(GGML_BLAS_USE_OPENBLAS)
         return "OpenBLAS";
     #else
         return "BLAS";
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt
index fc31089f3e2..3dc948e4d8e 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt
@@ -9,6 +9,11 @@ function(ggml_add_cpu_backend_features cpu_name arch)
     target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE ${ARGN})
     target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE GGML_BACKEND_DL GGML_BACKEND_BUILD GGML_BACKEND_SHARED)
     set_target_properties(${GGML_CPU_FEATS_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
+    # Disable LTO for the feature detection code to prevent cross-module optimization
+    # from inlining architecture-specific instructions into the score function.
+    # Without this, LTO can cause SIGILL when loading backends on older CPUs
+    # (e.g., loading power10 backend on power9 crashes before feature check runs).
+    target_compile_options(${GGML_CPU_FEATS_NAME} PRIVATE -fno-lto)
     target_link_libraries(${cpu_name} PRIVATE ${GGML_CPU_FEATS_NAME})
 endfunction()
 
@@ -458,6 +463,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
             if (GGML_RV_ZFH)
                 string(APPEND MARCH_STR "_zfh")
             endif()
+
             if (GGML_XTHEADVECTOR)
                 string(APPEND MARCH_STR "_xtheadvector")
             elseif (GGML_RVV)
@@ -465,6 +471,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
                 if (GGML_RV_ZVFH)
                     string(APPEND MARCH_STR "_zvfh")
                 endif()
+                if (GGML_RV_ZVFBFWMA)
+                    string(APPEND MARCH_STR "_zvfbfwma")
+                endif()
             endif()
             if (GGML_RV_ZICBOP)
                 string(APPEND MARCH_STR "_zicbop")
@@ -557,35 +566,32 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
 
         # Fetch KleidiAI sources:
         include(FetchContent)
-        set(KLEIDIAI_COMMIT_TAG "v1.14.0")
+        set(KLEIDIAI_COMMIT_TAG "v1.16.0")
         set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
-        set(KLEIDIAI_ARCHIVE_MD5  "45e110675d93f99f82c23a1afcca76bc")
+        set(KLEIDIAI_ARCHIVE_MD5  "0a9e9008adb6031f9e8cf70dff4a3321")
 
         if (POLICY CMP0135)
             cmake_policy(SET CMP0135 NEW)
         endif()
 
+        # TODO: Use FetchContent_MakeAvailable with EXCLUDE_FROM_ALL after bumping minimum CMake version to 3.28+
+        # Using FetchContent_Populate instead to avoid EXCLUDE_FROM_ALL which requires CMake 3.28
         FetchContent_Declare(KleidiAI_Download
             URL ${KLEIDIAI_DOWNLOAD_URL}
             DOWNLOAD_EXTRACT_TIMESTAMP NEW
             URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5})
 
-        FetchContent_MakeAvailable(KleidiAI_Download)
         FetchContent_GetProperties(KleidiAI_Download
             SOURCE_DIR  KLEIDIAI_SRC
             POPULATED   KLEIDIAI_POPULATED)
 
         if (NOT KLEIDIAI_POPULATED)
-            message(FATAL_ERROR "KleidiAI source downloaded failed.")
+            FetchContent_Populate(KleidiAI_Download)
+            FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC)
         endif()
 
         add_compile_definitions(GGML_USE_CPU_KLEIDIAI)
 
-        # Remove kleidiai target after fetching it
-        if (TARGET kleidiai)
-            set_target_properties(kleidiai PROPERTIES EXCLUDE_FROM_ALL TRUE)
-        endif()
-
         list(APPEND GGML_CPU_SOURCES
             ggml-cpu/kleidiai/kleidiai.cpp
             ggml-cpu/kleidiai/kernels.cpp
@@ -611,6 +617,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
         string(FIND "${ARCH_FLAGS_TEMP}" "+dotprod" DOTPROD_ENABLED)
         string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED)
         string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED)
+        string(FIND "${ARCH_FLAGS_TEMP}" "+sve" SVE_ENABLED)
 
         set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS_TEMP})
 
@@ -655,6 +662,15 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
             set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2")
         endif()
 
+        if (NOT SVE_ENABLED MATCHES -1)
+            list(APPEND GGML_KLEIDIAI_SOURCES
+                ${KLEIDIAI_SRC}/kai/kai_common_sve_asm.S
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod_asm.S
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod.c
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm_asm.S
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm.c)
+        endif()
+
         set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}")
         list(APPEND GGML_CPU_SOURCES ${GGML_KLEIDIAI_SOURCES})
     endif()
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/amx/amx.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/amx/amx.cpp
index 895a5713753..9baf3e025e6 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/amx/amx.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/amx/amx.cpp
@@ -141,27 +141,50 @@ static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_typ
 namespace ggml::cpu::amx {
 class extra_buffer_type : ggml::cpu::extra_buffer_type {
     bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
-        // handle only 2d gemm for now
-        auto is_contiguous_2d = [](const struct ggml_tensor * t) {
-            return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
-        };
-
-        if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) &&  // src0 must be contiguous
-            is_contiguous_2d(op->src[1]) &&                               // src1 must be contiguous
-            op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() &&
-            op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315)
-            op->ne[0] % (TILE_N * 2) == 0 &&                              // out_features is 32x
-            (qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) {
-            // src1 must be host buffer
-            if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
+        if (op->op != GGML_OP_MUL_MAT) {
+            return false;
+        }
+        auto * src0 = op->src[0];
+        auto * src1 = op->src[1];
+
+        if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
+            return false;
+        }
+        if (!src0->buffer || src0->buffer->buft != ggml_backend_amx_buffer_type()) {
+            return false;
+        }
+        if (src1->buffer && !ggml_backend_buft_is_host(src1->buffer->buft)) {
+            return false;
+        }
+        if (op->ne[0] % (TILE_N * 2)) {
+            return false;
+        }
+        int alignment;
+        switch (src0->type) {
+            case GGML_TYPE_Q4_0:
+            case GGML_TYPE_Q4_1:
+            case GGML_TYPE_Q8_0:
+                alignment = TILE_K;
+                break;
+            case GGML_TYPE_Q4_K:
+            case GGML_TYPE_Q5_K:
+            case GGML_TYPE_Q6_K:
+            case GGML_TYPE_IQ4_XS:
+                alignment = 256; // QK_K
+                break;
+            case GGML_TYPE_F16:
+                alignment = 16;
+                break;
+            default:
                 return false;
-            }
-            // src1 must be float32
-            if (op->src[1]->type == GGML_TYPE_F32) {
-                return true;
-            }
         }
-        return false;
+        if (src0->ne[0] % alignment) {
+            return false;
+        }
+        if (src1->type != GGML_TYPE_F32) {
+            return false;
+        }
+        return true;
     }
 
     ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/amx/mmq.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/amx/mmq.cpp
index 47c61b88164..b5aca76633c 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/amx/mmq.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/amx/mmq.cpp
@@ -1,4 +1,3 @@
-
 #if defined(__GNUC__)
 #pragma GCC diagnostic ignored "-Wpedantic"
 #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
@@ -202,35 +201,27 @@ struct tile_config_t{
 //    advanced-matrix-extensions-intrinsics-functions.html
 //
 
-#define TC_CONFIG_TILE(i, r, cb) tc.rows[i] = r; tc.colsb[i] = cb
-void ggml_tile_config_init(void) {
-    static thread_local bool is_first_time = true;
+inline void ggml_tile_config_init(void) {
+    static thread_local bool done = false;
 
-    if (!is_first_time) {
+    if (done) {
         return;
     }
 
-    static thread_local tile_config_t tc;
-    tile_config_t current_tc;
-    _tile_storeconfig(¤t_tc);
-
-    // load only when config changes
-    if (tc.palette_id == 0 || (memcmp(¤t_tc.colsb, &tc.colsb, sizeof(uint16_t) * 8) != 0 &&
-                               memcmp(¤t_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) {
-        tc.palette_id = 1;
-        tc.start_row = 0;
-        TC_CONFIG_TILE(TMM0, 8, 64);
-        TC_CONFIG_TILE(TMM1, 8, 64);
-        TC_CONFIG_TILE(TMM2, 16, 32);
-        TC_CONFIG_TILE(TMM3, 16, 32);
-        TC_CONFIG_TILE(TMM4, 16, 64);
-        TC_CONFIG_TILE(TMM5, 16, 64);
-        TC_CONFIG_TILE(TMM6, 16, 64);
-        TC_CONFIG_TILE(TMM7, 16, 64);
-        _tile_loadconfig(&tc);
-    }
-
-    is_first_time = false;
+    alignas(64) tile_config_t tc = {};
+    tc.palette_id = 1;
+    tc.start_row = 0;
+    tc.rows[0] = 8;   tc.colsb[0] = 64;
+    tc.rows[1] = 8;   tc.colsb[1] = 64;
+    tc.rows[2] = 16;  tc.colsb[2] = 32;
+    tc.rows[3] = 16;  tc.colsb[3] = 32;
+    tc.rows[4] = 16;  tc.colsb[4] = 64;
+    tc.rows[5] = 16;  tc.colsb[5] = 64;
+    tc.rows[6] = 16;  tc.colsb[6] = 64;
+    tc.rows[7] = 16;  tc.colsb[7] = 64;
+
+    _tile_loadconfig(&tc);
+    done = true;
 }
 
 // we need an extra 16 * 4B (TILE_N * int32_t) for each NB/KB block for compensation.
@@ -268,33 +259,6 @@ int get_row_size(int K) {
     return row_size;
 }
 
-// vectorized dtype conversion
-inline float FP16_TO_FP32(ggml_half val) {
-    __m256i v = _mm256_setr_epi16(
-        val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
-    __m512 o = _mm512_cvtph_ps(v);
-    return _mm512_cvtss_f32(o);
-}
-
-inline __m512 FP16_TO_FP32_VEC(ggml_half val) {
-    __m256i v = _mm256_set1_epi16(val);
-    return _mm512_cvtph_ps(v);
-}
-
-// horizontal reduce
-inline float _mm512_reduce_max_ps(const __m512 x) {
-    __m512 v = x;
-    __m512 v1 = _mm512_shuffle_f32x4(v, v, 0x4E);
-    v = _mm512_max_ps(v, v1);
-    v1 = _mm512_shuffle_f32x4(v, v, 0xB1);
-    v = _mm512_max_ps(v, v1);
-    v1 = _mm512_shuffle_ps(v, v, 0x4E);
-    v = _mm512_max_ps(v, v1);
-    v1 = _mm512_shuffle_ps(v, v, 0xB1);
-    v = _mm512_max_ps(v, v1);
-    return _mm512_cvtss_f32(v);
-}
-
 // transpose utils
 #define SHUFFLE_EPI32(a, b, mask) \
     _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask))
@@ -1370,9 +1334,9 @@ struct tinygemm_kernel_avx
 
 #define LAUNCH_TINYGEMM_KERNEL_AVX(MB_SIZE, NB_SIZE)                                \
     tinygemm_kernel_avx::apply(    \
-        K, (const float *)src1->data + mb_start * K,                                \
-        (const type *)src0->data + nb_start * K,                                    \
-        (float *)dst->data + mb_start * ldc + nb_start, ldc);
+        K, (const float *)src1->data + src1_offset + mb_start * K,                  \
+        (const type *)src0->data + src0_offset + nb_start * K,                      \
+        (float *)dst->data + dst_offset + mb_start * ldc + nb_start, ldc)
 
 
 // re-organize in the format {NB, KB, TILE_SIZE}:
@@ -2019,11 +1983,11 @@ struct tinygemm_kernel_vnni::apply(   \
-        KB, (const char *)wdata + 0 * row_size_A,                                    \
-        (const char *)src0->data + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE),     \
-        (float *) dst->data + 0 * N + nb_start, ldc)
+#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE)                                                   \
+    tinygemm_kernel_vnni::apply(             \
+        KB, wdata_batch,                                                                       \
+        (const char *)src0->data + src0_offset + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \
+        (float *) dst->data + dst_offset + nb_start, ldc)
 
 template ::value, int>::type = 0>
@@ -2079,7 +2043,7 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v
         _tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t));
 
         if (need_unpack) {
-            unpack_B(Tile1, B_blk0);
+            unpack_B(Tile1, B_blk1);
             _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
         } else {
             _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);
@@ -2336,6 +2300,13 @@ void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * d
     });
 }
 
+// ne2 is passed explicitly to help compiler optimize repeated calls
+inline int64_t ggml_batch_offset(const ggml_tensor * t, int64_t batch_idx, int64_t ne2) {
+    const int64_t i2 = batch_idx % ne2;
+    const int64_t i3 = batch_idx / ne2;
+    return i3 * t->nb[3] + i2 * t->nb[2];
+}
+
 size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {
     struct ggml_tensor * src0 = dst->src[0];
 
@@ -2348,12 +2319,13 @@ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {
 
     const int M = dst->ne[1];
     const int K = src0->ne[0];
+    const int64_t n_batch = dst->ne[2] * dst->ne[3];
 
     size_t desired_wsize = 0;
 
     GGML_DISPATCH_QTYPES(TYPE, [&] {
         const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
-        desired_wsize = M * row_size_A;
+        desired_wsize = n_batch * M * row_size_A;
     });
 
     return desired_wsize;
@@ -2365,7 +2337,7 @@ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {
 // src1: input  in shape of {M, K}, float32
 // dst:  output in shape of {M, N}, float32
 //
-// the function performs: dst = src1 @ src0.T
+// the function performs: dst = src1 @ src0.T for each batch
 //
 void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_tensor * dst) {
     struct ggml_tensor * src0 = dst->src[0];
@@ -2382,17 +2354,26 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
     const int K = src0->ne[0];
     const int ldc = dst->nb[1] / dst->nb[0];
 
+    const int64_t ne2 = dst->ne[2];
+    const int64_t n_batch = ne2 * dst->ne[3];
+
     if (is_floating_type) {
         constexpr int BLOCK_M = 4;
         constexpr int BLOCK_N = 6;
         const int MB = div_up(M, BLOCK_M);
         const int NB = div_up(N, BLOCK_N);
 
-        parallel_for_ggml(params, MB * NB, [&](int begin, int end) {
+        parallel_for_ggml(params, n_batch * MB * NB, [&](int begin, int end) {
             GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] {
                 for (int i = begin; i < end; ++i) {
-                    int mb = i / NB;
-                    int nb = i % NB;
+                    int batch_idx = i / (MB * NB);
+                    int remaining = i % (MB * NB);
+                    int mb = remaining / NB;
+                    int nb = remaining % NB;
+
+                    int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);
+                    int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2);
+                    int64_t dst_offset  = ggml_batch_offset(dst,  batch_idx, ne2);
 
                     int mb_start = mb * BLOCK_M;
                     int mb_size = std::min(BLOCK_M, M - mb_start);
@@ -2424,10 +2405,10 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
     void * wdata = params->wdata;
 
     //TODO: performance improvement: merge quant A
-    if (params->ith == 0) {
+ // if (params->ith == 0) {
         GGML_DISPATCH_QTYPES(TYPE, [&] {
             const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
-            const size_t desired_wsize = M * row_size_A;
+            const size_t desired_wsize = n_batch * M * row_size_A;
             if (params->wsize < desired_wsize) {
                 GGML_ABORT("insufficient work space size");
             }
@@ -2436,12 +2417,19 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
             // Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size
             GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size);
 
-            const float * A_data = static_cast(src1->data);
-            for (int m = 0; m < M; ++m) {
-                from_float(A_data + m * K, (char *)wdata + m * row_size_A, K);
-            }
+            parallel_for_ggml(params, n_batch, [&](int begin, int end) {
+                for (int batch_idx = begin; batch_idx < end; ++batch_idx) {
+                    int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2);
+                    const float * A_data = (const float *)((const char *)src1->data + src1_offset);
+                    char * wdata_batch = (char *)wdata + batch_idx * M * row_size_A;
+
+                    for (int m = 0; m < M; ++m) {
+                        from_float(A_data + m * K, wdata_batch + m * row_size_A, K);
+                    }
+                }
+            });
         });
-    }
+ // }
 
     ggml_barrier(params->threadpool);
 
@@ -2451,13 +2439,19 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
         constexpr int BLOCK_N = TILE_N * kTilesN;
         const int NB = div_up(N, BLOCK_N);
 
-        parallel_for_ggml(params, NB, [&](int begin, int end) {
+        parallel_for_ggml(params, n_batch * NB, [&](int begin, int end) {
             GGML_DISPATCH_QTYPES(TYPE, [&] {
                 const int KB = K / blck_size;
                 const int TILE_SIZE = get_tile_size();
                 const int row_size_A = KB * sizeof(vec_dot_type);
                 for (int i = begin; i < end; ++i) {
-                    int nb = i;
+                    int batch_idx = i / NB;
+                    int nb = i % NB;
+
+                    int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);
+                    int64_t dst_offset  = ggml_batch_offset(dst,  batch_idx, ne2);
+                    const char * wdata_batch = (const char *)wdata + batch_idx * row_size_A;
+
                     int nb_start = nb * BLOCK_N;
                     int nb_size = std::min(BLOCK_N, N - nb_start); // 32, 64, 96
 
@@ -2481,7 +2475,7 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
     const int MB = div_up(M, BLOCK_M);
     const int NB = div_up(N, BLOCK_N);
 
-    parallel_for_ggml(params, MB * NB, [&](int begin, int end) {
+    parallel_for_ggml(params, n_batch * MB * NB, [&](int begin, int end) {
         // init tile config for each thread
         ggml_tile_config_init();
 
@@ -2491,8 +2485,14 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
             const int row_size_A = KB * sizeof(vec_dot_type);
 
             for (int i = begin; i < end; ++i) {
-                int mb = i / NB;
-                int nb = i % NB;
+                int batch_idx = i / (MB * NB);
+                int remaining = i % (MB * NB);
+                int mb = remaining / NB;
+                int nb = remaining % NB;
+
+                int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);
+                int64_t dst_offset  = ggml_batch_offset(dst,  batch_idx, ne2);
+                const char * wdata_batch = (const char *)wdata + batch_idx * M * row_size_A;
 
                 int mb_start = mb * BLOCK_M;
                 int mb_size = std::min(BLOCK_M, M - mb_start);
@@ -2501,9 +2501,9 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
 
                 tinygemm_kernel_amx(
                     mb_size, nb_size, KB,
-                    (const char *)wdata + mb_start * row_size_A,
-                    (const char *)src0->data + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE),
-                    (float *) dst->data + mb_start * N + nb_start, ldc);
+                    wdata_batch + mb_start * row_size_A,
+                    (const char *)src0->data + src0_offset + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE),
+                    (float *) dst->data + dst_offset + mb_start * N + nb_start, ldc);
             }
         });
     });
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/arch-fallback.h b/ml/backend/ggml/ggml/src/ggml-cpu/arch-fallback.h
index 0775c87f98b..ebbd4b47e05 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/arch-fallback.h
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/arch-fallback.h
@@ -1,3 +1,4 @@
+
 #pragma once
 
 // Rename `_generic` functions if no native implementation is available.
@@ -38,26 +39,44 @@
 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
-#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
+#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
+#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
+#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
+#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
+#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
+#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
-#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
+#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
+#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
+#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
+#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
+#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
+#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
 // repack.cpp
 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
 // repack.cpp
@@ -66,11 +85,25 @@
 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
+#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
+#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
+#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
+#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
+#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
+#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
+#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
+#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
+#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
+#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
+#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
+#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #elif defined(__POWERPC__) || defined(__powerpc__)
 // ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679
 // quants.c
@@ -86,19 +119,35 @@
 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
-#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
+#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
+#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
+#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
+#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
+#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
+#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
-#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
+#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
+#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
+#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
+#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
+#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
+#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #elif defined(__loongarch64)
 // quants.c
 #define quantize_row_q8_K_generic quantize_row_q8_K
@@ -114,31 +163,41 @@
 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
-#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
+#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
+#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
+#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
+#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
+#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
+#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
-#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
+#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
+#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
+#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
+#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
+#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
+#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #elif defined(__riscv)
 // quants.c
 #define quantize_row_q8_K_generic quantize_row_q8_K
-#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
-#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
 #define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K
 #define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K
-#define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K
 #define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K
-#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
-#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
-#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
 #define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
 #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
 #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
@@ -149,18 +208,34 @@
 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
-#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
+#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
+#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
+#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
+#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
+#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
+#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
-#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
+#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
+#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
+#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
+#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
+#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
+#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #elif defined(__s390x__)
 // quants.c
 #define quantize_row_q8_K_generic quantize_row_q8_K
@@ -182,19 +257,35 @@
 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
-#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
+#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
+#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
+#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
+#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
+#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
+#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
-#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
+#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
+#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
+#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
+#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
+#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
+#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #elif defined(__wasm__)
 // quants.c
 #define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1
@@ -218,17 +309,33 @@
 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
-#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
+#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
+#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
+#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
+#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
+#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
+#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
-#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
+#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
+#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
+#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
+#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
+#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
+#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #endif
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/arch/arm/repack.cpp
index fb7f074a850..3eed0105bf1 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/arch/arm/repack.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/arch/arm/repack.cpp
@@ -25,9 +25,8 @@
 #define UNUSED GGML_UNUSED
 
 #if defined(__aarch64__) && defined(__ARM_NEON) && (defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_FEATURE_DOTPROD))
-static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in,
-                                             int16x8_t *     out_mins,
-                                             int8_t *        out_scales) {
+// Helper for decoding scales and mins of Q4_K and Q5_K block formats
+static inline void decode_q_Kx8_6bit_scales(const uint8_t * scales_in, int16x8_t * out_mins, int8_t * out_scales) {
     constexpr uint32_t kmask1 = 0x3f3f3f3f;
     constexpr uint32_t kmask2 = 0x0f0f0f0f;
     constexpr uint32_t kmask3 = 0x03030303;
@@ -499,6 +498,81 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
     ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 }
 
+void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    const int8x16_t kvalues = vld1q_s8(kvalues_mxfp4);
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    float * res_ptr = s;
+
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);
+
+        float32x4_t sumf = vdupq_n_f32(0);
+        for (int l = 0; l < nb; l++) {
+            uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
+            uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
+            uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
+            uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
+
+            int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4);
+            int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F);
+            int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4);
+            int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F);
+            int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4);
+            int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F);
+            int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4);
+            int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F);
+
+            int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
+            int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
+
+            int32x4_t sumi = vdupq_n_s32(0);
+            sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0);
+            sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0);
+            sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1);
+            sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1);
+            sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2);
+            sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
+            sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
+            sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);
+
+            float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
+            float32x4_t b_d = {
+                GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[0]),
+                GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[1]),
+                GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[2]),
+                GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[3]),
+            };
+            float32x4_t d = a_d * b_d;
+
+            sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi));
+        }
+
+        vst1q_f32(res_ptr + x * 4, sumf);
+    }
+    return;
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
+    ggml_gemv_mxfp4_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
 void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     constexpr int qk = QK_K;
     const int     nb = n / qk;
@@ -561,7 +635,7 @@ void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                 for (int i = 0; i < 2; i++) {
                     int8_t    aux_q4sb[8];
                     const int offset = sb * 24 + i * 12;
-                    decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
+                    decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
                     q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
                 }
 
@@ -701,7 +775,7 @@ void ggml_gemv_q4_K_8x8_q8_K(int                        n,
                 for (int i = 0; i < 2; i++) {
                     int8_t    aux_q4sb[8];
                     const int offset = sb * 24 + i * 12;
-                    decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
+                    decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
                     q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
                 }
 
@@ -786,189 +860,1152 @@ void ggml_gemv_q4_K_8x8_q8_K(int                        n,
     ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
 }
 
-void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
-    const int qk = QK8_0;
-    const int nb = n / qk;
-    const int ncols_interleaved = 4;
-    const int blocklen = 4;
+void ggml_gemv_q5_K_8x4_q8_K(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    constexpr int qk = QK_K;
+    const int     nb = n / qk;
 
-    assert (n % qk == 0);
-    assert (nr % 4 == 0);
-    assert (nc % ncols_interleaved == 0);
+    constexpr int ncols_interleaved = 8;
+    constexpr int blocklen          = 4;
+
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
 
-    UNUSED(s);
-    UNUSED(bs);
-    UNUSED(vx);
-    UNUSED(vy);
-    UNUSED(nr);
-    UNUSED(nc);
     UNUSED(nb);
     UNUSED(ncols_interleaved);
     UNUSED(blocklen);
 
-#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
-    const void * b_ptr = vx;
-    const void * a_ptr = vy;
-    float * res_ptr = s;
-    size_t res_stride = bs * sizeof(float);
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    constexpr int    col_groups = ncols_interleaved / 4;  // 0123 and 4567
+    const uint8x16_t m4b        = vdupq_n_u8(0x0f);
+    const uint8x16_t mone       = vdupq_n_u8(1);
+    const uint8x16_t mtwo       = vdupq_n_u8(2);
 
-    __asm__ __volatile__(
-        "mov x10, %x[nr]\n"
-        "mov x9, #0x88\n"
-        "cmp x10, #0x10\n"
-        "mul x9, %x[nb], x9\n"
-        "blt 4f\n"
-        "1:"  // Row loop
-        "add x28, %x[b_ptr], #0x8\n"
-        "mov x27, %x[nc]\n"
-        "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
-        "2:"  // Column loop
-        "add x25, %x[a_ptr], #0x8\n"
-        "movi v15.16b, #0x0\n"
-        "movi v19.16b, #0x0\n"
-        "mov x24, %x[nb]\n"
-        "add x23, x25, x9\n"
-        "movi v18.16b, #0x0\n"
-        "movi v14.16b, #0x0\n"
-        "add x22, x23, x9\n"
-        "movi v11.16b, #0x0\n"
-        "movi v13.16b, #0x0\n"
-        "add x21, x22, x9\n"
-        "movi v23.16b, #0x0\n"
-        "movi v16.16b, #0x0\n"
-        "movi v25.16b, #0x0\n"
-        "movi v7.16b, #0x0\n"
-        "movi v0.16b, #0x0\n"
-        "movi v4.16b, #0x0\n"
-        "movi v5.16b, #0x0\n"
-        "movi v21.16b, #0x0\n"
-        "movi v8.16b, #0x0\n"
-        "movi v1.16b, #0x0\n"
-        "3:"  // Block loop
-        "ldr q3, [x28, #0x0]\n"
-        "ldr q31, [x25, #0x0]\n"
-        "movi v28.16b, #0x4\n"
-        "movi v10.4s, #0x0\n"
-        "ldr q22, [x28, #0x10]\n"
-        "ldr q6, [x25, #0x10]\n"
-        "movi v29.4s, #0x0\n"
-        "movi v9.4s, #0x0\n"
-        "ldr q27, [x28, #0x20]\n"
-        "ldr q30, [x28, #0x30]\n"
-        "movi v20.4s, #0x0\n"
-        "movi v24.16b, #0xf0\n"
-        "ldr d2, [x25, #-0x8]\n"
-        "ldr d26, [x23, #-0x8]\n"
-        "sshl v12.16b, v3.16b, v28.16b\n"
-        "sub x20, x28, #0x8\n"
-        "ldr d17, [x20, #0x0]\n"
-        "and v3.16b, v3.16b, v24.16b\n"
-        "subs x24, x24, #0x1\n"
-        "add x28, x28, #0x48\n"
-        ".inst 0x4f9fe18a  // sdot v10.4s, v12.16b, v31.4b[0]\n"
-        ".inst 0x4fbfe19d  // sdot v29.4s, v12.16b, v31.4b[1]\n"
-        ".inst 0x4f9fe989  // sdot v9.4s, v12.16b, v31.4b[2]\n"
-        ".inst 0x4fbfe994  // sdot v20.4s, v12.16b, v31.4b[3]\n"
-        "sshl v31.16b, v22.16b, v28.16b\n"
-        "and v22.16b, v22.16b, v24.16b\n"
-        "fcvtl v17.4s, v17.4h\n"
-        "fcvtl v2.4s, v2.4h\n"
-        "fcvtl v26.4s, v26.4h\n"
-        ".inst 0x4f86e3ea  // sdot v10.4s, v31.16b, v6.4b[0]\n"
-        ".inst 0x4fa6e3fd  // sdot v29.4s, v31.16b, v6.4b[1]\n"
-        ".inst 0x4f86ebe9  // sdot v9.4s, v31.16b, v6.4b[2]\n"
-        ".inst 0x4fa6ebf4  // sdot v20.4s, v31.16b, v6.4b[3]\n"
-        "sshl v6.16b, v27.16b, v28.16b\n"
-        "sshl v28.16b, v30.16b, v28.16b\n"
-        "and v27.16b, v27.16b, v24.16b\n"
-        "and v30.16b, v30.16b, v24.16b\n"
-        "ldr q24, [x25, #0x20]\n"
-        ".inst 0x4f98e0ca  // sdot v10.4s, v6.16b, v24.4b[0]\n"
-        ".inst 0x4fb8e0dd  // sdot v29.4s, v6.16b, v24.4b[1]\n"
-        ".inst 0x4f98e8c9  // sdot v9.4s, v6.16b, v24.4b[2]\n"
-        ".inst 0x4fb8e8d4  // sdot v20.4s, v6.16b, v24.4b[3]\n"
-        "ldr q24, [x25, #0x30]\n"
-        ".inst 0x4f98e38a  // sdot v10.4s, v28.16b, v24.4b[0]\n"
-        ".inst 0x4fb8e39d  // sdot v29.4s, v28.16b, v24.4b[1]\n"
-        ".inst 0x4f98eb89  // sdot v9.4s, v28.16b, v24.4b[2]\n"
-        ".inst 0x4fb8eb94  // sdot v20.4s, v28.16b, v24.4b[3]\n"
-        "ldr q24, [x25, #0x40]\n"
-        ".inst 0x4f98e06a  // sdot v10.4s, v3.16b, v24.4b[0]\n"
-        ".inst 0x4fb8e07d  // sdot v29.4s, v3.16b, v24.4b[1]\n"
-        ".inst 0x4f98e869  // sdot v9.4s, v3.16b, v24.4b[2]\n"
-        ".inst 0x4fb8e874  // sdot v20.4s, v3.16b, v24.4b[3]\n"
-        "ldr q24, [x25, #0x50]\n"
-        ".inst 0x4f98e2ca  // sdot v10.4s, v22.16b, v24.4b[0]\n"
-        ".inst 0x4fb8e2dd  // sdot v29.4s, v22.16b, v24.4b[1]\n"
-        ".inst 0x4f98eac9  // sdot v9.4s, v22.16b, v24.4b[2]\n"
-        ".inst 0x4fb8ead4  // sdot v20.4s, v22.16b, v24.4b[3]\n"
-        "ldr q24, [x25, #0x60]\n"
-        ".inst 0x4f98e36a  // sdot v10.4s, v27.16b, v24.4b[0]\n"
-        ".inst 0x4fb8e37d  // sdot v29.4s, v27.16b, v24.4b[1]\n"
-        ".inst 0x4f98eb69  // sdot v9.4s, v27.16b, v24.4b[2]\n"
-        ".inst 0x4fb8eb74  // sdot v20.4s, v27.16b, v24.4b[3]\n"
-        "ldr q24, [x25, #0x70]\n"
-        "add x25, x25, #0x88\n"
-        ".inst 0x4f98e3ca  // sdot v10.4s, v30.16b, v24.4b[0]\n"
-        ".inst 0x4fb8e3dd  // sdot v29.4s, v30.16b, v24.4b[1]\n"
-        ".inst 0x4f98ebc9  // sdot v9.4s, v30.16b, v24.4b[2]\n"
-        ".inst 0x4fb8ebd4  // sdot v20.4s, v30.16b, v24.4b[3]\n"
-        "fmul v24.4s, v17.4s, v2.s[0]\n"
-        "scvtf v10.4s, v10.4s, #0x4\n"
-        "scvtf v29.4s, v29.4s, #0x4\n"
-        "scvtf v9.4s, v9.4s, #0x4\n"
-        "scvtf v20.4s, v20.4s, #0x4\n"
-        "fmla v15.4s, v10.4s, v24.4s\n"
-        "ldr q24, [x23, #0x0]\n"
-        "fmul v10.4s, v17.4s, v2.s[1]\n"
-        "fmla v19.4s, v29.4s, v10.4s\n"
-        "ldr q10, [x23, #0x10]\n"
-        "fmul v29.4s, v17.4s, v2.s[2]\n"
-        "fmul v2.4s, v17.4s, v2.s[3]\n"
-        "fmla v18.4s, v9.4s, v29.4s\n"
-        "movi v9.4s, #0x0\n"
-        "movi v29.4s, #0x0\n"
-        ".inst 0x4f98e189  // sdot v9.4s, v12.16b, v24.4b[0]\n"
-        ".inst 0x4fb8e19d  // sdot v29.4s, v12.16b, v24.4b[1]\n"
-        "fmla v14.4s, v20.4s, v2.4s\n"
-        "movi v20.4s, #0x0\n"
-        "movi v2.4s, #0x0\n"
-        ".inst 0x4f98e994  // sdot v20.4s, v12.16b, v24.4b[2]\n"
-        ".inst 0x4fb8e982  // sdot v2.4s, v12.16b, v24.4b[3]\n"
-        "ldr q24, [x23, #0x20]\n"
-        ".inst 0x4f8ae3e9  // sdot v9.4s, v31.16b, v10.4b[0]\n"
-        ".inst 0x4faae3fd  // sdot v29.4s, v31.16b, v10.4b[1]\n"
-        ".inst 0x4f8aebf4  // sdot v20.4s, v31.16b, v10.4b[2]\n"
-        ".inst 0x4faaebe2  // sdot v2.4s, v31.16b, v10.4b[3]\n"
-        "ldr q10, [x23, #0x30]\n"
-        ".inst 0x4f98e0c9  // sdot v9.4s, v6.16b, v24.4b[0]\n"
-        ".inst 0x4fb8e0dd  // sdot v29.4s, v6.16b, v24.4b[1]\n"
-        ".inst 0x4f98e8d4  // sdot v20.4s, v6.16b, v24.4b[2]\n"
-        ".inst 0x4fb8e8c2  // sdot v2.4s, v6.16b, v24.4b[3]\n"
-        "ldr q24, [x23, #0x40]\n"
-        ".inst 0x4f8ae389  // sdot v9.4s, v28.16b, v10.4b[0]\n"
-        ".inst 0x4faae39d  // sdot v29.4s, v28.16b, v10.4b[1]\n"
-        ".inst 0x4f8aeb94  // sdot v20.4s, v28.16b, v10.4b[2]\n"
-        ".inst 0x4faaeb82  // sdot v2.4s, v28.16b, v10.4b[3]\n"
-        "ldr q10, [x23, #0x50]\n"
-        ".inst 0x4f98e069  // sdot v9.4s, v3.16b, v24.4b[0]\n"
-        ".inst 0x4fb8e07d  // sdot v29.4s, v3.16b, v24.4b[1]\n"
-        ".inst 0x4f98e874  // sdot v20.4s, v3.16b, v24.4b[2]\n"
-        ".inst 0x4fb8e862  // sdot v2.4s, v3.16b, v24.4b[3]\n"
-        "ldr q24, [x23, #0x60]\n"
-        ".inst 0x4f8ae2c9  // sdot v9.4s, v22.16b, v10.4b[0]\n"
-        ".inst 0x4faae2dd  // sdot v29.4s, v22.16b, v10.4b[1]\n"
-        ".inst 0x4f8aead4  // sdot v20.4s, v22.16b, v10.4b[2]\n"
-        ".inst 0x4faaeac2  // sdot v2.4s, v22.16b, v10.4b[3]\n"
-        "ldr q10, [x23, #0x70]\n"
-        "add x23, x23, #0x88\n"
-        ".inst 0x4f98e369  // sdot v9.4s, v27.16b, v24.4b[0]\n"
-        ".inst 0x4fb8e37d  // sdot v29.4s, v27.16b, v24.4b[1]\n"
-        ".inst 0x4f98eb74  // sdot v20.4s, v27.16b, v24.4b[2]\n"
-        ".inst 0x4fb8eb62  // sdot v2.4s, v27.16b, v24.4b[3]\n"
-        "ldr q24, [x22, #0x0]\n"
-        ".inst 0x4f8ae3c9  // sdot v9.4s, v30.16b, v10.4b[0]\n"
-        ".inst 0x4faae3dd  // sdot v29.4s, v30.16b, v10.4b[1]\n"
+    // 1x8 tile = 2 x 4
+    float32x4_t acc_f32[col_groups];
+
+    const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
+
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
+
+        for (int i = 0; i < col_groups; i++) {
+            acc_f32[i] = vdupq_n_f32(0);
+        }
+
+        for (int b = 0; b < nb; b++) {
+            float32x4_t q5_d_0        = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d));      // d0 d1 d2 d3
+            float32x4_t q5_d_1        = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4));  // d4 d5 d6 d7
+            float32x4_t q8_d          = vdupq_n_f32(q8_ptr[b].d);
+            float32x4_t sb_scale_0123 = vmulq_f32(q5_d_0, q8_d);
+            float32x4_t sb_scale_4567 = vmulq_f32(q5_d_1, q8_d);
+            float32x4_t q5_dmin_0     = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin));      // dmin 0..3
+            float32x4_t q5_dmin_1     = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4));  // dmin 4..7
+            float32x4_t sb_min_0123   = vmulq_f32(q5_dmin_0, q8_d);
+            float32x4_t sb_min_4567   = vmulq_f32(q5_dmin_1, q8_d);
+
+            // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
+            int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
+            int32x4_t acc_lo[col_groups];
+            int32x4_t acc_hi[col_groups];
+
+            // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
+            const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
+            int16_t         bsums_arr[8];
+            vst1q_s16(bsums_arr, bsums);
+
+            uint8x16_t qh[col_groups][8];
+            for (int c = 0; c < col_groups; c++) {
+                for (int i = 0; i < 8; i++) {
+                    qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c);
+                }
+            }
+
+            for (int sb = 0; sb < QK_K / 64; sb++) {
+                for (int i = 0; i < col_groups; i++) {
+                    acc_lo[i] = vdupq_n_s32(0);
+                    acc_hi[i] = vdupq_n_s32(0);
+                }
+                // Need scales for the low and high nibbles
+                // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
+                int16x8_t q5sb_mins[2];
+                int16x8_t q5sb_scales[2];
+                for (int i = 0; i < 2; i++) {
+                    int8_t    aux_q5sb[8];
+                    const int offset = sb * 24 + i * 12;
+                    decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
+                    q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
+                }
+
+                int8x16_t q8_qs[4];
+                for (int i = 0; i < 4; i++) {
+                    q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16);
+                }
+
+                for (int c = 0; c < col_groups; c++) {
+                    uint8x16_t q5_cols[8];
+                    uint8x16_t hbit_lo[8];
+                    uint8x16_t hbit_hi[8];
+                    int8x16_t  q5_lo[8];
+                    int8x16_t  q5_hi[8];
+
+                    for (int i = 0; i < 8; i++) {
+                        q5_cols[i] = vld1q_u8(q5_ptr[b].qs + sb * QK_K + i * 32 + 16 * c);
+                        hbit_lo[i] = vandq_u8(qh[c][i], mone);
+                        hbit_hi[i] = vshlq_n_u8(vandq_u8(qh[c][i], mtwo), 3);
+                        qh[c][i]   = vshrq_n_u8(qh[c][i], 2);
+                        q5_lo[i]   = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_cols[i], m4b), hbit_lo[i], 4));
+                        q5_hi[i]   = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_cols[i], 4), hbit_hi[i]));
+                    }
+
+                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[0], q8_qs[0], 0);
+                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[1], q8_qs[0], 1);
+                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[2], q8_qs[0], 2);
+                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[3], q8_qs[0], 3);
+                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[4], q8_qs[1], 0);
+                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[5], q8_qs[1], 1);
+                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[6], q8_qs[1], 2);
+                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[7], q8_qs[1], 3);
+
+                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[0], q8_qs[2], 0);
+                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[1], q8_qs[2], 1);
+                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[2], q8_qs[2], 2);
+                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[3], q8_qs[2], 3);
+                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[4], q8_qs[3], 0);
+                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[5], q8_qs[3], 1);
+                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[6], q8_qs[3], 2);
+                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[7], q8_qs[3], 3);
+                }
+
+                // Scales
+                // row c0123 blk0 and blk1
+                const int16x4_t   sc_0123_lo = vget_low_s16(q5sb_scales[0]);
+                const int16x4_t   sc_0123_hi = vget_low_s16(q5sb_scales[1]);
+                const float32x4_t sumf_0123  = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]),
+                                                                       vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0])));
+                acc_f32[0]                   = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123);
+                // row c4567 blk0 and blk1
+                const int16x4_t   sc_4567_lo = vget_high_s16(q5sb_scales[0]);
+                const int16x4_t   sc_4567_hi = vget_high_s16(q5sb_scales[1]);
+                const float32x4_t sumf_4567  = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]),
+                                                                       vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1])));
+                acc_f32[1]                   = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567);
+
+                // Bias Correction
+                const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
+                const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
+
+                bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
+                bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
+                bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
+                bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
+            }  // for sb
+
+            acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123);
+            acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567);
+        }  // for b
+
+        int base = x * ncols_interleaved;
+        vst1q_f32(s + base, acc_f32[0]);
+        vst1q_f32(s + base + 4, acc_f32[1]);
+    }  // for x
+    return;
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    ggml_gemv_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemv_q5_K_8x8_q8_K(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    constexpr int qk = QK_K;
+    const int     nb = n / qk;
+
+    constexpr int ncols_interleaved = 8;
+    constexpr int blocklen          = 8;
+
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    constexpr int    col_pairs = ncols_interleaved / 2;
+    const uint8x16_t m4b       = vdupq_n_u8(0x0f);
+    const uint8x16_t mone      = vdupq_n_u8(1);
+    const uint8x16_t mtwo      = vdupq_n_u8(2);
+
+    // 1x8 tile = 2 x 4
+    float32x4_t acc_f32[ncols_interleaved / 4];
+
+    const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
+
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
+
+        for (int i = 0; i < ncols_interleaved / 4; i++) {
+            acc_f32[i] = vdupq_n_f32(0);
+        }
+
+        for (int b = 0; b < nb; b++) {
+            float32x4_t q5_d_0     = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d));      // d0 d1 d2 d3
+            float32x4_t q5_d_1     = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4));  // d4 d5 d6 d7
+            float32x4_t q8_d       = vdupq_n_f32(q8_ptr[b].d);
+            float32x4_t sb_scale_0 = vmulq_f32(q5_d_0, q8_d);
+            float32x4_t sb_scale_1 = vmulq_f32(q5_d_1, q8_d);
+            float32x4_t q5_dmin_0  = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin));      // dmin 0..3
+            float32x4_t q5_dmin_1  = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4));  // dmin 4..7
+            float32x4_t sb_min_0   = vmulq_f32(q5_dmin_0, q8_d);
+            float32x4_t sb_min_1   = vmulq_f32(q5_dmin_1, q8_d);
+
+            // 2 sb each iteration
+            int32x4_t acc_lo[col_pairs];
+            int32x4_t acc_hi[col_pairs];
+
+            // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
+            const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
+            int16_t         bsums_arr[8];
+            vst1q_s16(bsums_arr, bsums);
+
+            // Load qh once per block and shift after each subblock
+            const uint8_t * qh_base = q5_ptr[b].qh;
+            uint8x16_t      qh[col_pairs][4];
+            for (int cp = 0; cp < col_pairs; cp++) {
+                qh[cp][0] = vld1q_u8(qh_base + 16 * cp);
+                qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);
+                qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);
+                qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);
+            }
+
+            for (int sb = 0; sb < QK_K / 64; sb++) {
+                for (int i = 0; i < col_pairs; i++) {
+                    acc_lo[i] = vdupq_n_s32(0);
+                    acc_hi[i] = vdupq_n_s32(0);
+                }
+                // Need scales for the low and high nibbles
+                // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
+                int16x8_t q5sb_mins[2];  // int16 as its needed for bias_acc later
+                int16x8_t q5sb_scales[2];
+                for (int i = 0; i < 2; i++) {
+                    int8_t    aux_q5sb[8];
+                    const int offset = sb * 24 + i * 12;
+                    decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
+                    q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
+                }
+
+                const uint8_t * qs_base = q5_ptr[b].qs + sb * QK_K;
+
+                // Load the 64 quants from q8K duplicated to use vecdots with the interleaved columns
+                const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
+                int8x16_t      q8_qs[8];
+                for (int i = 0; i < 8; i++) {
+                    q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));
+                }
+
+                // Q5s column pair loop unrolled
+                {
+                    // Cols 01
+                    uint8x16_t qs_0 = vld1q_u8(qs_base);
+                    uint8x16_t qs_1 = vld1q_u8(qs_base + 64);
+                    uint8x16_t qs_2 = vld1q_u8(qs_base + 128);
+                    uint8x16_t qs_3 = vld1q_u8(qs_base + 192);
+
+                    uint8x16_t hbit_lo_0 = vandq_u8(qh[0][0], mone);
+                    uint8x16_t hbit_lo_1 = vandq_u8(qh[0][1], mone);
+                    uint8x16_t hbit_lo_2 = vandq_u8(qh[0][2], mone);
+                    uint8x16_t hbit_lo_3 = vandq_u8(qh[0][3], mone);
+                    uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[0][0], mtwo), 3);
+                    uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[0][1], mtwo), 3);
+                    uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[0][2], mtwo), 3);
+                    uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[0][3], mtwo), 3);
+
+                    qh[0][0] = vshrq_n_u8(qh[0][0], 2);
+                    qh[0][1] = vshrq_n_u8(qh[0][1], 2);
+                    qh[0][2] = vshrq_n_u8(qh[0][2], 2);
+                    qh[0][3] = vshrq_n_u8(qh[0][3], 2);
+
+                    acc_lo[0] = ggml_vdotq_s32(
+                        acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
+                    acc_lo[0] = ggml_vdotq_s32(
+                        acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
+                    acc_lo[0] = ggml_vdotq_s32(
+                        acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
+                    acc_lo[0] = ggml_vdotq_s32(
+                        acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
+                    acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
+                                               q8_qs[4]);
+                    acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
+                                               q8_qs[5]);
+                    acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
+                                               q8_qs[6]);
+                    acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
+                                               q8_qs[7]);
+
+                    // Cols 23
+                    qs_0 = vld1q_u8(qs_base + 16);
+                    qs_1 = vld1q_u8(qs_base + 80);
+                    qs_2 = vld1q_u8(qs_base + 144);
+                    qs_3 = vld1q_u8(qs_base + 208);
+
+                    hbit_lo_0 = vandq_u8(qh[1][0], mone);
+                    hbit_lo_1 = vandq_u8(qh[1][1], mone);
+                    hbit_lo_2 = vandq_u8(qh[1][2], mone);
+                    hbit_lo_3 = vandq_u8(qh[1][3], mone);
+                    hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[1][0], mtwo), 3);
+                    hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[1][1], mtwo), 3);
+                    hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[1][2], mtwo), 3);
+                    hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[1][3], mtwo), 3);
+
+                    qh[1][0] = vshrq_n_u8(qh[1][0], 2);
+                    qh[1][1] = vshrq_n_u8(qh[1][1], 2);
+                    qh[1][2] = vshrq_n_u8(qh[1][2], 2);
+                    qh[1][3] = vshrq_n_u8(qh[1][3], 2);
+
+                    acc_lo[1] = ggml_vdotq_s32(
+                        acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
+                    acc_lo[1] = ggml_vdotq_s32(
+                        acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
+                    acc_lo[1] = ggml_vdotq_s32(
+                        acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
+                    acc_lo[1] = ggml_vdotq_s32(
+                        acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
+                    acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
+                                               q8_qs[4]);
+                    acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
+                                               q8_qs[5]);
+                    acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
+                                               q8_qs[6]);
+                    acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
+                                               q8_qs[7]);
+
+                    // Cols 45
+                    qs_0 = vld1q_u8(qs_base + 32);
+                    qs_1 = vld1q_u8(qs_base + 96);
+                    qs_2 = vld1q_u8(qs_base + 160);
+                    qs_3 = vld1q_u8(qs_base + 224);
+
+                    hbit_lo_0 = vandq_u8(qh[2][0], mone);
+                    hbit_lo_1 = vandq_u8(qh[2][1], mone);
+                    hbit_lo_2 = vandq_u8(qh[2][2], mone);
+                    hbit_lo_3 = vandq_u8(qh[2][3], mone);
+                    hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[2][0], mtwo), 3);
+                    hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[2][1], mtwo), 3);
+                    hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[2][2], mtwo), 3);
+                    hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[2][3], mtwo), 3);
+
+                    qh[2][0] = vshrq_n_u8(qh[2][0], 2);
+                    qh[2][1] = vshrq_n_u8(qh[2][1], 2);
+                    qh[2][2] = vshrq_n_u8(qh[2][2], 2);
+                    qh[2][3] = vshrq_n_u8(qh[2][3], 2);
+
+                    acc_lo[2] = ggml_vdotq_s32(
+                        acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
+                    acc_lo[2] = ggml_vdotq_s32(
+                        acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
+                    acc_lo[2] = ggml_vdotq_s32(
+                        acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
+                    acc_lo[2] = ggml_vdotq_s32(
+                        acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
+                    acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
+                                               q8_qs[4]);
+                    acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
+                                               q8_qs[5]);
+                    acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
+                                               q8_qs[6]);
+                    acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
+                                               q8_qs[7]);
+
+                    // Cols 45
+                    qs_0 = vld1q_u8(qs_base + 48);
+                    qs_1 = vld1q_u8(qs_base + 112);
+                    qs_2 = vld1q_u8(qs_base + 176);
+                    qs_3 = vld1q_u8(qs_base + 240);
+
+                    hbit_lo_0 = vandq_u8(qh[3][0], mone);
+                    hbit_lo_1 = vandq_u8(qh[3][1], mone);
+                    hbit_lo_2 = vandq_u8(qh[3][2], mone);
+                    hbit_lo_3 = vandq_u8(qh[3][3], mone);
+                    hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[3][0], mtwo), 3);
+                    hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[3][1], mtwo), 3);
+                    hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[3][2], mtwo), 3);
+                    hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[3][3], mtwo), 3);
+
+                    qh[3][0] = vshrq_n_u8(qh[3][0], 2);
+                    qh[3][1] = vshrq_n_u8(qh[3][1], 2);
+                    qh[3][2] = vshrq_n_u8(qh[3][2], 2);
+                    qh[3][3] = vshrq_n_u8(qh[3][3], 2);
+
+                    acc_lo[3] = ggml_vdotq_s32(
+                        acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
+                    acc_lo[3] = ggml_vdotq_s32(
+                        acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
+                    acc_lo[3] = ggml_vdotq_s32(
+                        acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
+                    acc_lo[3] = ggml_vdotq_s32(
+                        acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
+                    acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
+                                               q8_qs[4]);
+                    acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
+                                               q8_qs[5]);
+                    acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
+                                               q8_qs[6]);
+                    acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
+                                               q8_qs[7]);
+                }
+
+                // Prepare bsum vectors for bias computation
+                // Each pair of subblocks share the same bsums
+                int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
+                int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
+
+                // Iterates over a pair of column pairs (4 columns) to use a single 128 register
+                // p = 0 -> 0123  p2 -> 4567
+                for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {
+                    int16x4_t   group_scales_lo = p == 0 ? vget_low_s16(q5sb_scales[0]) : vget_high_s16(q5sb_scales[0]);
+                    int16x4_t   group_scales_hi = p == 0 ? vget_low_s16(q5sb_scales[1]) : vget_high_s16(q5sb_scales[1]);
+                    int16x4_t   group_mins_lo   = p == 0 ? vget_low_s16(q5sb_mins[0]) : vget_high_s16(q5sb_mins[0]);
+                    int16x4_t   group_mins_hi   = p == 0 ? vget_low_s16(q5sb_mins[1]) : vget_high_s16(q5sb_mins[1]);
+                    float32x4_t sb_scale        = p == 0 ? sb_scale_0 : sb_scale_1;
+                    float32x4_t sb_min          = p == 0 ? sb_min_0 : sb_min_1;
+
+                    // 0123 or 4567
+                    float32x4_t sumf_0 =
+                        vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
+                    acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
+
+                    float32x4_t sumf_1 =
+                        vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));
+                    acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);
+
+                    // FUSED BIAS: Compute and subtract bias immediately
+                    // bias = (bsums_lo * mins_lo + bsums_hi * mins_hi) * sb_min
+                    int32x4_t bias       = vmull_s16(bsums_vec_lo, group_mins_lo);
+                    bias                 = vmlal_s16(bias, bsums_vec_hi, group_mins_hi);
+                    float32x4_t bias_f32 = vcvtq_f32_s32(bias);
+                    acc_f32[i]           = vmlsq_f32(acc_f32[i], sb_min, bias_f32);
+                }
+            }  // for sb
+        }  // for b
+
+        int base = x * ncols_interleaved;
+        vst1q_f32(s + base, acc_f32[0]);
+        vst1q_f32(s + base + 4, acc_f32[1]);
+    }  // for x
+    return;
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemv_q6_K_8x4_q8_K(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    constexpr int qk = QK_K;
+    const int     nb = n / qk;
+
+    constexpr int ncols_interleaved = 8;
+    constexpr int blocklen          = 4;
+
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    constexpr int    col_groups = ncols_interleaved / 4;
+    const uint8x16_t m4b        = vdupq_n_u8(0x0f);
+    const uint8x16_t mask_lo    = vdupq_n_u8(0x03);
+    const uint8x16_t mask_hi    = vdupq_n_u8(0x30);
+
+    // 1x8 tile = 2 x 4
+    float32x4_t acc_f32[2];
+
+    const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
+
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
+
+        for (int i = 0; i < col_groups; i++) {
+            acc_f32[i] = vdupq_n_f32(0);
+        }
+
+        for (int b = 0; b < nb; b++) {
+            float32x4_t q6_d_0     = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d));      // d0 d1 d2 d3
+            float32x4_t q6_d_1     = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4));  // d4 d5 d6 d7
+            float32x4_t q8_d       = vdupq_n_f32(q8_ptr[b].d);
+            float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d);
+            float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d);
+
+            int32x4_t acc[col_groups];
+            for (int i = 0; i < col_groups; i++) {
+                acc[i] = vdupq_n_s32(0);
+            }
+
+            // Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block)
+            // Reused for bias and dequantization later
+            int16_t q6_scales[16 * 8];
+            for (int i = 0; i < 16; i++) {
+                int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
+                vst1q_s16(q6_scales + i * 8, scales);
+            }
+
+            // Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift
+            int32x4_t bias_lo = vdupq_n_s32(0);
+            int32x4_t bias_hi = vdupq_n_s32(0);
+
+            // Load bsums in chunks of 4 to process with vectorized operations
+            for (int i = 0; i < 16; i += 4) {
+                int16x4_t bsums_vec   = vld1_s16(q8_ptr[b].bsums + i);
+                int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8);
+                int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4);
+                int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8);
+                int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4);
+                int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8);
+                int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4);
+                int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8);
+                int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4);
+
+                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0);
+                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0);
+                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1);
+                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1);
+                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2);
+                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2);
+                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3);
+                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3);
+            }
+            bias_lo = vshlq_n_s32(bias_lo, 5);
+            bias_hi = vshlq_n_s32(bias_hi, 5);
+
+            // Process two 128-value halves per superblock
+            for (int half = 0; half < 2; half++) {
+                const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
+                const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
+
+                // A subblock (sb) is a set of weights that share the scale
+                // Since q6_K scales are per 16 elements
+                // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
+                for (int sb = 0; sb < QK_K / 64; sb++) {
+                    const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16;
+                    const int8_t * q8_base_h = q8_base_l + 64;
+
+                    // Load and duplicate q8 values (each register covers four interleaved columns of q6)
+                    int8x16_t q8_l[4];
+                    int8x16_t q8_h[4];
+                    for (int i = 0; i < 4; i++) {
+                        q8_l[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_l + i * 4));
+                        q8_h[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_h + i * 4));
+                    }
+
+                    const int ql_off_base = sb * QK_K / 2;
+                    const int qh_off_base = ql_off_base & 255;  // wraps after 256 bytes
+
+                    // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)
+                    uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base);
+                    uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64);
+                    uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base);
+                    uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64);
+
+                    // Adjust qh for subblocks 2 and 3 (shift right by 2)
+                    if (sb > 1) {
+                        q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2);
+                        q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2);
+                        q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2);
+                        q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2);
+                        q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2);
+                        q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2);
+                        q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2);
+                        q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2);
+                    }
+
+                    const uint8x16_t q6_ql[8] = { q6_ql_0.val[0], q6_ql_0.val[1], q6_ql_0.val[2], q6_ql_0.val[3],
+                                                  q6_ql_1.val[0], q6_ql_1.val[1], q6_ql_1.val[2], q6_ql_1.val[3] };
+                    const uint8x16_t q6_qh[8] = { q6_qh_0.val[0], q6_qh_0.val[1], q6_qh_0.val[2], q6_qh_0.val[3],
+                                                  q6_qh_1.val[0], q6_qh_1.val[1], q6_qh_1.val[2], q6_qh_1.val[3] };
+
+                    // Process column groups (0-3, 4-7)
+                    for (int g = 0; g < col_groups; g++) {
+                        int32x4_t sb_acc_l = vdupq_n_s32(0);
+                        int32x4_t sb_acc_h = vdupq_n_s32(0);
+
+                        for (int chunk = 0; chunk < 4; chunk++) {
+                            const int idx = chunk * 2 + g;
+
+                            const uint8x16_t q6_qs_l = q6_ql[idx];
+                            const uint8x16_t q6_qs_h = q6_qh[idx];
+
+                            // Extract high 2 bits for upper nibble reconstruction
+                            const uint8x16_t q6_qs_hh = vandq_u8(q6_qs_h, mask_hi);
+
+                            // q6 = (low4 | high2<<4), without -32 bias (handled via bsums)
+                            const int8x16_t q6_l =
+                                vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_l, m4b), vandq_u8(q6_qs_h, mask_lo), 4));
+                            const int8x16_t q6_h = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_l, 4), q6_qs_hh));
+
+                            sb_acc_l = vdotq_s32(sb_acc_l, q6_l, q8_l[chunk]);
+                            sb_acc_h = vdotq_s32(sb_acc_h, q6_h, q8_h[chunk]);
+                        }
+
+                        const int scale_idx_l = half * 8 + sb;
+                        const int scale_idx_h = half * 8 + sb + 4;
+
+                        const int32x4_t scale_vec_l = vmovl_s16(vld1_s16(q6_scales + scale_idx_l * 8 + g * 4));
+                        const int32x4_t scale_vec_h = vmovl_s16(vld1_s16(q6_scales + scale_idx_h * 8 + g * 4));
+
+                        acc[g] = vmlaq_s32(acc[g], sb_acc_l, scale_vec_l);
+                        acc[g] = vmlaq_s32(acc[g], sb_acc_h, scale_vec_h);
+                    }
+                }
+            }  // for half
+
+            // Bias correction
+            acc[0] = vsubq_s32(acc[0], bias_lo);
+            acc[1] = vsubq_s32(acc[1], bias_hi);
+
+            // Apply superblock scale (no mins for q6_K)
+            // acc[g] has [c0, c1, c2, c3]
+            float32x4_t w_0123 = vmulq_f32(vcvtq_f32_s32(acc[0]), sb_scale_0);
+            float32x4_t w_4567 = vmulq_f32(vcvtq_f32_s32(acc[1]), sb_scale_1);
+
+            acc_f32[0] = vaddq_f32(acc_f32[0], w_0123);
+            acc_f32[1] = vaddq_f32(acc_f32[1], w_4567);
+        }  // for b
+
+        int base = x * ncols_interleaved;
+        vst1q_f32(s + base, acc_f32[0]);
+        vst1q_f32(s + base + 4, acc_f32[1]);
+    }  // for x
+    return;
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    ggml_gemv_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemv_q6_K_8x8_q8_K(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    constexpr int qk = QK_K;
+    const int     nb = n / qk;
+
+    constexpr int ncols_interleaved = 8;
+    constexpr int blocklen          = 8;
+
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    constexpr int    col_pairs = ncols_interleaved / 2;
+    const uint8x16_t m4b       = vdupq_n_u8(0x0f);
+    const uint8x16_t mask_lo   = vdupq_n_u8(0x03);
+    const uint8x16_t mask_hi   = vdupq_n_u8(0x30);
+
+    // 1x8 tile = 2 x 4
+    float32x4_t acc_f32[2];
+
+    const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
+
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
+
+        acc_f32[0] = vdupq_n_f32(0);
+        acc_f32[1] = vdupq_n_f32(0);
+
+        for (int b = 0; b < nb; b++) {
+            float32x4_t q6_d_0     = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d));      // d0 d1 d2 d3
+            float32x4_t q6_d_1     = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4));  // d4 d5 d6 d7
+            float32x4_t q8_d       = vdupq_n_f32(q8_ptr[b].d);
+            float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d);
+            float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d);
+
+            int32x2_t acc[col_pairs];
+            for (int i = 0; i < col_pairs; i++) {
+                acc[i] = vdup_n_s32(0);
+            }
+
+            // Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block)
+            // Reused for bias and dequantization later
+            int16_t q6_scales[16 * 8];
+            for (int i = 0; i < 16; i++) {
+                int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
+                vst1q_s16(q6_scales + i * 8, scales);
+            }
+
+            // Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift
+            int32x4_t bias_lo = vdupq_n_s32(0);
+            int32x4_t bias_hi = vdupq_n_s32(0);
+
+            // Load bsums in chunks of 4 to process with vectorized operations
+            for (int i = 0; i < 16; i += 4) {
+                int16x4_t bsums_vec   = vld1_s16(q8_ptr[b].bsums + i);
+                int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8);
+                int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4);
+                int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8);
+                int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4);
+                int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8);
+                int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4);
+                int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8);
+                int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4);
+
+                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0);
+                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0);
+                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1);
+                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1);
+                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2);
+                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2);
+                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3);
+                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3);
+            }
+            bias_lo = vshlq_n_s32(bias_lo, 5);
+            bias_hi = vshlq_n_s32(bias_hi, 5);
+
+            // Process two 128-value halves per superblock
+            for (int half = 0; half < 2; half++) {
+                const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
+                const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
+
+                // A subblock (sb) is a set of weights that share the scale
+                // Since q6_K scales are per 16 elements
+                // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
+                for (int sb = 0; sb < QK_K / 64; sb++) {
+                    const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16;
+                    const int8_t * q8_base_h = q8_base_l + 64;
+
+                    // Load and duplicate q8 values (each register covers two interleaved columns of q6)
+                    int8x16_t q8_l[2];
+                    int8x16_t q8_h[2];
+                    for (int i = 0; i < 2; i++) {
+                        q8_l[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_l + i * 8));
+                        q8_h[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_h + i * 8));
+                    }
+
+                    const int ql_off_base = sb * QK_K / 2;
+                    const int qh_off_base = ql_off_base & 255;  // wraps after 256 bytes
+
+                    // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)
+                    uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base);
+                    uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64);
+                    uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base);
+                    uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64);
+
+                    // Adjust qh for subblocks 2 and 3 (shift right by 2)
+                    if (sb > 1) {
+                        q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2);
+                        q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2);
+                        q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2);
+                        q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2);
+                        q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2);
+                        q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2);
+                        q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2);
+                        q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2);
+                    }
+
+                    // Process column pairs (0-1, 2-3, 4-5, 6-7)
+                    for (int cp = 0; cp < col_pairs; cp++) {
+                        const uint8x16_t q6_qs_cp_0_l = q6_ql_0.val[cp];
+                        const uint8x16_t q6_qs_cp_1_l = q6_ql_1.val[cp];
+                        const uint8x16_t q6_qs_cp_0_h = q6_qh_0.val[cp];
+                        const uint8x16_t q6_qs_cp_1_h = q6_qh_1.val[cp];
+
+                        // Extract high 2 bits for upper nibble reconstruction
+                        const uint8x16_t q6_qs_cp_0_hh = vandq_u8(q6_qs_cp_0_h, mask_hi);
+                        const uint8x16_t q6_qs_cp_1_hh = vandq_u8(q6_qs_cp_1_h, mask_hi);
+
+                        // q6 = (low4 | high2<<4), without -32 bias (handled via bsums)
+                        const int8x16_t q6_l0 = vreinterpretq_s8_u8(
+                            vsliq_n_u8(vandq_u8(q6_qs_cp_0_l, m4b), vandq_u8(q6_qs_cp_0_h, mask_lo), 4));
+                        const int8x16_t q6_l1 = vreinterpretq_s8_u8(
+                            vsliq_n_u8(vandq_u8(q6_qs_cp_1_l, m4b), vandq_u8(q6_qs_cp_1_h, mask_lo), 4));
+                        const int8x16_t q6_h0 =
+                            vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_0_l, 4), q6_qs_cp_0_hh));
+                        const int8x16_t q6_h1 =
+                            vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_1_l, 4), q6_qs_cp_1_hh));
+
+                        int32x4_t sb_acc_l = vdupq_n_s32(0);
+                        sb_acc_l           = vdotq_s32(sb_acc_l, q6_l0, q8_l[0]);
+                        sb_acc_l           = vdotq_s32(sb_acc_l, q6_l1, q8_l[1]);
+
+                        int32x4_t sb_acc_h = vdupq_n_s32(0);
+                        sb_acc_h           = vdotq_s32(sb_acc_h, q6_h0, q8_h[0]);
+                        sb_acc_h           = vdotq_s32(sb_acc_h, q6_h1, q8_h[1]);
+
+                        // Pairwise add to get per-column sums: [col0, col1]
+                        int32x2_t sum_l = vpadd_s32(vget_low_s32(sb_acc_l), vget_high_s32(sb_acc_l));
+                        int32x2_t sum_h = vpadd_s32(vget_low_s32(sb_acc_h), vget_high_s32(sb_acc_h));
+
+                        const int scale_idx_l = half * 8 + sb;
+                        const int scale_idx_h = half * 8 + sb + 4;
+
+                        // Access scales using array indexing (scales are interleaved by column)
+                        const int32x2_t scale_vec_l = { (int32_t) q6_scales[scale_idx_l * 8 + cp * 2],
+                                                        (int32_t) q6_scales[scale_idx_l * 8 + cp * 2 + 1] };
+                        const int32x2_t scale_vec_h = { (int32_t) q6_scales[scale_idx_h * 8 + cp * 2],
+                                                        (int32_t) q6_scales[scale_idx_h * 8 + cp * 2 + 1] };
+
+                        // Accumulate scaled results
+                        acc[cp] = vmla_s32(acc[cp], sum_l, scale_vec_l);
+                        acc[cp] = vmla_s32(acc[cp], sum_h, scale_vec_h);
+                    }
+                }
+            }  // for half
+
+            // Bias correction
+            acc[0] = vsub_s32(acc[0], vget_low_s32(bias_lo));
+            acc[1] = vsub_s32(acc[1], vget_high_s32(bias_lo));
+            acc[2] = vsub_s32(acc[2], vget_low_s32(bias_hi));
+            acc[3] = vsub_s32(acc[3], vget_high_s32(bias_hi));
+
+            // Apply superblock scale (no mins for q6_K)
+            // acc[cp] has [c0, c1]
+            float32x2_t w_01 = vmul_f32(vcvt_f32_s32(acc[0]), vget_low_f32(sb_scale_0));
+            float32x2_t w_23 = vmul_f32(vcvt_f32_s32(acc[1]), vget_high_f32(sb_scale_0));
+            float32x2_t w_45 = vmul_f32(vcvt_f32_s32(acc[2]), vget_low_f32(sb_scale_1));
+            float32x2_t w_67 = vmul_f32(vcvt_f32_s32(acc[3]), vget_high_f32(sb_scale_1));
+
+            acc_f32[0] = vaddq_f32(acc_f32[0], vcombine_f32(w_01, w_23));
+            acc_f32[1] = vaddq_f32(acc_f32[1], vcombine_f32(w_45, w_67));
+        }  // for b
+
+        int base = x * ncols_interleaved;
+        vst1q_f32(s + base, acc_f32[0]);
+        vst1q_f32(s + base + 4, acc_f32[1]);
+    }  // for x
+    return;
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    ggml_gemv_q6_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemv_q8_0_4x4_q8_0(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    const int qk                = QK8_0;
+    const int nb                = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen          = 4;
+
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
+
+    for (int c = 0; c < nc; c += ncols_interleaved) {
+        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+        float32x4_t        acc   = vdupq_n_f32(0);
+        for (int b = 0; b < nb; b++) {
+            int8x16x4_t b_low  = vld1q_s8_x4((const int8_t *) b_ptr->qs);
+            int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
+            float16x4_t bd     = vld1_f16((const __fp16 *) b_ptr->d);
+
+            int8x16x2_t a  = vld1q_s8_x2(a_ptr->qs);
+            float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
+
+            int32x4_t ret = vdupq_n_s32(0);
+
+            ret = vdotq_laneq_s32(ret, b_low.val[0], a.val[0], 0);
+            ret = vdotq_laneq_s32(ret, b_low.val[1], a.val[0], 1);
+            ret = vdotq_laneq_s32(ret, b_low.val[2], a.val[0], 2);
+            ret = vdotq_laneq_s32(ret, b_low.val[3], a.val[0], 3);
+
+            ret = vdotq_laneq_s32(ret, b_high.val[0], a.val[1], 0);
+            ret = vdotq_laneq_s32(ret, b_high.val[1], a.val[1], 1);
+            ret = vdotq_laneq_s32(ret, b_high.val[2], a.val[1], 2);
+            ret = vdotq_laneq_s32(ret, b_high.val[3], a.val[1], 3);
+
+            acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
+            a_ptr++;
+            b_ptr++;
+        }
+        vst1q_f32(s, acc);
+        s += ncols_interleaved;
+    }
+    return;
+
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    ggml_gemv_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemv_q8_0_4x8_q8_0(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    const int qk                = QK8_0;
+    const int nb                = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen          = 8;
+
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
+
+    for (int c = 0; c < nc; c += ncols_interleaved) {
+        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+        float32x4_t        acc   = vdupq_n_f32(0);
+
+        for (int b = 0; b < nb; b++) {
+            int8x16x4_t b_low  = vld1q_s8_x4((const int8_t *) b_ptr->qs);
+            int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
+            float16x4_t bd     = vld1_f16((const __fp16 *) b_ptr->d);
+
+            int8x8x4_t  a_chunks = vld1_s8_x4(a_ptr->qs);
+            int8x16_t   a0       = vcombine_s8(a_chunks.val[0], a_chunks.val[0]);
+            int8x16_t   a1       = vcombine_s8(a_chunks.val[1], a_chunks.val[1]);
+            int8x16_t   a2       = vcombine_s8(a_chunks.val[2], a_chunks.val[2]);
+            int8x16_t   a3       = vcombine_s8(a_chunks.val[3], a_chunks.val[3]);
+            float16x4_t ad       = vld1_dup_f16((const __fp16 *) &a_ptr->d);
+
+            int32x4_t ret0 = vdupq_n_s32(0);
+            int32x4_t ret1 = vdupq_n_s32(0);
+
+            // 0..7
+            ret0 = vdotq_s32(ret0, b_low.val[0], a0);
+            ret1 = vdotq_s32(ret1, b_low.val[1], a0);
+            // 8..15
+            ret0 = vdotq_s32(ret0, b_low.val[2], a1);
+            ret1 = vdotq_s32(ret1, b_low.val[3], a1);
+            // 16..23
+            ret0 = vdotq_s32(ret0, b_high.val[0], a2);
+            ret1 = vdotq_s32(ret1, b_high.val[1], a2);
+            // 24..31
+            ret0 = vdotq_s32(ret0, b_high.val[2], a3);
+            ret1 = vdotq_s32(ret1, b_high.val[3], a3);
+
+            int32x4_t ret = vpaddq_s32(ret0, ret1);
+
+            acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
+            a_ptr++;
+            b_ptr++;
+        }
+        vst1q_f32(s, acc);
+        s += ncols_interleaved;
+    }
+    return;
+
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    ggml_gemv_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    const void * b_ptr = vx;
+    const void * a_ptr = vy;
+    float * res_ptr = s;
+    size_t res_stride = bs * sizeof(float);
+
+    __asm__ __volatile__(
+        "mov x10, %x[nr]\n"
+        "mov x9, #0x88\n"
+        "cmp x10, #0x10\n"
+        "mul x9, %x[nb], x9\n"
+        "blt 4f\n"
+        "1:"  // Row loop
+        "add x28, %x[b_ptr], #0x8\n"
+        "mov x27, %x[nc]\n"
+        "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
+        "2:"  // Column loop
+        "add x25, %x[a_ptr], #0x8\n"
+        "movi v15.16b, #0x0\n"
+        "movi v19.16b, #0x0\n"
+        "mov x24, %x[nb]\n"
+        "add x23, x25, x9\n"
+        "movi v18.16b, #0x0\n"
+        "movi v14.16b, #0x0\n"
+        "add x22, x23, x9\n"
+        "movi v11.16b, #0x0\n"
+        "movi v13.16b, #0x0\n"
+        "add x21, x22, x9\n"
+        "movi v23.16b, #0x0\n"
+        "movi v16.16b, #0x0\n"
+        "movi v25.16b, #0x0\n"
+        "movi v7.16b, #0x0\n"
+        "movi v0.16b, #0x0\n"
+        "movi v4.16b, #0x0\n"
+        "movi v5.16b, #0x0\n"
+        "movi v21.16b, #0x0\n"
+        "movi v8.16b, #0x0\n"
+        "movi v1.16b, #0x0\n"
+        "3:"  // Block loop
+        "ldr q3, [x28, #0x0]\n"
+        "ldr q31, [x25, #0x0]\n"
+        "movi v28.16b, #0x4\n"
+        "movi v10.4s, #0x0\n"
+        "ldr q22, [x28, #0x10]\n"
+        "ldr q6, [x25, #0x10]\n"
+        "movi v29.4s, #0x0\n"
+        "movi v9.4s, #0x0\n"
+        "ldr q27, [x28, #0x20]\n"
+        "ldr q30, [x28, #0x30]\n"
+        "movi v20.4s, #0x0\n"
+        "movi v24.16b, #0xf0\n"
+        "ldr d2, [x25, #-0x8]\n"
+        "ldr d26, [x23, #-0x8]\n"
+        "sshl v12.16b, v3.16b, v28.16b\n"
+        "sub x20, x28, #0x8\n"
+        "ldr d17, [x20, #0x0]\n"
+        "and v3.16b, v3.16b, v24.16b\n"
+        "subs x24, x24, #0x1\n"
+        "add x28, x28, #0x48\n"
+        ".inst 0x4f9fe18a  // sdot v10.4s, v12.16b, v31.4b[0]\n"
+        ".inst 0x4fbfe19d  // sdot v29.4s, v12.16b, v31.4b[1]\n"
+        ".inst 0x4f9fe989  // sdot v9.4s, v12.16b, v31.4b[2]\n"
+        ".inst 0x4fbfe994  // sdot v20.4s, v12.16b, v31.4b[3]\n"
+        "sshl v31.16b, v22.16b, v28.16b\n"
+        "and v22.16b, v22.16b, v24.16b\n"
+        "fcvtl v17.4s, v17.4h\n"
+        "fcvtl v2.4s, v2.4h\n"
+        "fcvtl v26.4s, v26.4h\n"
+        ".inst 0x4f86e3ea  // sdot v10.4s, v31.16b, v6.4b[0]\n"
+        ".inst 0x4fa6e3fd  // sdot v29.4s, v31.16b, v6.4b[1]\n"
+        ".inst 0x4f86ebe9  // sdot v9.4s, v31.16b, v6.4b[2]\n"
+        ".inst 0x4fa6ebf4  // sdot v20.4s, v31.16b, v6.4b[3]\n"
+        "sshl v6.16b, v27.16b, v28.16b\n"
+        "sshl v28.16b, v30.16b, v28.16b\n"
+        "and v27.16b, v27.16b, v24.16b\n"
+        "and v30.16b, v30.16b, v24.16b\n"
+        "ldr q24, [x25, #0x20]\n"
+        ".inst 0x4f98e0ca  // sdot v10.4s, v6.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e0dd  // sdot v29.4s, v6.16b, v24.4b[1]\n"
+        ".inst 0x4f98e8c9  // sdot v9.4s, v6.16b, v24.4b[2]\n"
+        ".inst 0x4fb8e8d4  // sdot v20.4s, v6.16b, v24.4b[3]\n"
+        "ldr q24, [x25, #0x30]\n"
+        ".inst 0x4f98e38a  // sdot v10.4s, v28.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e39d  // sdot v29.4s, v28.16b, v24.4b[1]\n"
+        ".inst 0x4f98eb89  // sdot v9.4s, v28.16b, v24.4b[2]\n"
+        ".inst 0x4fb8eb94  // sdot v20.4s, v28.16b, v24.4b[3]\n"
+        "ldr q24, [x25, #0x40]\n"
+        ".inst 0x4f98e06a  // sdot v10.4s, v3.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e07d  // sdot v29.4s, v3.16b, v24.4b[1]\n"
+        ".inst 0x4f98e869  // sdot v9.4s, v3.16b, v24.4b[2]\n"
+        ".inst 0x4fb8e874  // sdot v20.4s, v3.16b, v24.4b[3]\n"
+        "ldr q24, [x25, #0x50]\n"
+        ".inst 0x4f98e2ca  // sdot v10.4s, v22.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e2dd  // sdot v29.4s, v22.16b, v24.4b[1]\n"
+        ".inst 0x4f98eac9  // sdot v9.4s, v22.16b, v24.4b[2]\n"
+        ".inst 0x4fb8ead4  // sdot v20.4s, v22.16b, v24.4b[3]\n"
+        "ldr q24, [x25, #0x60]\n"
+        ".inst 0x4f98e36a  // sdot v10.4s, v27.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e37d  // sdot v29.4s, v27.16b, v24.4b[1]\n"
+        ".inst 0x4f98eb69  // sdot v9.4s, v27.16b, v24.4b[2]\n"
+        ".inst 0x4fb8eb74  // sdot v20.4s, v27.16b, v24.4b[3]\n"
+        "ldr q24, [x25, #0x70]\n"
+        "add x25, x25, #0x88\n"
+        ".inst 0x4f98e3ca  // sdot v10.4s, v30.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e3dd  // sdot v29.4s, v30.16b, v24.4b[1]\n"
+        ".inst 0x4f98ebc9  // sdot v9.4s, v30.16b, v24.4b[2]\n"
+        ".inst 0x4fb8ebd4  // sdot v20.4s, v30.16b, v24.4b[3]\n"
+        "fmul v24.4s, v17.4s, v2.s[0]\n"
+        "scvtf v10.4s, v10.4s, #0x4\n"
+        "scvtf v29.4s, v29.4s, #0x4\n"
+        "scvtf v9.4s, v9.4s, #0x4\n"
+        "scvtf v20.4s, v20.4s, #0x4\n"
+        "fmla v15.4s, v10.4s, v24.4s\n"
+        "ldr q24, [x23, #0x0]\n"
+        "fmul v10.4s, v17.4s, v2.s[1]\n"
+        "fmla v19.4s, v29.4s, v10.4s\n"
+        "ldr q10, [x23, #0x10]\n"
+        "fmul v29.4s, v17.4s, v2.s[2]\n"
+        "fmul v2.4s, v17.4s, v2.s[3]\n"
+        "fmla v18.4s, v9.4s, v29.4s\n"
+        "movi v9.4s, #0x0\n"
+        "movi v29.4s, #0x0\n"
+        ".inst 0x4f98e189  // sdot v9.4s, v12.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e19d  // sdot v29.4s, v12.16b, v24.4b[1]\n"
+        "fmla v14.4s, v20.4s, v2.4s\n"
+        "movi v20.4s, #0x0\n"
+        "movi v2.4s, #0x0\n"
+        ".inst 0x4f98e994  // sdot v20.4s, v12.16b, v24.4b[2]\n"
+        ".inst 0x4fb8e982  // sdot v2.4s, v12.16b, v24.4b[3]\n"
+        "ldr q24, [x23, #0x20]\n"
+        ".inst 0x4f8ae3e9  // sdot v9.4s, v31.16b, v10.4b[0]\n"
+        ".inst 0x4faae3fd  // sdot v29.4s, v31.16b, v10.4b[1]\n"
+        ".inst 0x4f8aebf4  // sdot v20.4s, v31.16b, v10.4b[2]\n"
+        ".inst 0x4faaebe2  // sdot v2.4s, v31.16b, v10.4b[3]\n"
+        "ldr q10, [x23, #0x30]\n"
+        ".inst 0x4f98e0c9  // sdot v9.4s, v6.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e0dd  // sdot v29.4s, v6.16b, v24.4b[1]\n"
+        ".inst 0x4f98e8d4  // sdot v20.4s, v6.16b, v24.4b[2]\n"
+        ".inst 0x4fb8e8c2  // sdot v2.4s, v6.16b, v24.4b[3]\n"
+        "ldr q24, [x23, #0x40]\n"
+        ".inst 0x4f8ae389  // sdot v9.4s, v28.16b, v10.4b[0]\n"
+        ".inst 0x4faae39d  // sdot v29.4s, v28.16b, v10.4b[1]\n"
+        ".inst 0x4f8aeb94  // sdot v20.4s, v28.16b, v10.4b[2]\n"
+        ".inst 0x4faaeb82  // sdot v2.4s, v28.16b, v10.4b[3]\n"
+        "ldr q10, [x23, #0x50]\n"
+        ".inst 0x4f98e069  // sdot v9.4s, v3.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e07d  // sdot v29.4s, v3.16b, v24.4b[1]\n"
+        ".inst 0x4f98e874  // sdot v20.4s, v3.16b, v24.4b[2]\n"
+        ".inst 0x4fb8e862  // sdot v2.4s, v3.16b, v24.4b[3]\n"
+        "ldr q24, [x23, #0x60]\n"
+        ".inst 0x4f8ae2c9  // sdot v9.4s, v22.16b, v10.4b[0]\n"
+        ".inst 0x4faae2dd  // sdot v29.4s, v22.16b, v10.4b[1]\n"
+        ".inst 0x4f8aead4  // sdot v20.4s, v22.16b, v10.4b[2]\n"
+        ".inst 0x4faaeac2  // sdot v2.4s, v22.16b, v10.4b[3]\n"
+        "ldr q10, [x23, #0x70]\n"
+        "add x23, x23, #0x88\n"
+        ".inst 0x4f98e369  // sdot v9.4s, v27.16b, v24.4b[0]\n"
+        ".inst 0x4fb8e37d  // sdot v29.4s, v27.16b, v24.4b[1]\n"
+        ".inst 0x4f98eb74  // sdot v20.4s, v27.16b, v24.4b[2]\n"
+        ".inst 0x4fb8eb62  // sdot v2.4s, v27.16b, v24.4b[3]\n"
+        "ldr q24, [x22, #0x0]\n"
+        ".inst 0x4f8ae3c9  // sdot v9.4s, v30.16b, v10.4b[0]\n"
+        ".inst 0x4faae3dd  // sdot v29.4s, v30.16b, v10.4b[1]\n"
         ".inst 0x4f8aebd4  // sdot v20.4s, v30.16b, v10.4b[2]\n"
         ".inst 0x4faaebc2  // sdot v2.4s, v30.16b, v10.4b[3]\n"
         "fmul v10.4s, v17.4s, v26.s[0]\n"
@@ -2120,89 +3157,1372 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
         );
         return;
     }
-#endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
+#endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
+
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
+    ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
+
+            float32x4_t sumf[4];
+            for (int m = 0; m < 4; m++) {
+                sumf[m] = vdupq_n_f32(0);
+            }
+
+            for (int l = 0; l < nb; l++) {
+                float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
+                float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
+
+                int32x4_t sumi_0 = vdupq_n_s32(0);
+                int32x4_t sumi_1 = vdupq_n_s32(0);
+                int32x4_t sumi_2 = vdupq_n_s32(0);
+                int32x4_t sumi_3 = vdupq_n_s32(0);
+
+                for (int k = 0; k < 4; k++) {
+                    int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
+                    int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
+
+                    uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
+                    int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
+                    int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
+
+                    sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
+                    sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
+                    sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
+                    sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
+                    sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
+                    sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
+                    sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
+                    sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
+                }
+
+                sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
+                sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
+                sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
+                sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
+            }
+
+            for (int m = 0; m < 4; m++) {
+                vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
+            }
+        }
+    }
+    return;
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
+    ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    const int8x16_t kvalues = vld1q_s8(kvalues_mxfp4);
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);
+
+            float32x4_t sumf[4];
+            for (int m = 0; m < 4; m++) {
+                sumf[m] = vdupq_n_f32(0);
+            }
+
+            for (int l = 0; l < nb; l++) {
+                float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
+                float32x4_t b_d = {
+                    GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[0]),
+                    GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[1]),
+                    GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[2]),
+                    GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[3]),
+                };
+
+                int32x4_t sumi_0 = vdupq_n_s32(0);
+                int32x4_t sumi_1 = vdupq_n_s32(0);
+                int32x4_t sumi_2 = vdupq_n_s32(0);
+                int32x4_t sumi_3 = vdupq_n_s32(0);
+
+                for (int k = 0; k < 4; k++) {
+                    int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
+                    int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
+
+                    uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
+                    int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
+                    int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
+
+                    sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
+                    sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
+                    sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
+                    sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
+                    sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
+                    sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
+                    sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
+                    sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
+                }
+
+                sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
+                sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
+                sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
+                sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
+            }
+
+            for (int m = 0; m < 4; m++) {
+                vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
+            }
+        }
+    }
+    return;
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
+    ggml_gemm_mxfp4_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    constexpr int qk = QK_K;
+    const int     nb = n / qk;
+
+    constexpr int ncols_interleaved = 8;
+    constexpr int blocklen          = 4;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    constexpr int    q8_k_blocklen = 4;
+    constexpr int    acc_size  = 2 * 4;  // 2 row pairs × 4 col pairs
+    const uint8x16_t m4b       = vdupq_n_u8(0x0f);
+
+    // 8 accumulators: 2 row pairs × 4 col pairs
+    float32x4_t acc_f32[acc_size];
+
+    for (int y = 0; y < nr / q8_k_blocklen; y++) {
+        const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
+
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
+
+            for (int i = 0; i < acc_size; i++) {
+                acc_f32[i] = vdupq_n_f32(0);
+            }
+
+            for (int b = 0; b < nb; b++) {
+                // d4 0 1 2 3, 4 5 6 7
+                float32x4_t q4_d_0123    = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d));
+                float32x4_t q4_d_4567    = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4));
+                // d8 0 1 2 3
+                float32x4_t q8_d_0123    = vld1q_f32(q8_ptr[b].d);
+                // mins
+                float32x4_t q4_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin));
+                float32x4_t q4_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4));
+
+                // Precomputation of scales and mins
+                float32x4_t sbd_scale_0123[q8_k_blocklen];
+                float32x4_t sbd_scale_4567[q8_k_blocklen];
+                float32x4_t sbd_min_0123[q8_k_blocklen];
+                float32x4_t sbd_min_4567[q8_k_blocklen];
+
+                sbd_scale_0123[0] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 0);
+                sbd_scale_4567[0] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 0);
+                sbd_min_0123[0]   = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 0);
+                sbd_min_4567[0]   = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 0);
+
+                sbd_scale_0123[1] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 1);
+                sbd_scale_4567[1] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 1);
+                sbd_min_0123[1]   = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 1);
+                sbd_min_4567[1]   = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 1);
+
+                sbd_scale_0123[2] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 2);
+                sbd_scale_4567[2] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 2);
+                sbd_min_0123[2]   = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 2);
+                sbd_min_4567[2]   = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 2);
+
+                sbd_scale_0123[3] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 3);
+                sbd_scale_4567[3] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 3);
+                sbd_min_0123[3]   = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 3);
+                sbd_min_4567[3]   = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 3);
+
+                // Precomputation of bsums, each vpaddq calcs all the bsums for each row
+                const int16x8_t bsums[q8_k_blocklen] = {
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
+                };
+                int16_t bsums_arr[QK_K / 64][8];
+                for (int q8_row = 0; q8_row < 4; q8_row++) {
+                    vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
+                }
+
+                // interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
+                int32x4_t bias_acc[acc_size];
+                for (int i = 0; i < acc_size; i++) {
+                    bias_acc[i] = vdupq_n_s32(0);
+                }
+
+                for (int sb = 0; sb < QK_K / 64; sb++) {
+                    // Int accumulators for qs vecdot (4 row x 2 col quartets)
+                    int32x4_t acc_lo[acc_size];
+                    int32x4_t acc_hi[acc_size];
+                    for (int i = 0; i < acc_size; i++) {
+                        acc_lo[i] = vdupq_n_s32(0);
+                        acc_hi[i] = vdupq_n_s32(0);
+                    }
+                    // Need scales for the low and high nibbles
+                    // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
+                    int16x8_t q4sb_scales[2];
+                    int16x8_t q4sb_mins[2];
+                    for (int i = 0; i < 2; i++) {
+                        int8_t    aux_q4sb[8];
+                        const int offset = sb * 24 + i * 12;
+                        decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
+                        q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
+                    }
+
+                    constexpr int reads_per_sb = 8;  // 8 * 16 bytes each => 32 qs * 4 rows
+                    for (int k = 0; k < reads_per_sb; k++) {
+                        const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
+                        const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
+
+                        // 0..3 & 32..35
+                        const uint8x16_t q4_0123 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k);
+                        const uint8x16_t q4_4567 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k + 16);
+
+                        const int8x16_t q4_0123_lo = vreinterpretq_s8_u8(vandq_u8(q4_0123, m4b));
+                        const int8x16_t q4_0123_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_0123, 4));
+
+                        acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q4_0123_lo, q8_blk0, 0);  //  0..3  r0 c0123
+                        acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q4_0123_lo, q8_blk0, 1);  //  0..3  r1 c0123
+                        acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q4_0123_lo, q8_blk0, 2);  //  0..3  r2 c0123
+                        acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q4_0123_lo, q8_blk0, 3);  //  0..3  r3 c0123
+
+                        acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q4_0123_hi, q8_blk1, 0);  // 32..35 r0 c0123
+                        acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q4_0123_hi, q8_blk1, 1);  // 32..35 r1 c0123
+                        acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q4_0123_hi, q8_blk1, 2);  // 32..35 r2 c0123
+                        acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q4_0123_hi, q8_blk1, 3);  // 32..35 r3 c0123
+
+                        const int8x16_t q4_4567_lo = vreinterpretq_s8_u8(vandq_u8(q4_4567, m4b));
+                        const int8x16_t q4_4567_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_4567, 4));
+
+                        acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q4_4567_lo, q8_blk0, 0);  //  0..3  r0 c4567
+                        acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q4_4567_lo, q8_blk0, 1);  //  0..3  r1 c4567
+                        acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q4_4567_lo, q8_blk0, 2);  //  0..3  r2 c4567
+                        acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q4_4567_lo, q8_blk0, 3);  //  0..3  r3 c4567
+
+                        acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q4_4567_hi, q8_blk1, 0);  // 32..35 r0 c4567
+                        acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q4_4567_hi, q8_blk1, 1);  // 32..35 r1 c4567
+                        acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q4_4567_hi, q8_blk1, 2);  // 32..35 r2 c4567
+                        acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q4_4567_hi, q8_blk1, 3);  // 32..35 r3 c4567
+                    }
+
+                    // Scale and bias application
+                    // acc is stored interleaved to match output layout
+                    const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
+                    const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
+                    const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
+                    const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
+                    for (int row = 0; row < q8_k_blocklen; row++) {
+                        // Bias correction
+                        // row c0123 blk0 and blk1
+                        const float32x4_t sumf_0123 =
+                            vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
+                                                    vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
+                        acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
+
+                        // row c4567 blk0 and blk1
+                        const float32x4_t sumf_4567 =
+                            vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
+                                                    vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
+                        acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
+
+                        // Bias
+                        const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
+                        const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
+
+                        // row c0123 blk0 and blk1
+                        bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
+                        bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
+
+                        // row c4567 blk0 and blk1
+                        bias_acc[2 * row + 1] =
+                            vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
+                        bias_acc[2 * row + 1] =
+                            vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
+                    }
+                }  // for sb
+
+                for (int row = 0; row < q8_k_blocklen; row++) {
+                    acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
+                    acc_f32[2 * row + 1] =
+                        vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
+                }
+            }  // for b
+
+            for (int i = 0; i < q8_k_blocklen; i++) {
+                int row = y * q8_k_blocklen + i;
+                for (int j = 0; j < 2; j++) {
+                    int col    = x * ncols_interleaved + j * 4;
+                    int offset = row * bs + col;
+                    vst1q_f32(s + offset, acc_f32[2 * i + j]);
+                }
+            }
+        }  // for x
+    }  // for y
+    return;
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemm_q5_K_8x4_q8_K(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    constexpr int qk = QK_K;
+    const int     nb = n / qk;
+
+    constexpr int ncols_interleaved = 8;
+    constexpr int blocklen          = 4;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    constexpr int    q8_k_blocklen = 4;
+    constexpr int    acc_size      = 2 * 4;  // 2 row pairs, 4 col pairs
+    constexpr int    col_groups    = ncols_interleaved / 4;
+    const uint8x16_t m4b           = vdupq_n_u8(0x0f);
+    const uint8x16_t mone          = vdupq_n_u8(1);
+    const uint8x16_t mtwo          = vdupq_n_u8(2);
+
+    // 8 accumulators: 2 row pairs, 4 col pairs
+    float32x4_t acc_f32[acc_size];
+
+    for (int y = 0; y < nr / q8_k_blocklen; y++) {
+        const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
+
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
+
+            for (int i = 0; i < acc_size; i++) {
+                acc_f32[i] = vdupq_n_f32(0);
+            }
+
+            for (int b = 0; b < nb; b++) {
+                // d5 0 1 2 3, 4 5 6 7
+                float32x4_t q5_d_0123    = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d));
+                float32x4_t q5_d_4567    = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4));
+                // d8 0 1 2 3
+                float32x4_t q8_d_0123    = vld1q_f32(q8_ptr[b].d);
+                // mins
+                float32x4_t q5_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin));
+                float32x4_t q5_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4));
+
+                // Precomputation of scales and mins
+                float32x4_t sbd_scale_0123[q8_k_blocklen];
+                float32x4_t sbd_scale_4567[q8_k_blocklen];
+                float32x4_t sbd_min_0123[q8_k_blocklen];
+                float32x4_t sbd_min_4567[q8_k_blocklen];
+
+                sbd_scale_0123[0] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 0);
+                sbd_scale_4567[0] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 0);
+                sbd_min_0123[0]   = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 0);
+                sbd_min_4567[0]   = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 0);
+
+                sbd_scale_0123[1] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 1);
+                sbd_scale_4567[1] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 1);
+                sbd_min_0123[1]   = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 1);
+                sbd_min_4567[1]   = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 1);
+
+                sbd_scale_0123[2] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 2);
+                sbd_scale_4567[2] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 2);
+                sbd_min_0123[2]   = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 2);
+                sbd_min_4567[2]   = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 2);
+
+                sbd_scale_0123[3] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 3);
+                sbd_scale_4567[3] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 3);
+                sbd_min_0123[3]   = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 3);
+                sbd_min_4567[3]   = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 3);
+
+                // Precomputation of bsums, each vpaddq calcs all the bsums for each row
+                const int16x8_t bsums[q8_k_blocklen] = {
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
+                };
+                int16_t bsums_arr[QK_K / 64][8];
+                for (int q8_row = 0; q8_row < 4; q8_row++) {
+                    vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
+                }
+
+                // interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
+                int32x4_t bias_acc[acc_size];
+                for (int i = 0; i < acc_size; i++) {
+                    bias_acc[i] = vdupq_n_s32(0);
+                }
+
+                uint8x16_t qh[col_groups][8];
+                for (int c = 0; c < col_groups; c++) {
+                    for (int i = 0; i < 8; i++) {
+                        qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c);
+                    }
+                }
+
+                for (int sb = 0; sb < QK_K / 64; sb++) {
+                    // Int accumulators for qs vecdot (4 row * 2 col quartets)
+                    int32x4_t acc_lo[acc_size];
+                    int32x4_t acc_hi[acc_size];
+                    for (int i = 0; i < acc_size; i++) {
+                        acc_lo[i] = vdupq_n_s32(0);
+                        acc_hi[i] = vdupq_n_s32(0);
+                    }
+                    // Need scales for the low and high nibbles
+                    // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
+                    int16x8_t q5sb_scales[2];
+                    int16x8_t q5sb_mins[2];
+                    for (int i = 0; i < 2; i++) {
+                        int8_t    aux_q5sb[8];
+                        const int offset = sb * 24 + i * 12;
+                        decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
+                        q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
+                    }
+
+                    constexpr int reads_per_sb = 8;  // 8 * 16 bytes each => 32 qs * 4 rows
+                    for (int k = 0; k < reads_per_sb; k++) {
+                        const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
+                        const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
+
+                        // 0..3 & 32..35
+                        const uint8x16_t q5_0123 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k);
+                        const uint8x16_t q5_4567 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k + 16);
+
+                        // NOTE: This is the only difference with q4_K
+                        const uint8x16_t hbit_lo_0123 = vandq_u8(qh[0][k], mone);
+                        const uint8x16_t hbit_hi_0123 = vshlq_n_u8(vandq_u8(qh[0][k], mtwo), 3);
+                        qh[0][k]                      = vshrq_n_u8(qh[0][k], 2);
+                        const uint8x16_t hbit_lo_4567 = vandq_u8(qh[1][k], mone);
+                        const uint8x16_t hbit_hi_4567 = vshlq_n_u8(vandq_u8(qh[1][k], mtwo), 3);
+                        qh[1][k]                      = vshrq_n_u8(qh[1][k], 2);
+                        // From here, same as q4_K
+
+                        const int8x16_t q5_0123_lo =
+                            vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_0123, m4b), hbit_lo_0123, 4));
+                        const int8x16_t q5_0123_hi =
+                            vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_0123, 4), hbit_hi_0123));
+
+                        acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q5_0123_lo, q8_blk0, 0);  //  0..3  r0 c0123
+                        acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q5_0123_lo, q8_blk0, 1);  //  0..3  r1 c0123
+                        acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q5_0123_lo, q8_blk0, 2);  //  0..3  r2 c0123
+                        acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q5_0123_lo, q8_blk0, 3);  //  0..3  r3 c0123
+
+                        acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q5_0123_hi, q8_blk1, 0);  // 32..35 r0 c0123
+                        acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q5_0123_hi, q8_blk1, 1);  // 32..35 r1 c0123
+                        acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q5_0123_hi, q8_blk1, 2);  // 32..35 r2 c0123
+                        acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q5_0123_hi, q8_blk1, 3);  // 32..35 r3 c0123
+
+                        const int8x16_t q5_4567_lo =
+                            vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_4567, m4b), hbit_lo_4567, 4));
+                        const int8x16_t q5_4567_hi =
+                            vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_4567, 4), hbit_hi_4567));
+
+                        acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q5_4567_lo, q8_blk0, 0);  //  0..3  r0 c4567
+                        acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q5_4567_lo, q8_blk0, 1);  //  0..3  r1 c4567
+                        acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q5_4567_lo, q8_blk0, 2);  //  0..3  r2 c4567
+                        acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q5_4567_lo, q8_blk0, 3);  //  0..3  r3 c4567
+
+                        acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q5_4567_hi, q8_blk1, 0);  // 32..35 r0 c4567
+                        acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q5_4567_hi, q8_blk1, 1);  // 32..35 r1 c4567
+                        acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q5_4567_hi, q8_blk1, 2);  // 32..35 r2 c4567
+                        acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q5_4567_hi, q8_blk1, 3);  // 32..35 r3 c4567
+                    }
+
+                    // Scale and bias application
+                    // acc is stored interleaved to match output layout
+                    const int16x4_t sc_0123_lo = vget_low_s16(q5sb_scales[0]);
+                    const int16x4_t sc_4567_lo = vget_high_s16(q5sb_scales[0]);
+                    const int16x4_t sc_0123_hi = vget_low_s16(q5sb_scales[1]);
+                    const int16x4_t sc_4567_hi = vget_high_s16(q5sb_scales[1]);
+                    for (int row = 0; row < q8_k_blocklen; row++) {
+                        // Bias correction
+                        // row c0123 blk0 and blk1
+                        const float32x4_t sumf_0123 =
+                            vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
+                                                    vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
+                        acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
+
+                        // row c4567 blk0 and blk1
+                        const float32x4_t sumf_4567 =
+                            vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
+                                                    vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
+                        acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
+
+                        // Bias
+                        const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
+                        const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
+
+                        // row c0123 blk0 and blk1
+                        bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
+                        bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
+
+                        // row c4567 blk0 and blk1
+                        bias_acc[2 * row + 1] =
+                            vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
+                        bias_acc[2 * row + 1] =
+                            vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
+                    }
+                }  // for sb
+
+                for (int row = 0; row < q8_k_blocklen; row++) {
+                    acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
+                    acc_f32[2 * row + 1] =
+                        vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
+                }
+            }  // for b
+
+            for (int i = 0; i < q8_k_blocklen; i++) {
+                int row = y * q8_k_blocklen + i;
+                for (int j = 0; j < 2; j++) {
+                    int col    = x * ncols_interleaved + j * 4;
+                    int offset = row * bs + col;
+                    vst1q_f32(s + offset, acc_f32[2 * i + j]);
+                }
+            }
+        }  // for x
+    }  // for y
+    return;
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    ggml_gemm_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemm_q4_K_8x8_q8_K(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    constexpr int qk = QK_K;
+    const int     nb = n / qk;
+
+    constexpr int ncols_interleaved = 8;
+    constexpr int blocklen          = 8;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
+    if (svcntb() * 8 == 256) {
+        constexpr int    q8_k_blocklen = 4;
+        const svuint8_t m4b_1          = svdup_n_u8(0x0f);
+        // 8 accumulators: 2 row pairs × 4 col pairs
+        svfloat32_t acc_f32_01, acc_f32_23, acc_f32_45, acc_f32_67;
+        uint32_t idx_arr[8] = { 0, 2, 4, 6,  1, 3, 5, 7 };
+        svbool_t pg = svptrue_pat_b32(SV_VL8);
+        svuint32_t idx = svld1(pg, idx_arr);
+
+        static const uint32_t idx_data[8] = {0, 4, 2, 6, 1, 5, 3, 7};
+        svuint32_t idx1 = svld1_u32(svptrue_b32(), idx_data);
+
+        for (int y = 0; y < nr / q8_k_blocklen; y++) {
+            const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
+
+            for (int x = 0; x < nc / ncols_interleaved; x++) {
+                const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
+
+                acc_f32_01 = svdup_n_f32(0);
+                acc_f32_23 = svdup_n_f32(0);
+                acc_f32_45 = svdup_n_f32(0);
+                acc_f32_67 = svdup_n_f32(0);
+
+                for (int b = 0; b < nb; b++) {
+                    // bsums pairs belongs to the same q8_k subblock
+                    // 64 elemnts loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum
+                    const int16x8_t bsums[4]{
+                        vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
+                        vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
+                        vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
+                        vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
+                    };
+
+                    int32_t bsums_arr32[4][8];
+
+                    for (int q8_row = 0; q8_row < 4; q8_row++) {
+                        int16x8_t v16 = bsums[q8_row];
+
+                        // low 4
+                        int32x4_t v32_lo = vmovl_s16(vget_low_s16(v16));
+                        vst1q_s32(&bsums_arr32[q8_row][0], v32_lo);
+
+                        // high 4
+                        int32x4_t v32_hi = vmovl_s16(vget_high_s16(v16));
+                        vst1q_s32(&bsums_arr32[q8_row][4], v32_hi);
+                    }
+
+                    svint32_t sb_acc_0 = svdup_n_s32(0);
+                    svint32_t sb_acc_2 = svdup_n_s32(0);
+
+                    svint32_t acc_00 = svdup_n_s32(0);
+                    svint32_t acc_11 = svdup_n_s32(0);
+                    svint32_t acc_22 = svdup_n_s32(0);
+                    svint32_t acc_33 = svdup_n_s32(0);
+                    svint32_t acc_44 = svdup_n_s32(0);
+                    svint32_t acc_55 = svdup_n_s32(0);
+                    svint32_t acc_66 = svdup_n_s32(0);
+                    svint32_t acc_77 = svdup_n_s32(0);
+
+                    svint32_t bias_acc_00 = svdup_n_s32(0);
+                    svint32_t bias_acc_22 = svdup_n_s32(0);
+                    svint32_t bias_acc_44 = svdup_n_s32(0);
+                    svint32_t bias_acc_66 = svdup_n_s32(0);
+
+                    for (int sb = 0; sb < QK_K / 64; sb++) {
+                        // Need scales for the low and high nibbles
+                        // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
+                        svint32_t block_scale_0, block_scale_1, block_scale_2, block_scale_3;
+                        svint32_t q4sb_mins_0, q4sb_mins_1;
+                        {
+                            // 2-superblock I am working on
+                            const int offset = sb * 24 + 0 * 12;
+                            const uint8_t * scales_in = &q4_ptr[b].scales[offset];
+
+                            const int offset1 = sb * 24 + 12;
+                            const uint8_t * scales_in1 = &q4_ptr[b].scales[offset1];
+
+                            constexpr uint32_t kmask1 = 0x3f3f3f3f;
+                            constexpr uint32_t kmask2 = 0x0f0f0f0f;
+                            constexpr uint32_t kmask3 = 0x03030303;
+                            constexpr uint8_t  scales_size = 12;
+
+                            uint32_t sm[3];
+                            memcpy(sm, scales_in, scales_size);
+
+                            uint32_t sm1[3];
+                            memcpy(sm1, scales_in1, scales_size);
+
+                            const uint32_t mins_0_3 = sm[1] & kmask1;
+                            const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
+
+                            const uint32_t mins_0_3_1 = sm1[1] & kmask1;
+                            const uint32_t mins_4_7_1 = ((sm1[2] >> 4) & kmask2) | (((sm1[1] >> 6) & kmask3) << 4);
+
+                            svuint32_t mins_u32_temp = svzip1_u32(svdup_n_u32(mins_0_3), svdup_n_u32(mins_4_7));
+                            svuint32_t mins_u32_temp_1 = svzip1_u32(svdup_n_u32(mins_0_3_1), svdup_n_u32(mins_4_7_1));
+
+                            /* reinterpret u32 → u8 */
+                            svuint8_t mins_u8 = svreinterpret_u8_u32(mins_u32_temp);
+                            svuint8_t mins_u8_1 = svreinterpret_u8_u32(mins_u32_temp_1);
+
+                            /* widen u8 → u16->u32 (lower half only) */
+                            svuint32_t mins_u16 = svunpklo_u32(svunpklo_u16(mins_u8));
+                            svuint32_t mins_u16_1 = svunpklo_u32(svunpklo_u16(mins_u8_1));
+
+                            q4sb_mins_0 = svreinterpret_s32_u32(mins_u16);
+                            q4sb_mins_1 = svreinterpret_s32_u32(mins_u16_1);
+
+                            uint32_t scales_u32_0 = sm[0] & kmask1;
+                            uint32_t scales_u32_1 = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
+                            uint32_t scales_u32_2 = sm1[0] & kmask1;
+                            uint32_t scales_u32_3 = (sm1[2] & kmask2) | (((sm1[0] >> 6) & kmask3) << 4);
+
+                            svuint32_t S01 = svdup_n_u32(scales_u32_0);
+                            svuint32_t S23 = svdup_n_u32(scales_u32_1);
+                            svuint32_t R01 = svdup_n_u32(scales_u32_2);
+                            svuint32_t R23 = svdup_n_u32(scales_u32_3);
+
+                            svint8_t S01_b = svreinterpret_s8_u32(S01);
+                            svint8_t S23_b = svreinterpret_s8_u32(S23);
+                            svint8_t R01_b = svreinterpret_s8_u32(R01);
+                            svint8_t R23_b = svreinterpret_s8_u32(R23);
+
+                            svint32_t S01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S01_b, S01_b)));
+                            svint32_t R01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R01_b, R01_b)));
+                            svint32_t S23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S23_b, S23_b)));
+                            svint32_t R23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R23_b, R23_b)));
+
+                            block_scale_0 = svtbl_s32(svzip1_s32(S01_d, R01_d), idx);
+                            block_scale_1 = svtbl_s32(svzip2_s32(S01_d, R01_d), idx);
+                            block_scale_2 = svtbl_s32(svzip1_s32(S23_d, R23_d), idx);
+                            block_scale_3 = svtbl_s32(svzip2_s32(S23_d, R23_d), idx);
+                        }
+
+                        const int8_t * q8_base_1 = q8_ptr[b].qs + sb * 256;
+
+                        // Load 32-byte per row pair, 1 subblock each time
+                        // predicate for activating higher lanes for 16 int8 elements
+                        const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
+                        // predicate for activating lower lanes for  16 int8 elements
+                        const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
+
+                        svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112));
+                        svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144));
+                        svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176));
+                        svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208));
+
+                        svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128));
+                        svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160));
+                        svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192));
+                        svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224));
+
+                        // Q4s columns iterated in pairs (01, 23, 45, 67)
+                        for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
+
+                            sb_acc_0 = svdup_n_s32(0);
+                            sb_acc_2 = svdup_n_s32(0);
+
+                            svuint8_t q4_qs_cp_00 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 0);
+                            svuint8_t q4_qs_cp_01 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 64);
+                            svuint8_t q4_qs_cp_02 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 128);
+                            svuint8_t q4_qs_cp_03 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 192);
+
+                            svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4));
+                            svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4));
+                            svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4));
+                            svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4));
+
+                            sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_00, q8_qs_0);
+                            sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_01, q8_qs_2);
+
+                            sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_02, q8_qs_4);
+                            sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_03, q8_qs_6);
+
+                            sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_00, q8_qs_1);
+                            sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_01, q8_qs_3);
+
+                            sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_02, q8_qs_5);
+                            sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_03, q8_qs_7);
+
+                            if(cp == 0) {
+                                acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_0, block_scale_0);
+                                acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_2, block_scale_0);
+                            }
+                            if(cp == 1) {
+                                acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_0, block_scale_1);
+                                acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_2, block_scale_1);
+                            }
+                            if(cp == 2) {
+                                acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_0, block_scale_2);
+                                acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_2, block_scale_2);
+                            }
+                            if(cp == 3) {
+                                acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_0, block_scale_3);
+                                acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_2, block_scale_3);
+                            }
+                        }
+
+                        bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][0]), q4sb_mins_0);
+                        bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][1]), q4sb_mins_1);
+
+                        bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][2]), q4sb_mins_0);
+                        bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][3]), q4sb_mins_1);
+
+                        bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][4]), q4sb_mins_0);
+                        bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][5]), q4sb_mins_1);
+
+                        bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][6]), q4sb_mins_0);
+                        bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][7]), q4sb_mins_1);
+                    }  // for sb
+
+
+                    acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_00, svext_s32(acc_00, acc_00, 4));
+                    acc_11 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_11, svext_s32(acc_11, acc_11, 4));
+                    acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_22, svext_s32(acc_22, acc_22, 4));
+                    acc_33 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_33, svext_s32(acc_33, acc_33, 4));
+                    acc_44 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_44, svext_s32(acc_44, acc_44, 4));
+                    acc_55 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_55, svext_s32(acc_55, acc_55, 4));
+                    acc_66 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_66, svext_s32(acc_66, acc_66, 4));
+                    acc_77 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_77, svext_s32(acc_77, acc_77, 4));
+
+                    svint32_t reorder_acc_01 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_00, acc_11), svtrn1_s32(acc_22, acc_33)), idx1);
+                    svint32_t reorder_acc_23 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_00, acc_11), svtrn2_s32(acc_22, acc_33)), idx1);
+
+                    svint32_t reorder_acc_45 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_44, acc_55), svtrn1_s32(acc_66, acc_77)), idx1);
+                    svint32_t reorder_acc_67 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_44, acc_55), svtrn2_s32(acc_66, acc_77)), idx1);
+
+                    // Broadcast q8 scalar
+                    svfloat32_t q8_d = svdup_f32(q8_ptr[b].d[0]);
+
+                    svfloat32_t q4_dmin_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].dmin), svdup_f16(0)));
+
+                    svfloat32_t q4_d_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].d), svdup_f16(0)));
+
+                    svfloat32_t scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
+                    svfloat32_t dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
+
+                    acc_f32_01 = svmls_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_00), dmins1);
+                    acc_f32_01 = svmla_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_01), scale1);
+
+                    q8_d = svdup_f32(q8_ptr[b].d[1]);
+
+                    scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
+                    dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
+
+                    acc_f32_23 = svmls_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_22), dmins1);
+                    acc_f32_23 = svmla_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_23), scale1);
+
+                    q8_d = svdup_f32(q8_ptr[b].d[2]);
+
+
+                    scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
+                    dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
+
+                    acc_f32_45 = svmls_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_44), dmins1);
+                    acc_f32_45 = svmla_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_45), scale1);
+
+                    q8_d = svdup_f32(q8_ptr[b].d[3]);
+
+                    scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
+                    dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
+
+                    acc_f32_67 = svmls_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_66), dmins1);
+                    acc_f32_67 = svmla_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_67), scale1);
+
+                }  // for b
+
+                // With the previous reorder, the tile is already in the correct memory layout.
+                // Predicate for exactly 4 lanes
+                svbool_t pg4 = svptrue_pat_b32(SV_VL4);
+                for (int i = 0; i < q8_k_blocklen; i++) {
+                    int row = y * q8_k_blocklen + i;
+                    for (int j = 0; j < 2; j++) {
+                        int col    = x * ncols_interleaved + j * 4;
+                        int offset = row * bs + col;
+
+                        if (i == 0 && j == 0) {
+                            // acc_f32_0 → lower half of acc_f32_01
+                            svst1_f32(pg4, s + offset, acc_f32_01);
+                        } else if (i == 0 && j == 1) {
+                            // acc_f32_1 → upper half of acc_f32_01
+                            svst1_f32(pg4, s + offset, svext_f32(acc_f32_01, acc_f32_01, 4));
+                        } else if (i == 1 && j == 0) {
+                            // acc_f32_2
+                            svst1_f32(pg4, s + offset, acc_f32_23);
+                        } else if (i == 1 && j == 1) {
+                            // acc_f32_3
+                            svst1_f32(pg4, s + offset, svext_f32(acc_f32_23, acc_f32_23, 4));
+                        } else if (i == 2 && j == 0) {
+                            // acc_f32_4
+                            svst1_f32(pg4, s + offset, acc_f32_45);
+                        } else if (i == 2 && j == 1) {
+                            // acc_f32_5
+                            svst1_f32(pg4, s + offset, svext_f32(acc_f32_45, acc_f32_45, 4));
+                        } else if (i == 3 && j == 0) {
+                            // acc_f32_6
+                            svst1_f32(pg4, s + offset, acc_f32_67);
+                        } else if (i == 3 && j == 1) {
+                            // acc_f32_7
+                            svst1_f32(pg4, s + offset, svext_f32(acc_f32_67, acc_f32_67, 4));
+                        }
+                    }
+                }
+            }  // for x
+        }  // for y
+        return;
+    }
+#endif  // SVE compile-time end
 
-#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
-    ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    constexpr int    q8_k_blocklen = 4;
+    const uint8x16_t m4b           = vdupq_n_u8(0x0f);
+
+    // 8 accumulators: 2 row pairs × 4 col pairs
+    float32x4_t acc_f32[blocklen];
+
+    for (int y = 0; y < nr / q8_k_blocklen; y++) {
+        const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
+
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
+
+            for (int i = 0; i < blocklen; i++) {
+                acc_f32[i] = vdupq_n_f32(0);
+            }
+
+            for (int b = 0; b < nb; b++) {
+                // bsums pairs belongs to the same q8_k subblock
+                const int16x8_t bsums[4]{
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
+                };
+                int16_t bsums_arr[4][8];
+                for (int q8_row = 0; q8_row < 4; q8_row++) {
+                    vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
+                }
+
+                int32x4_t sb_acc[4];    // Aux accumulators to store subblock (partial) results
+                int32x4_t acc[8];       // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
+                int32x4_t bias_acc[8];  // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
+                for (int i = 0; i < 8; i++) {
+                    acc[i]      = vdupq_n_s32(0);
+                    bias_acc[i] = vdupq_n_s32(0);
+                }
+
+                for (int sb = 0; sb < QK_K / 64; sb++) {
+                    // Need scales for the low and high nibbles
+                    // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
+                    int8_t    q4sb_scales[2][8];
+                    int16x8_t q4sb_mins[2];  // int16 as its needed for bias_acc later
+                    for (int i = 0; i < 2; i++) {
+                        const int offset = sb * 24 + i * 12;
+                        decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
+                    }
+
+                    // q8_ptr[b].qs has interleaved Q8 rows (01, 23)
+                    const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
+
+                    int8x16_t q8_qs_01[8];
+                    int8x16_t q8_qs_23[8];
+
+                    // Load 32-byte per row pair, 1 subblock each time
+                    for (int i = 0; i < 8; i++) {
+                        const int offset = i * 32;  // 16 for row 01, 16 for row 23
+                        q8_qs_01[i]      = vld1q_s8(q8_base + offset);
+                        q8_qs_23[i]      = vld1q_s8(q8_base + offset + 16);
+                    }
+
+                    const int8x16_t q8s[2][8] = {
+                        { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3],
+                          q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], q8_qs_01[7] },
+                        { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3],
+                          q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], q8_qs_23[7] },
+                    };
+
+                    // Q4s columns iterated in pairs (01, 23, 45, 67)
+                    for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
+                        for (int i = 0; i < 4; i++) {
+                            sb_acc[i] = vdupq_n_s32(0);
+                        }
+
+                        uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 0);    // 0 .. 7 & 32..39
+                        uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 64);   // 8 ..15 & 40..47
+                        uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 128);  // 16..23 & 48..55
+                        uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 192);  // 24..31 & 56..63
+                        const int8x16_t q4_nibbles[2][4] = {
+                            {
+                                vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)),
+                                vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)),
+                                vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)),
+                                vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)),
+                            },
+                            {
+                                vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)),
+                                vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)),
+                                vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)),
+                                vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)),
+                            }
+                        };
+
+                        // Calculates the Qs muladd of every row pair (rp) rows 01 and 23 of q8
+                        // for each of the internal 32 qs subblock (blk)
+                        for (int rp = 0; rp < 2; rp++) {
+                            for (int blk = 0; blk < 2; blk++) {
+                                const int8x16_t * q8  = &q8s[rp][4 * blk];
+                                const int8x16_t * q4  = q4_nibbles[blk];
+                                int32x4_t         acc = sb_acc[2 * rp + blk];
+                                // mul add for each qs in the same subblock
+                                for (int qs_offset = 0; qs_offset < 4; qs_offset++) {
+                                    acc = vmmlaq_s32(acc, q4[qs_offset], q8[qs_offset]);
+                                }
+                                sb_acc[2 * rp + blk] = acc;
+                            }
+                        }
+
+                        // Scales[i] corresponds to column i
+                        const int scale_offset = cp * 2;
+                        const int32_t scale_00 = q4sb_scales[0][scale_offset];
+                        const int32_t scale_01 = q4sb_scales[0][scale_offset + 1];
+                        const int32_t scale_10 = q4sb_scales[1][scale_offset];
+                        const int32_t scale_11 = q4sb_scales[1][scale_offset + 1];
+                        const int32x4_t block_scale_0 = vcombine_s32(vdup_n_s32(scale_00), vdup_n_s32(scale_01));
+                        const int32x4_t block_scale_1 = vcombine_s32(vdup_n_s32(scale_10), vdup_n_s32(scale_11));
+
+                        acc[cp]     = vmlaq_s32(acc[cp], sb_acc[0], block_scale_0);
+                        acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale_0);
+                        acc[cp]     = vmlaq_s32(acc[cp], sb_acc[1], block_scale_1);
+                        acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale_1);
+                    }
+
+                    // Multiply Acc bsum + mins
+                    for (int q8_row = 0; q8_row < 4; q8_row++) {
+                        // Each pair of subblocks share the same bsums
+                        // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
+                        int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
+                        int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
+
+                        bias_acc[2 * q8_row] =
+                            vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
+                        bias_acc[2 * q8_row] =
+                            vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
+                        bias_acc[2 * q8_row + 1] =
+                            vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
+                        bias_acc[2 * q8_row + 1] =
+                            vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
+                    }
+                }  // for sb
+
+                // Reorder of i8mm output with bias and output layout
+                for (int i = 0; i < 8; i++) {
+                    int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
+                    acc[i]          = vcombine_s32(aux.val[0], aux.val[1]);
+                }
+                int32x4_t reorder_acc[8] = {
+                    vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
+                    vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
+                    vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
+                    vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
+                    vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
+                    vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
+                    vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
+                    vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
+                };
+
+                for (int i = 0; i < q8_k_blocklen; i++) {
+                    for (int j = 0; j < 2; j++) {
+                        float32x4_t       q8_d    = vdupq_n_f32(q8_ptr[b].d[i]);
+                        float32x4_t       q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].dmin + j * 4)));
+                        const float32x4_t dmins   = vmulq_f32(q4_dmin, q8_d);
+
+                        float32x4_t       q4_d  = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].d + j * 4)));
+                        const float32x4_t scale = vmulq_f32(q4_d, q8_d);
+
+                        acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
+                        acc_f32[2 * i + j] =
+                            vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
+                    }
+                }
+            }  // for b
+
+            // With the previous reorder, the tile is already in the correct memory layout.
+            for (int i = 0; i < q8_k_blocklen; i++) {
+                int row = y * q8_k_blocklen + i;
+                for (int j = 0; j < 2; j++) {
+                    int col    = x * ncols_interleaved + j * 4;
+                    int offset = row * bs + col;
+                    vst1q_f32(s + offset, acc_f32[2 * i + j]);
+                }
+            }
+        }  // for x
+    }  // for y
+    return;
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
 }
 
-void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
-    const int qk = QK8_0;
-    const int nb = n / qk;
-    const int ncols_interleaved = 4;
-    const int blocklen = 4;
+void ggml_gemm_q5_K_8x8_q8_K(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    constexpr int qk = QK_K;
+    const int     nb = n / qk;
 
-    assert (n % qk == 0);
-    assert (nr % 4 == 0);
-    assert (nc % ncols_interleaved == 0);
+    constexpr int ncols_interleaved = 8;
+    constexpr int blocklen          = 8;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
 
-    UNUSED(s);
-    UNUSED(bs);
-    UNUSED(vx);
-    UNUSED(vy);
-    UNUSED(nr);
-    UNUSED(nc);
     UNUSED(nb);
     UNUSED(ncols_interleaved);
     UNUSED(blocklen);
 
-#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
-    const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    constexpr int    q8_k_blocklen = 4;
+    constexpr int    col_pairs     = ncols_interleaved / 2;
+    const uint8x16_t m4b           = vdupq_n_u8(0x0f);
+    const uint8x16_t mone          = vdupq_n_u8(1);
+    const uint8x16_t mtwo          = vdupq_n_u8(2);
+
+    // 8 accumulators: 2 row pairs × 4 col pairs
+    float32x4_t acc_f32[blocklen];
+
+    for (int y = 0; y < nr / q8_k_blocklen; y++) {
+        const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
 
-    for (int y = 0; y < nr / 4; y++) {
-        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
         for (int x = 0; x < nc / ncols_interleaved; x++) {
-            const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
+            const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
 
-            float32x4_t sumf[4];
-            for (int m = 0; m < 4; m++) {
-                sumf[m] = vdupq_n_f32(0);
+            for (int i = 0; i < blocklen; i++) {
+                acc_f32[i] = vdupq_n_f32(0);
             }
 
-            for (int l = 0; l < nb; l++) {
-                float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
-                float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
+            for (int b = 0; b < nb; b++) {
+                // bsums pairs belongs to the same q8_k subblock
+                const int16x8_t bsums[4]{
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
+                };
+                int16_t bsums_arr[4][8];
+                for (int q8_row = 0; q8_row < 4; q8_row++) {
+                    vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
+                }
 
-                int32x4_t sumi_0 = vdupq_n_s32(0);
-                int32x4_t sumi_1 = vdupq_n_s32(0);
-                int32x4_t sumi_2 = vdupq_n_s32(0);
-                int32x4_t sumi_3 = vdupq_n_s32(0);
+                int32x4_t sb_acc[4];    // Aux accumulators to store subblock (partial) results
+                int32x4_t acc[8];       // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
+                int32x4_t bias_acc[8];  // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
+                for (int i = 0; i < 8; i++) {
+                    acc[i]      = vdupq_n_s32(0);
+                    bias_acc[i] = vdupq_n_s32(0);
+                }
 
-                for (int k = 0; k < 4; k++) {
-                    int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
-                    int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
+                // Load qh once per block and shift after each subblock
+                const uint8_t * qh_base = q5_ptr[b].qh;
+                uint8x16_t      qh[col_pairs][4];
+                for (int cp = 0; cp < col_pairs; cp++) {
+                    qh[cp][0] = vld1q_u8(qh_base + 16 * cp);
+                    qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);
+                    qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);
+                    qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);
+                }
 
-                    uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
-                    int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
-                    int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
+                for (int sb = 0; sb < QK_K / 64; sb++) {
+                    // Need scales for the low and high nibbles
+                    // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
+                    int8_t    q5sb_scales[2][8];
+                    int16x8_t q5sb_mins[2];  // int16 as its needed for bias_acc later
+                    for (int i = 0; i < 2; i++) {
+                        const int offset = sb * 24 + i * 12;
+                        decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], q5sb_scales[i]);
+                    }
 
-                    sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
-                    sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
-                    sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
-                    sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
-                    sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
-                    sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
-                    sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
-                    sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
+                    // q8_ptr[b].qs has interleaved Q8 rows (01, 23)
+                    const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
+
+                    int8x16_t q8_qs_01[8];
+                    int8x16_t q8_qs_23[8];
+
+                    // Load 32-byte per row pair, 1 subblock each time
+                    for (int i = 0; i < 8; i++) {
+                        const int offset = i * 32;  // 16 for row 01, 16 for row 23
+                        q8_qs_01[i]      = vld1q_s8(q8_base + offset);
+                        q8_qs_23[i]      = vld1q_s8(q8_base + offset + 16);
+                    }
+
+                    const int8x16_t q8s[2][8] = {
+                        { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3], q8_qs_01[4], q8_qs_01[5], q8_qs_01[6],
+                         q8_qs_01[7] },
+                        { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3], q8_qs_23[4], q8_qs_23[5], q8_qs_23[6],
+                         q8_qs_23[7] },
+                    };
+
+                    // Q5s columns iterated in pairs (01, 23, 45, 67)
+                    for (int cp = 0; cp < col_pairs; cp++) {
+                        for (int i = 0; i < 4; i++) {
+                            sb_acc[i] = vdupq_n_s32(0);
+                        }
+
+                        uint8x16_t qs_cp_0 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 0);    // 0 .. 7 & 32..39
+                        uint8x16_t qs_cp_1 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 64);   // 8 ..15 & 40..47
+                        uint8x16_t qs_cp_2 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 128);  // 16..23 & 48..55
+                        uint8x16_t qs_cp_3 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 192);  // 24..31 & 56..63
+
+                        // This is the only part of the algorithm that differs with Q4_K
+                        // Extract High bits and pack into 5 bit weights
+                        uint8x16_t hbit_lo_0    = vandq_u8(qh[cp][0], mone);
+                        uint8x16_t hbit_hi_0    = vshlq_n_u8(vandq_u8(qh[cp][0], mtwo), 3);
+                        qh[cp][0]               = vshrq_n_u8(qh[cp][0], 2);
+                        // Same as Q4_K, i8mm to dequantize the weights.
+                        const int8x16_t qs_lo_0 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_0, m4b), hbit_lo_0, 4));
+                        int32x4_t       acc_0   = sb_acc[0];
+                        acc_0                   = vmmlaq_s32(acc_0, qs_lo_0, q8s[0][0]);
+                        int32x4_t acc_2         = sb_acc[2];
+                        acc_2                   = vmmlaq_s32(acc_2, qs_lo_0, q8s[1][0]);
+                        const int8x16_t qs_hi_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_0, 4), hbit_hi_0));
+                        int32x4_t       acc_1   = sb_acc[1];
+                        acc_1                   = vmmlaq_s32(acc_1, qs_hi_0, q8s[0][4]);
+                        int32x4_t acc_3         = sb_acc[3];
+                        acc_3                   = vmmlaq_s32(acc_3, qs_hi_0, q8s[1][4]);
+
+                        // Repeat for the other 3 columns (8..15, 16..23, 24..31)
+                        uint8x16_t hbit_hi_1    = vshlq_n_u8(vandq_u8(qh[cp][1], mtwo), 3);
+                        uint8x16_t hbit_lo_1    = vandq_u8(qh[cp][1], mone);
+                        qh[cp][1]               = vshrq_n_u8(qh[cp][1], 2);
+                        const int8x16_t qs_lo_1 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_1, m4b), hbit_lo_1, 4));
+                        acc_0                   = vmmlaq_s32(acc_0, qs_lo_1, q8s[0][1]);
+                        acc_2                   = vmmlaq_s32(acc_2, qs_lo_1, q8s[1][1]);
+                        const int8x16_t qs_hi_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_1, 4), hbit_hi_1));
+                        acc_1                   = vmmlaq_s32(acc_1, qs_hi_1, q8s[0][5]);
+                        acc_3                   = vmmlaq_s32(acc_3, qs_hi_1, q8s[1][5]);
+
+                        uint8x16_t hbit_hi_2    = vshlq_n_u8(vandq_u8(qh[cp][2], mtwo), 3);
+                        uint8x16_t hbit_lo_2    = vandq_u8(qh[cp][2], mone);
+                        qh[cp][2]               = vshrq_n_u8(qh[cp][2], 2);
+                        const int8x16_t qs_lo_2 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_2, m4b), hbit_lo_2, 4));
+                        acc_0                   = vmmlaq_s32(acc_0, qs_lo_2, q8s[0][2]);
+                        acc_2                   = vmmlaq_s32(acc_2, qs_lo_2, q8s[1][2]);
+                        const int8x16_t qs_hi_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_2, 4), hbit_hi_2));
+                        acc_1                   = vmmlaq_s32(acc_1, qs_hi_2, q8s[0][6]);
+                        acc_3                   = vmmlaq_s32(acc_3, qs_hi_2, q8s[1][6]);
+
+                        uint8x16_t hbit_lo_3    = vandq_u8(qh[cp][3], mone);
+                        uint8x16_t hbit_hi_3    = vshlq_n_u8(vandq_u8(qh[cp][3], mtwo), 3);
+                        qh[cp][3]               = vshrq_n_u8(qh[cp][3], 2);
+                        const int8x16_t qs_lo_3 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_3, m4b), hbit_lo_3, 4));
+                        acc_0                   = vmmlaq_s32(acc_0, qs_lo_3, q8s[0][3]);
+                        sb_acc[0]               = acc_0;
+                        acc_2                   = vmmlaq_s32(acc_2, qs_lo_3, q8s[1][3]);
+                        sb_acc[2]               = acc_2;
+
+                        // Scales[i] corresponds to column i
+                        const int       scale_offset = cp * 2;
+                        const int32_t   s0           = q5sb_scales[0][scale_offset];
+                        const int32_t   s1           = q5sb_scales[0][scale_offset + 1];
+                        const int32x4_t block_scale  = vcombine_s32(vdup_n_s32(s0), vdup_n_s32(s1));
+                        acc[cp]                      = vmlaq_s32(acc[cp], sb_acc[0], block_scale);
+                        acc[cp + 4]                  = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale);
+
+                        const int8x16_t qs_hi_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_3, 4), hbit_hi_3));
+                        acc_1                   = vmmlaq_s32(acc_1, qs_hi_3, q8s[0][7]);
+                        sb_acc[1]               = acc_1;
+                        acc_3                   = vmmlaq_s32(acc_3, qs_hi_3, q8s[1][7]);
+                        sb_acc[3]               = acc_3;
+
+                        const int32_t   s2           = q5sb_scales[1][scale_offset];
+                        const int32_t   s3           = q5sb_scales[1][scale_offset + 1];
+                        const int32x4_t block_scale2 = vcombine_s32(vdup_n_s32(s2), vdup_n_s32(s3));
+                        acc[cp]                      = vmlaq_s32(acc[cp], sb_acc[1], block_scale2);
+                        acc[cp + 4]                  = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale2);
+                    }
+
+                    // Multiply Acc bsum + mins
+                    for (int q8_row = 0; q8_row < 4; q8_row++) {
+                        // Each pair of subblocks share the same bsums
+                        // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
+                        int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
+                        int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
+
+                        bias_acc[2 * q8_row] =
+                            vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
+                        bias_acc[2 * q8_row] =
+                            vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
+                        bias_acc[2 * q8_row + 1] =
+                            vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
+                        bias_acc[2 * q8_row + 1] =
+                            vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
+                    }
+                }  // for sb
+
+                // Reorder of i8mm output with bias and output layout
+                for (int i = 0; i < 8; i++) {
+                    int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
+                    acc[i]          = vcombine_s32(aux.val[0], aux.val[1]);
                 }
+                int32x4_t reorder_acc[8] = {
+                    vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
+                    vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
+                    vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
+                    vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
+                    vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
+                    vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
+                    vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
+                    vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
+                };
 
-                sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
-                sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
-                sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
-                sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
-            }
+                for (int i = 0; i < q8_k_blocklen; i++) {
+                    for (int j = 0; j < 2; j++) {
+                        float32x4_t       q8_d    = vdupq_n_f32(q8_ptr[b].d[i]);
+                        float32x4_t       q5_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].dmin + j * 4)));
+                        const float32x4_t dmins   = vmulq_f32(q5_dmin, q8_d);
 
-            for (int m = 0; m < 4; m++) {
-                vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
+                        float32x4_t       q5_d  = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].d + j * 4)));
+                        const float32x4_t scale = vmulq_f32(q5_d, q8_d);
+
+                        acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
+                        acc_f32[2 * i + j] =
+                            vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
+                    }
+                }
+            }  // for b
+
+            // With the previous reorder, the tile is already in the correct memory layout.
+            for (int i = 0; i < q8_k_blocklen; i++) {
+                int row = y * q8_k_blocklen + i;
+                for (int j = 0; j < 2; j++) {
+                    int col    = x * ncols_interleaved + j * 4;
+                    int offset = row * bs + col;
+                    vst1q_f32(s + offset, acc_f32[2 * i + j]);
+                }
             }
-        }
-    }
+        }  // for x
+    }  // for y
     return;
-#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
-    ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
 }
 
-void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+void ggml_gemm_q6_K_8x4_q8_K(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
     constexpr int qk = QK_K;
     const int     nb = n / qk;
 
@@ -2219,171 +4539,167 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
 
 #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
     constexpr int    q8_k_blocklen = 4;
-    constexpr int    acc_size  = 2 * 4;  // 2 row pairs × 4 col pairs
-    const uint8x16_t m4b       = vdupq_n_u8(0x0f);
+    constexpr int    col_groups    = ncols_interleaved / 4;
+    constexpr int    acc_size      = q8_k_blocklen * col_groups;  // 4 rows, 2 column groups
+    const uint8x16_t m4b           = vdupq_n_u8(0x0f);
+    const uint8x16_t mask_lo       = vdupq_n_u8(0x03);
+    const uint8x16_t mask_hi       = vdupq_n_u8(0x30);
+    const int8x16_t  m32s          = vdupq_n_s8(32);
 
-    // 8 accumulators: 2 row pairs × 4 col pairs
     float32x4_t acc_f32[acc_size];
 
     for (int y = 0; y < nr / q8_k_blocklen; y++) {
         const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
 
         for (int x = 0; x < nc / ncols_interleaved; x++) {
-            const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
+            const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
 
             for (int i = 0; i < acc_size; i++) {
                 acc_f32[i] = vdupq_n_f32(0);
             }
 
             for (int b = 0; b < nb; b++) {
-                // d4 0 1 2 3, 4 5 6 7
-                float32x4_t q4_d_0123    = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d));
-                float32x4_t q4_d_4567    = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4));
-                // d8 0 1 2 3
-                float32x4_t q8_d_0123    = vld1q_f32(q8_ptr[b].d);
-                // mins
-                float32x4_t q4_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin));
-                float32x4_t q4_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4));
+                float32x4_t q6_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d));
+                float32x4_t q6_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4));
+                float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
 
-                // Precomputation of scales and mins
                 float32x4_t sbd_scale_0123[q8_k_blocklen];
                 float32x4_t sbd_scale_4567[q8_k_blocklen];
-                float32x4_t sbd_min_0123[q8_k_blocklen];
-                float32x4_t sbd_min_4567[q8_k_blocklen];
-
-                sbd_scale_0123[0] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 0);
-                sbd_scale_4567[0] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 0);
-                sbd_min_0123[0]   = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 0);
-                sbd_min_4567[0]   = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 0);
-
-                sbd_scale_0123[1] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 1);
-                sbd_scale_4567[1] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 1);
-                sbd_min_0123[1]   = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 1);
-                sbd_min_4567[1]   = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 1);
-
-                sbd_scale_0123[2] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 2);
-                sbd_scale_4567[2] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 2);
-                sbd_min_0123[2]   = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 2);
-                sbd_min_4567[2]   = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 2);
-
-                sbd_scale_0123[3] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 3);
-                sbd_scale_4567[3] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 3);
-                sbd_min_0123[3]   = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 3);
-                sbd_min_4567[3]   = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 3);
 
-                // Precomputation of bsums, each vpaddq calcs all the bsums for each row
-                const int16x8_t bsums[q8_k_blocklen] = {
-                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
-                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
-                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
-                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
-                };
-                int16_t bsums_arr[QK_K / 64][8];
-                for (int q8_row = 0; q8_row < 4; q8_row++) {
-                    vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
-                }
+                sbd_scale_0123[0] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 0);
+                sbd_scale_4567[0] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 0);
+                sbd_scale_0123[1] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 1);
+                sbd_scale_4567[1] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 1);
+                sbd_scale_0123[2] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 2);
+                sbd_scale_4567[2] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 2);
+                sbd_scale_0123[3] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 3);
+                sbd_scale_4567[3] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 3);
 
-                // interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
-                int32x4_t bias_acc[acc_size];
+                int32x4_t acc_s32[acc_size];
                 for (int i = 0; i < acc_size; i++) {
-                    bias_acc[i] = vdupq_n_s32(0);
+                    acc_s32[i] = vdupq_n_s32(0);
                 }
 
-                for (int sb = 0; sb < QK_K / 64; sb++) {
-                    // Int accumulators for qs vecdot (4 row x 2 col quartets)
-                    int32x4_t acc_lo[acc_size];
-                    int32x4_t acc_hi[acc_size];
-                    for (int i = 0; i < acc_size; i++) {
-                        acc_lo[i] = vdupq_n_s32(0);
-                        acc_hi[i] = vdupq_n_s32(0);
-                    }
-                    // Need scales for the low and high nibbles
-                    // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
-                    int16x8_t q4sb_scales[2];
-                    int16x8_t q4sb_mins[2];
-                    for (int i = 0; i < 2; i++) {
-                        int8_t    aux_q4sb[8];
-                        const int offset = sb * 24 + i * 12;
-                        decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
-                        q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
-                    }
-
-                    constexpr int reads_per_sb = 8;  // 8 * 16 bytes each => 32 qs * 4 rows
-                    for (int k = 0; k < reads_per_sb; k++) {
-                        const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
-                        const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
-
-                        // 0..3 & 32..35
-                        const uint8x16_t q4_0123 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k);
-                        const uint8x16_t q4_4567 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k + 16);
-
-                        const int8x16_t q4_0123_lo = vreinterpretq_s8_u8(vandq_u8(q4_0123, m4b));
-                        const int8x16_t q4_0123_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_0123, 4));
-
-                        acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q4_0123_lo, q8_blk0, 0);  //  0..3  r0 c0123
-                        acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q4_0123_lo, q8_blk0, 1);  //  0..3  r1 c0123
-                        acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q4_0123_lo, q8_blk0, 2);  //  0..3  r2 c0123
-                        acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q4_0123_lo, q8_blk0, 3);  //  0..3  r3 c0123
+                int16_t q6_scales[8 * 16];
+                for (int i = 0; i < 16; i++) {
+                    int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
+                    vst1q_s16(q6_scales + i * 8, scales);
+                }
 
-                        acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q4_0123_hi, q8_blk1, 0);  // 32..35 r0 c0123
-                        acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q4_0123_hi, q8_blk1, 1);  // 32..35 r1 c0123
-                        acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q4_0123_hi, q8_blk1, 2);  // 32..35 r2 c0123
-                        acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q4_0123_hi, q8_blk1, 3);  // 32..35 r3 c0123
+                for (int half = 0; half < 2; half++) {
+                    const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
+                    const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
 
-                        const int8x16_t q4_4567_lo = vreinterpretq_s8_u8(vandq_u8(q4_4567, m4b));
-                        const int8x16_t q4_4567_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_4567, 4));
+                    for (int sb = 0; sb < QK_K / 64; sb++) {
+                        int32x4_t acc_lo[acc_size];
+                        int32x4_t acc_hi[acc_size];
+                        for (int i = 0; i < acc_size; i++) {
+                            acc_lo[i] = vdupq_n_s32(0);
+                            acc_hi[i] = vdupq_n_s32(0);
+                        }
 
-                        acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q4_4567_lo, q8_blk0, 0);  //  0..3  r0 c4567
-                        acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q4_4567_lo, q8_blk0, 1);  //  0..3  r1 c4567
-                        acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q4_4567_lo, q8_blk0, 2);  //  0..3  r2 c4567
-                        acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q4_4567_lo, q8_blk0, 3);  //  0..3  r3 c4567
+                        const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64;
+                        const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64;
+
+                        // 4 rows * 16 elements per scale
+                        // 4 reads of 16 bytes each
+                        constexpr int reads_per_sb = 4;
+                        int8x16_t     q8_l[reads_per_sb];
+                        int8x16_t     q8_h[reads_per_sb];
+                        for (int k = 0; k < reads_per_sb; k++) {
+                            q8_l[k] = vld1q_s8(q8_base_l + 16 * k);
+                            q8_h[k] = vld1q_s8(q8_base_h + 16 * k);
+                        }
 
-                        acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q4_4567_hi, q8_blk1, 0);  // 32..35 r0 c4567
-                        acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q4_4567_hi, q8_blk1, 1);  // 32..35 r1 c4567
-                        acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q4_4567_hi, q8_blk1, 2);  // 32..35 r2 c4567
-                        acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q4_4567_hi, q8_blk1, 3);  // 32..35 r3 c4567
-                    }
+                        const int ql_off_base = sb * QK_K / 2;
+                        const int qh_off_base = ql_off_base & 255;
 
-                    // Scale and bias application
-                    // acc is stored interleaved to match output layout
-                    const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
-                    const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
-                    const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
-                    const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
-                    for (int row = 0; row < q8_k_blocklen; row++) {
-                        // Bias correction
-                        // row c0123 blk0 and blk1
-                        const float32x4_t sumf_0123 =
-                            vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
-                                                    vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
-                        acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
+                        uint8x16_t q6_ql_0123[reads_per_sb];
+                        uint8x16_t q6_ql_4567[reads_per_sb];
+                        uint8x16_t q6_qh_0123[reads_per_sb];
+                        uint8x16_t q6_qh_4567[reads_per_sb];
 
-                        // row c4567 blk0 and blk1
-                        const float32x4_t sumf_4567 =
-                            vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
-                                                    vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
-                        acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
+                        for (int k = 0; k < reads_per_sb; k++) {
+                            q6_ql_0123[k] = vld1q_u8(ql_base + ql_off_base + k * 32);
+                            q6_ql_4567[k] = vld1q_u8(ql_base + ql_off_base + k * 32 + 16);
+                            q6_qh_0123[k] = vld1q_u8(qh_base + qh_off_base + k * 32);
+                            q6_qh_4567[k] = vld1q_u8(qh_base + qh_off_base + k * 32 + 16);
+                        }
 
-                        // Bias
-                        const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
-                        const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
+                        if (sb > 1) {
+                            for (int k = 0; k < reads_per_sb; k++) {
+                                q6_qh_0123[k] = vshrq_n_u8(q6_qh_0123[k], 2);
+                                q6_qh_4567[k] = vshrq_n_u8(q6_qh_4567[k], 2);
+                            }
+                        }
 
-                        // row c0123 blk0 and blk1
-                        bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
-                        bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
+                        for (int k = 0; k < reads_per_sb; k++) {
+                            // q = (ql | qh) - 32
+                            const uint8x16_t hbit_lo_0123 = vandq_u8(q6_qh_0123[k], mask_lo);
+                            const uint8x16_t hbit_hi_0123 = vandq_u8(q6_qh_0123[k], mask_hi);
+                            const uint8x16_t hbit_lo_4567 = vandq_u8(q6_qh_4567[k], mask_lo);
+                            const uint8x16_t hbit_hi_4567 = vandq_u8(q6_qh_4567[k], mask_hi);
+
+                            const int8x16_t q6_0123_lo = vsubq_s8(
+                                vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_0123[k], m4b), hbit_lo_0123, 4)), m32s);
+                            const int8x16_t q6_0123_hi = vsubq_s8(
+                                vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_0123[k], 4), hbit_hi_0123)), m32s);
+
+                            acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q6_0123_lo, q8_l[k], 0);  //  0..3  r0 c0123
+                            acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q6_0123_lo, q8_l[k], 1);  //  0..3  r1 c0123
+                            acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q6_0123_lo, q8_l[k], 2);  //  0..3  r2 c0123
+                            acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q6_0123_lo, q8_l[k], 3);  //  0..3  r3 c0123
+
+                            acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q6_0123_hi, q8_h[k], 0);  // 64..67 r0 c0123
+                            acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q6_0123_hi, q8_h[k], 1);  // 64..67 r1 c0123
+                            acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q6_0123_hi, q8_h[k], 2);  // 64..67 r2 c0123
+                            acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q6_0123_hi, q8_h[k], 3);  // 64..67 r3 c0123
+
+                            const int8x16_t q6_4567_lo = vsubq_s8(
+                                vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_4567[k], m4b), hbit_lo_4567, 4)), m32s);
+                            const int8x16_t q6_4567_hi = vsubq_s8(
+                                vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_4567[k], 4), hbit_hi_4567)), m32s);
+
+                            acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q6_4567_lo, q8_l[k], 0);  //  0..3  r0 c4567
+                            acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q6_4567_lo, q8_l[k], 1);  //  0..3  r1 c4567
+                            acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q6_4567_lo, q8_l[k], 2);  //  0..3  r2 c4567
+                            acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q6_4567_lo, q8_l[k], 3);  //  0..3  r3 c4567
+
+                            acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q6_4567_hi, q8_h[k], 0);  // 64..67 r0 c4567
+                            acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q6_4567_hi, q8_h[k], 1);  // 64..67 r1 c4567
+                            acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q6_4567_hi, q8_h[k], 2);  // 64..67 r2 c4567
+                            acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q6_4567_hi, q8_h[k], 3);  // 64..67 r3 c4567
+                        }
 
-                        // row c4567 blk0 and blk1
-                        bias_acc[2 * row + 1] =
-                            vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
-                        bias_acc[2 * row + 1] =
-                            vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
+                        // Scale and bias
+                        const int scale_idx_l = half * 8 + sb;
+                        const int scale_idx_h = half * 8 + sb + 4;
+
+                        for (int g = 0; g < col_groups; g++) {
+                            const int16x4_t scales_l16  = vld1_s16(q6_scales + scale_idx_l * 8 + g * 4);
+                            const int16x4_t scales_h16  = vld1_s16(q6_scales + scale_idx_h * 8 + g * 4);
+                            const int32x4_t scale_vec_l = vmovl_s16(scales_l16);
+                            const int32x4_t scale_vec_h = vmovl_s16(scales_h16);
+                            const int       acc_offset  = g * q8_k_blocklen;
+
+                            for (int row = 0; row < q8_k_blocklen; row++) {
+                                const int idx = row * 2 + g;
+                                acc_s32[idx]  = vmlaq_s32(acc_s32[idx], acc_lo[acc_offset + row], scale_vec_l);
+                                acc_s32[idx]  = vmlaq_s32(acc_s32[idx], acc_hi[acc_offset + row], scale_vec_h);
+                            }
+                        }
                     }
-                }  // for sb
+                }
 
+                // Finally we apply the superblock scales
                 for (int row = 0; row < q8_k_blocklen; row++) {
-                    acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
-                    acc_f32[2 * row + 1] =
-                        vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
+                    const int       idx0     = 2 * row;
+                    const int       idx1     = 2 * row + 1;
+                    const int32x4_t acc_0123 = acc_s32[idx0];
+                    const int32x4_t acc_4567 = acc_s32[idx1];
+
+                    acc_f32[idx0] = vmlaq_f32(acc_f32[idx0], vcvtq_f32_s32(acc_0123), sbd_scale_0123[row]);
+                    acc_f32[idx1] = vmlaq_f32(acc_f32[idx1], vcvtq_f32_s32(acc_4567), sbd_scale_4567[row]);
                 }
             }  // for b
 
@@ -2399,10 +4715,10 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
     }  // for y
     return;
 #endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
-    ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+    ggml_gemm_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
 }
 
-void ggml_gemm_q4_K_8x8_q8_K(int                        n,
+void ggml_gemm_q6_K_8x8_q8_K(int                        n,
                              float * GGML_RESTRICT      s,
                              size_t                     bs,
                              const void * GGML_RESTRICT vx,
@@ -2426,144 +4742,155 @@ void ggml_gemm_q4_K_8x8_q8_K(int                        n,
 #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
     constexpr int    q8_k_blocklen = 4;
     const uint8x16_t m4b           = vdupq_n_u8(0x0f);
+    const uint8x16_t mask_lo       = vdupq_n_u8(0x03);
+    const uint8x16_t mask_hi       = vdupq_n_u8(0x30);
+    const int8x16_t  m32s          = vdupq_n_s8(32);
 
-    // 8 accumulators: 2 row pairs × 4 col pairs
+    // 8 accumulators: 4 q8 rows × 2 col groups (0-3, 4-7)
     float32x4_t acc_f32[blocklen];
 
     for (int y = 0; y < nr / q8_k_blocklen; y++) {
         const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
 
         for (int x = 0; x < nc / ncols_interleaved; x++) {
-            const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
+            const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
 
             for (int i = 0; i < blocklen; i++) {
                 acc_f32[i] = vdupq_n_f32(0);
             }
 
             for (int b = 0; b < nb; b++) {
-                // bsums pairs belongs to the same q8_k subblock
-                const int16x8_t bsums[4]{
-                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
-                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
-                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
-                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
-                };
-                int16_t bsums_arr[4][8];
-                for (int q8_row = 0; q8_row < 4; q8_row++) {
-                    vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
-                }
-
-                int32x4_t sb_acc[4];    // Aux accumulators to store subblock (partial) results
-                int32x4_t acc[8];       // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
-                int32x4_t bias_acc[8];  // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
+                int32x4_t acc[8];  // rows 01 stored in [0][1][2][3], rows 23 stored in [4][5][6][7]
                 for (int i = 0; i < 8; i++) {
-                    acc[i]      = vdupq_n_s32(0);
-                    bias_acc[i] = vdupq_n_s32(0);
+                    acc[i] = vdupq_n_s32(0);
                 }
 
-                for (int sb = 0; sb < QK_K / 64; sb++) {
-                    // Need scales for the low and high nibbles
-                    // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
-                    int8_t    q4sb_scales[2][8];
-                    int16x8_t q4sb_mins[2];  // int16 as its needed for bias_acc later
-                    for (int i = 0; i < 2; i++) {
-                        const int offset = sb * 24 + i * 12;
-                        decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
-                    }
-
-                    // q8_ptr[b].qs has interleaved Q8 rows (01, 23)
-                    const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
+                // Q6_K has simple 8-bit scales, 16 per block (one per 16 values)
+                // Reused for bias and dequantization later
+                int16_t q6_scales[16 * 8];
+                for (int i = 0; i < 16; ++i) {
+                    int16x8_t s16 = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
+                    vst1q_s16(q6_scales + i * 8, s16);
+                }
 
-                    int8x16_t q8_qs_01[8];
-                    int8x16_t q8_qs_23[8];
+                // Process two 128-value halves per superblock
+                for (int half = 0; half < 2; half++) {
+
+                    const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
+                    const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
+
+                    // A subblock (sb) is a set of weights that share the scale
+                    // Since q6_K scales are per 16 elements
+                    // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
+                    for (int sb = 0; sb < QK_K / 64; sb++) {
+                        // Q6_K weight index increasing by 64 instead of 32 requires
+                        // loading various q8 memory regions
+                        const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64;
+                        const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64;
+
+                        int8x16_t q8_l_01[2];
+                        int8x16_t q8_l_23[2];
+                        for (int i = 0; i < 2; i++) {
+                            const int offset = i * 32;
+                            q8_l_01[i]       = vld1q_s8(q8_base_l + offset);       // 0..7 & 8..15 (r01)
+                            q8_l_23[i]       = vld1q_s8(q8_base_l + offset + 16);  // 0..7 & 8..15 (r23)
+                        }
 
-                    // Load 32-byte per row pair, 1 subblock each time
-                    for (int i = 0; i < 8; i++) {
-                        const int offset = i * 32;  // 16 for row 01, 16 for row 23
-                        q8_qs_01[i]      = vld1q_s8(q8_base + offset);
-                        q8_qs_23[i]      = vld1q_s8(q8_base + offset + 16);
-                    }
+                        int8x16_t q8_h_01[2];
+                        int8x16_t q8_h_23[2];
+                        for (int i = 0; i < 2; i++) {
+                            const int offset = i * 32;
+                            q8_h_01[i]       = vld1q_s8(q8_base_h + offset);
+                            q8_h_23[i]       = vld1q_s8(q8_base_h + offset + 16);
+                        }
 
-                    const int8x16_t q8s[2][8] = {
-                        { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3],
-                          q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], q8_qs_01[7] },
-                        { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3],
-                          q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], q8_qs_23[7] },
-                    };
+                        const int ql_off_base = sb * QK_K / 2;
 
-                    // Q4s columns iterated in pairs (01, 23, 45, 67)
-                    for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
-                        for (int i = 0; i < 4; i++) {
-                            sb_acc[i] = vdupq_n_s32(0);
+                        uint8x16_t q6_ql_0[4];
+                        uint8x16_t q6_ql_1[4];
+                        for (int k = 0; k < 4; k++) {
+                            q6_ql_0[k] = vld1q_u8(ql_base + ql_off_base + 16 * k);
+                            q6_ql_1[k] = vld1q_u8(ql_base + ql_off_base + 64 + 16 * k);
                         }
 
-                        uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 0);    // 0 .. 7 & 32..39
-                        uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 64);   // 8 ..15 & 40..47
-                        uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 128);  // 16..23 & 48..55
-                        uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 192);  // 24..31 & 56..63
-                        const int8x16_t q4_nibbles[2][4] = {
-                            {
-                                vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)),
-                                vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)),
-                                vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)),
-                                vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)),
-                            },
-                            {
-                                vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)),
-                                vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)),
-                                vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)),
-                                vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)),
-                            }
-                        };
+                        const int  qh_off_base = (sb * QK_K / 2) & 255;  // wrap after 256 bytes
+                        uint8x16_t q6_qh_0[4];
+                        uint8x16_t q6_qh_1[4];
+                        for (int k = 0; k < 4; k++) {
+                            q6_qh_0[k] = vld1q_u8(qh_base + qh_off_base + 16 * k);
+                            q6_qh_1[k] = vld1q_u8(qh_base + qh_off_base + 64 + 16 * k);
+                        }
 
-                        // Calculates the Qs muladd of every row pair (rp) rows 01 and 23 of q8
-                        // for each of the internal 32 qs subblock (blk)
-                        for (int rp = 0; rp < 2; rp++) {
-                            for (int blk = 0; blk < 2; blk++) {
-                                const int8x16_t * q8  = &q8s[rp][4 * blk];
-                                const int8x16_t * q4  = q4_nibbles[blk];
-                                int32x4_t         acc = sb_acc[2 * rp + blk];
-                                // mul add for each qs in the same subblock
-                                for (int qs_offset = 0; qs_offset < 4; qs_offset++) {
-                                    acc = vmmlaq_s32(acc, q4[qs_offset], q8[qs_offset]);
-                                }
-                                sb_acc[2 * rp + blk] = acc;
+                        // Adjust for the proper high bits (Sb 2 and 3)
+                        if (sb > 1) {
+                            for (int k = 0; k < 4; k++) {
+                                q6_qh_0[k] = vshrq_n_u8(q6_qh_0[k], 2);
+                                q6_qh_1[k] = vshrq_n_u8(q6_qh_1[k], 2);
                             }
                         }
 
-                        // Scales[i] corresponds to column i
-                        const int scale_offset = cp * 2;
-                        for (int blk = 0; blk < 2; blk++) {
-                            const int32x4_t block_scale = {
-                                (int32_t) q4sb_scales[blk][scale_offset],
-                                (int32_t) q4sb_scales[blk][scale_offset],
-                                (int32_t) q4sb_scales[blk][scale_offset + 1],
-                                (int32_t) q4sb_scales[blk][scale_offset + 1],
+                        // Process column pairs (0-1, 2-3, 4-5, 6-7)
+                        for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
+                            const uint8x16_t q6_qs_cp_0_l = q6_ql_0[cp];
+                            const uint8x16_t q6_qs_cp_1_l = q6_ql_1[cp];
+                            const uint8x16_t q6_qs_cp_0_h = q6_qh_0[cp];
+                            const uint8x16_t q6_qs_cp_1_h = q6_qh_1[cp];
+
+                            // Extract high 2 bits for upper nibble reconstruction
+                            const uint8x16_t q6_qs_cp_0_hh = vandq_u8(q6_qs_cp_0_h, mask_hi);
+                            const uint8x16_t q6_qs_cp_1_hh = vandq_u8(q6_qs_cp_1_h, mask_hi);
+
+                            // q6 = (low4 | high2<<4) - 32
+                            // Use vsliq_n_u8 to combine shift-left-insert in one instruction (like Q5_K)
+                            const int8x16_t q6_l0 = vsubq_s8(
+                                vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_cp_0_l, m4b), vandq_u8(q6_qs_cp_0_h, mask_lo), 4)),
+                                m32s);
+                            const int8x16_t q6_l1 = vsubq_s8(
+                                vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_cp_1_l, m4b), vandq_u8(q6_qs_cp_1_h, mask_lo), 4)),
+                                m32s);
+                            const int8x16_t q6_h0 = vsubq_s8(
+                                vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_0_l, 4), q6_qs_cp_0_hh)), m32s);
+                            const int8x16_t q6_h1 = vsubq_s8(
+                                vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_1_l, 4), q6_qs_cp_1_hh)), m32s);
+
+                            // row pair 0, base_l
+                            int32x4_t sb_acc_0l = vmmlaq_s32(vdupq_n_s32(0), q6_l0, q8_l_01[0]);
+                            sb_acc_0l           = vmmlaq_s32(sb_acc_0l, q6_l1, q8_l_01[1]);
+                            // row pair 0, base_h
+                            int32x4_t sb_acc_0h = vmmlaq_s32(vdupq_n_s32(0), q6_h0, q8_h_01[0]);
+                            sb_acc_0h           = vmmlaq_s32(sb_acc_0h, q6_h1, q8_h_01[1]);
+                            // row pair 1, base_l
+                            int32x4_t sb_acc_1l = vmmlaq_s32(vdupq_n_s32(0), q6_l0, q8_l_23[0]);
+                            sb_acc_1l           = vmmlaq_s32(sb_acc_1l, q6_l1, q8_l_23[1]);
+                            // row pair 1, base_h
+                            int32x4_t sb_acc_1h = vmmlaq_s32(vdupq_n_s32(0), q6_h0, q8_h_23[0]);
+                            sb_acc_1h           = vmmlaq_s32(sb_acc_1h, q6_h1, q8_h_23[1]);
+
+                            const int scale_idx_l = half * 8 + sb;
+                            const int scale_idx_h = half * 8 + sb + 4;
+
+                            const int32x4_t scale_vec_l = {
+                                q6_scales[scale_idx_l * 8 + cp * 2 + 0],
+                                q6_scales[scale_idx_l * 8 + cp * 2 + 0],
+                                q6_scales[scale_idx_l * 8 + cp * 2 + 1],
+                                q6_scales[scale_idx_l * 8 + cp * 2 + 1],
+                            };
+                            const int32x4_t scale_vec_h = {
+                                q6_scales[scale_idx_h * 8 + cp * 2 + 0],
+                                q6_scales[scale_idx_h * 8 + cp * 2 + 0],
+                                q6_scales[scale_idx_h * 8 + cp * 2 + 1],
+                                q6_scales[scale_idx_h * 8 + cp * 2 + 1],
                             };
-                            acc[cp]     = vmlaq_s32(acc[cp], sb_acc[blk], block_scale);
-                            acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[blk + 2], block_scale);
-                        }
-                    }
-
-                    // Multiply Acc bsum + mins
-                    for (int q8_row = 0; q8_row < 4; q8_row++) {
-                        // Each pair of subblocks share the same bsums
-                        // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
-                        int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
-                        int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
 
-                        bias_acc[2 * q8_row] =
-                            vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
-                        bias_acc[2 * q8_row] =
-                            vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
-                        bias_acc[2 * q8_row + 1] =
-                            vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
-                        bias_acc[2 * q8_row + 1] =
-                            vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
+                            acc[cp]     = vmlaq_s32(acc[cp], sb_acc_0l, scale_vec_l);
+                            acc[cp]     = vmlaq_s32(acc[cp], sb_acc_0h, scale_vec_h);
+                            acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc_1l, scale_vec_l);
+                            acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc_1h, scale_vec_h);
+                        }
                     }
-                }  // for sb
+                }  // for half
 
-                // Reorder of i8mm output with bias and output layout
+                // Reorder i8mm output to match memory layout
                 for (int i = 0; i < 8; i++) {
                     int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
                     acc[i]          = vcombine_s32(aux.val[0], aux.val[1]);
@@ -2579,23 +4906,20 @@ void ggml_gemm_q4_K_8x8_q8_K(int                        n,
                     vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
                 };
 
+                // Apply superblock scale (no mins for q6_K)
                 for (int i = 0; i < q8_k_blocklen; i++) {
                     for (int j = 0; j < 2; j++) {
-                        float32x4_t       q8_d    = vdupq_n_f32(q8_ptr[b].d[i]);
-                        float32x4_t       q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].dmin + j * 4)));
-                        const float32x4_t dmins   = vmulq_f32(q4_dmin, q8_d);
-
-                        float32x4_t       q4_d  = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].d + j * 4)));
-                        const float32x4_t scale = vmulq_f32(q4_d, q8_d);
+                        float32x4_t       q8_d  = vdupq_n_f32(q8_ptr[b].d[i]);
+                        float32x4_t       q6_d  = vcvt_f32_f16(vld1_f16((const __fp16 *) (q6_ptr[b].d + j * 4)));
+                        const float32x4_t scale = vmulq_f32(q6_d, q8_d);
 
-                        acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
                         acc_f32[2 * i + j] =
                             vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
                     }
                 }
             }  // for b
 
-            // With the previous reorder, the tile is already in the correct memory layout.
+            // Store results
             for (int i = 0; i < q8_k_blocklen; i++) {
                 int row = y * q8_k_blocklen + i;
                 for (int j = 0; j < 2; j++) {
@@ -2608,5 +4932,160 @@ void ggml_gemm_q4_K_8x8_q8_K(int                        n,
     }  // for y
     return;
 #endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
-    ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+    ggml_gemm_q6_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemm_q8_0_4x4_q8_0(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    const int qk                = QK8_0;
+    const int nb                = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen          = 4;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
+
+            float32x4_t sumf[4];
+            for (int m = 0; m < 4; m++) {
+                sumf[m] = vdupq_n_f32(0);
+            }
+
+            for (int l = 0; l < nb; l++) {
+                float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *) a_ptr[l].d));
+                float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *) b_ptr[l].d));
+
+                int32x4_t sumi_0 = vdupq_n_s32(0);
+                int32x4_t sumi_1 = vdupq_n_s32(0);
+                int32x4_t sumi_2 = vdupq_n_s32(0);
+                int32x4_t sumi_3 = vdupq_n_s32(0);
+
+                for (int k_group = 0; k_group < 8; k_group += 4) {
+                    int8x16x4_t a = vld1q_s8_x4(a_ptr[l].qs + 16 * k_group);
+                    int8x16x4_t b = vld1q_s8_x4(b_ptr[l].qs + 16 * k_group);
+
+                    for (int k = 0; k < 4; k++) {
+                        sumi_0 = vdotq_laneq_s32(sumi_0, b.val[k], a.val[k], 0);
+                        sumi_1 = vdotq_laneq_s32(sumi_1, b.val[k], a.val[k], 1);
+                        sumi_2 = vdotq_laneq_s32(sumi_2, b.val[k], a.val[k], 2);
+                        sumi_3 = vdotq_laneq_s32(sumi_3, b.val[k], a.val[k], 3);
+                    }
+                }
+
+                sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
+                sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
+                sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
+                sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
+            }
+
+            for (int m = 0; m < 4; m++) {
+                vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
+            }
+        }
+    }
+    return;
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    ggml_gemm_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemm_q8_0_4x8_q8_0(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    const int qk                = QK8_0;
+    const int nb                = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen          = 8;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;
+
+    for (int y = 0; y < nr; y += 4) {
+        const block_q8_0x4 * a_ptr_base = (const block_q8_0x4 *) vy + (y / 4) * nb;
+
+        for (int x = 0; x < nc; x += ncols_interleaved) {
+            const block_q8_0x4 * b_ptr = b_ptr_base + (x / 4) * nb;
+            const block_q8_0x4 * a_ptr = a_ptr_base;
+
+            float32x4_t acc_f32[4];
+            for (int i = 0; i < 4; i++) {
+                acc_f32[i] = vdupq_n_f32(0);
+            }
+
+            for (int b = 0; b < nb; b++) {
+                int32x4_t acc[4];
+                for (int i = 0; i < 4; i++) {
+                    acc[i] = vdupq_n_s32(0);
+                }
+
+                // Process 4 chunks of 8 positions each
+                for (int chunk = 0; chunk < 4; chunk++) {
+                    int8x16_t a01 = vld1q_s8(a_ptr->qs + chunk * 32);
+                    int8x16_t a23 = vld1q_s8(a_ptr->qs + chunk * 32 + 16);
+                    int8x16_t b01 = vld1q_s8(b_ptr->qs + chunk * 32);
+                    int8x16_t b23 = vld1q_s8(b_ptr->qs + chunk * 32 + 16);
+
+                    acc[0] = vmmlaq_s32(acc[0], a01, b01);
+                    acc[1] = vmmlaq_s32(acc[1], a01, b23);
+                    acc[2] = vmmlaq_s32(acc[2], a23, b01);
+                    acc[3] = vmmlaq_s32(acc[3], a23, b23);
+                }
+
+                // Reorder outputs from 2×2 tiles to row-major
+                // acc[0] = [r0c0, r0c1, r1c0, r1c1]
+                // acc[1] = [r0c2, r0c3, r1c2, r1c3]
+                // acc[2] = [r2c0, r2c1, r3c0, r3c1]
+                // acc[3] = [r2c2, r2c3, r3c2, r3c3]
+                int32x4_t row0 = vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1]));
+                int32x4_t row1 = vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1]));
+                int32x4_t row2 = vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3]));
+                int32x4_t row3 = vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3]));
+
+                // Scales
+                float32x4_t a_d = vcvt_f32_f16(vld1_f16((const __fp16 *) a_ptr->d));
+                float32x4_t b_d = vcvt_f32_f16(vld1_f16((const __fp16 *) b_ptr->d));
+
+                acc_f32[0] = vfmaq_f32(acc_f32[0], vcvtq_f32_s32(row0), vmulq_laneq_f32(b_d, a_d, 0));
+                acc_f32[1] = vfmaq_f32(acc_f32[1], vcvtq_f32_s32(row1), vmulq_laneq_f32(b_d, a_d, 1));
+                acc_f32[2] = vfmaq_f32(acc_f32[2], vcvtq_f32_s32(row2), vmulq_laneq_f32(b_d, a_d, 2));
+                acc_f32[3] = vfmaq_f32(acc_f32[3], vcvtq_f32_s32(row3), vmulq_laneq_f32(b_d, a_d, 3));
+
+                a_ptr++;
+                b_ptr++;
+            }
+
+            for (int row = 0; row < 4; row++) {
+                vst1q_f32(s + (y + row) * bs + x, acc_f32[row]);
+            }
+        }
+    }
+    return;
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    ggml_gemm_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp
new file mode 100644
index 00000000000..fedd6430278
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp
@@ -0,0 +1,82 @@
+# include "ggml-backend-impl.h"
+
+#if defined(__powerpc64__) || defined(__ppc64__) || defined(__PPC64__)
+
+#if defined(__linux__)
+#include 
+#endif
+
+#include 
+
+struct powerpc_features {
+    std::string platform = "";
+    int power_version    = -1;
+
+    bool has_vsx         = false;
+
+    powerpc_features() {
+#if defined(__linux__)
+        unsigned long auxval = getauxval(AT_PLATFORM);
+        if (auxval) {
+            platform = std::string(reinterpret_cast(auxval));
+            // TBD: Do systems exist that return this in uppercase?
+            if (platform.substr(0, 5) == "power") {
+                // Extractt a numeric suffix, if one exists
+                int vpos = -1;
+                for (int i = platform.length() - 1; i >= 0; i--) {
+                    if (std::isdigit(platform[i])) {
+                        vpos = i;
+                    } else {
+                        break;
+                    }
+                }
+                if (vpos > -1) {
+                    power_version = std::stoi(platform.substr(vpos));
+                }
+            }
+        }
+#endif
+        if (power_version >= 9) {
+            has_vsx = true;
+        }
+    }
+};
+
+static int ggml_backend_cpu_powerpc_score() {
+    int score = 1;
+    powerpc_features pf;
+
+// Platform scores
+#if defined(GGML_USE_POWER7)
+    if (pf.power_version < 7) { return 0; }
+    score += 1<<1;
+#endif
+#if defined(GGML_USE_POWER8)
+    if (pf.power_version < 8) { return 0; }
+    score += 1<<2;
+#endif
+#if defined(GGML_USE_POWER9)
+    if (pf.power_version < 9) { return 0; }
+    score += 1<<3;
+#endif
+#if defined(GGML_USE_POWER10)
+    if (pf.power_version < 10) { return 0; }
+    score += 1<<4;
+#endif
+#if defined(GGML_USE_POWER11)
+    if (pf.power_version < 11) { return 0; }
+    score += 1<<5;
+#endif
+
+// Feature scores
+#if defined(GGML_USE_VSX)
+    if (!pf.has_vsx) { return 0; }
+    score += 1<<6;
+#endif
+
+    return score;
+}
+
+GGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_powerpc_score)
+
+#endif // defined(__powerpc64__) || defined(__ppc64__) || defined(__PPC64__)
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/arch/powerpc/quants.c b/ml/backend/ggml/ggml/src/ggml-cpu/arch/powerpc/quants.c
new file mode 100644
index 00000000000..d3dfd049eaf
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/arch/powerpc/quants.c
@@ -0,0 +1,2305 @@
+#define GGML_COMMON_IMPL_C
+#include "ggml-common.h"
+#include "ggml-quants.h"
+#include "ggml-impl.h"
+#include "ggml-cpu.h"
+#include "simd-mappings.h"
+
+#include "../../quants.h"
+#include "../../ggml-cpu-impl.h"
+
+#include 
+#include 
+#include 
+#include 
+#include  // for qsort
+#include   // for GGML_ASSERT
+
+#define GROUP_MAX_EPS 1e-15f
+#define GROUP_MAX_EPS_IQ3_XXS 1e-8f
+#define GROUP_MAX_EPS_IQ2_S 1e-8f
+#define GROUP_MAX_EPS_IQ1_M 1e-7f
+#define GROUP_MAX_EPS_IQ1_S 1e-12f
+
+#define UNUSED GGML_UNUSED
+
+#if defined(__POWER9_VECTOR__)
+#define B1(c,s,n)  0x ## n ## c ,  0x ## n ## s
+#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
+#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
+#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
+#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
+#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
+#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
+#define B8(c,s  ) B7(c,s,     c), B7(c,s,     s)
+
+// precomputed tables for expanding 8bits to 8 bytes:
+static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
+static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
+#endif
+
+void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
+    assert(QK8_0 == 32);
+    assert(k % QK8_0 == 0);
+    const int nb = k / QK8_0;
+
+    block_q8_0 * GGML_RESTRICT y = vy;
+
+#if defined(__POWER9_VECTOR__)
+    for (int i = 0; i < nb; i++) {
+        vector float srcv [8];
+        vector float asrcv[8];
+        vector float amaxv[8];
+        vector signed int vi[8];
+
+        for (int j = 0; j < 8; j++) srcv[j]  = vec_xl(0, x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
+
+        for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = MAX(MAX(vec_extract(amaxv[0], 0),
+                                   vec_extract(amaxv[0], 1)),
+                               MAX(vec_extract(amaxv[0], 2),
+                                   vec_extract(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+        const vector float vid = vec_splats(id);
+
+        y[i].d = GGML_CPU_FP32_TO_FP16(d);
+
+        for (int j = 0; j < 8; j++) {
+            const vector float v  = vec_round(vec_mul(srcv[j], vid));
+            vi[j] = vec_cts(v, 0);
+        }
+        vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])),  0, &y[i].qs[0]);
+        vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]);
+    }
+#else
+    GGML_UNUSED(nb);
+    // scalar
+    quantize_row_q8_0_ref(x, y, k);
+#endif
+}
+
+void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
+    assert(k % QK8_1 == 0);
+    const int nb = k / QK8_1;
+
+    block_q8_1 * GGML_RESTRICT y = vy;
+
+#if defined(__POWER9_VECTOR__)
+    for (int i = 0; i < nb; i++) {
+        vector float srcv [8];
+        vector float asrcv[8];
+        vector float amaxv[8];
+        vector signed int vi[8];
+
+        for (int j = 0; j < 8; j++) srcv[j]  = vec_xl(0, x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
+
+        for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = MAX(MAX(vec_extract(amaxv[0], 0),
+                                   vec_extract(amaxv[0], 1)),
+                               MAX(vec_extract(amaxv[0], 2),
+                                   vec_extract(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+        const vector float vid = vec_splats(id);
+
+        y[i].d = GGML_CPU_FP32_TO_FP16(d);
+
+        vector int accv = vec_splats(0);
+
+        for (int j = 0; j < 8; j++) {
+            const vector float v  = vec_round(vec_mul(srcv[j], vid));
+            vi[j] = vec_cts(v, 0);
+
+            accv = vec_add(accv, vi[j]);
+        }
+        vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])),  0, &y[i].qs[0]);
+        vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]);
+
+        accv = vec_add(accv, vec_sld(accv, accv, 4));
+        accv = vec_add(accv, vec_sld(accv, accv, 8));
+        y[i].s = GGML_CPU_FP32_TO_FP16(d * vec_extract(accv, 0));
+    }
+
+#else
+    GGML_UNUSED(nb);
+    // scalar
+    quantize_row_q8_1_ref(x, y, k);
+#endif
+}
+
+
+//===================================== Dot products =================================
+
+void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+
+    assert(n % qk == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q4_0 * GGML_RESTRICT x = vx;
+    const block_q8_0 * GGML_RESTRICT y = vy;
+
+    int ib = 0;
+    float sumf = 0;
+
+#if defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector signed int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+    const vector signed char v8 = vec_splats((signed char)0x8);
+
+    vector float vsumf0 = vec_splats(0.0f);
+
+#pragma GCC unroll 8
+    for (; ib < nb; ++ib) {
+        __builtin_prefetch(x[ib].qs, 0, 1);
+        __builtin_prefetch(y[ib].qs, 0, 1);
+
+        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d));
+        vector float vyd = vec_splats(GGML_CPU_FP16_TO_FP32(y[ib].d));
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);
+        vector signed char q8y0 = vec_xl( 0, y[ib].qs);
+        vector signed char q8y1 = vec_xl(16, y[ib].qs);
+
+        vector signed char q4x0 = vec_and(qxs, lowMask);
+        vector signed char q4x1 = vec_sr(qxs, v4);
+
+        q4x0 = vec_sub(q4x0, v8);
+        q4x1 = vec_sub(q4x1, v8);
+
+        vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0));
+        vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1));
+
+        vector signed int vsumi0 = v0;
+
+        vsumi0 = vec_sum4s(qv0, vsumi0);
+        vsumi0 = vec_sum4s(qv1, vsumi0);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+    }
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    sumf = vec_extract(vsumf0, 0);
+
+    *s = sumf;
+#else
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(ib);
+    UNUSED(sumf);
+    ggml_vec_dot_q4_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    const int qk = QK8_1;
+    const int nb = n / qk;
+
+    assert(n % qk == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q4_1 * GGML_RESTRICT x = vx;
+    const block_q8_1 * GGML_RESTRICT y = vy;
+
+    int ib = 0;
+    float sumf = 0;
+
+#if defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector signed int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+    vector float vsumf0 = vec_splats(0.0f);
+
+#pragma GCC unroll 4
+    for (; ib < nb; ++ib) {
+        __builtin_prefetch(x[ib].qs, 0, 1);
+        __builtin_prefetch(y[ib].qs, 0, 1);
+
+        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d));
+        vector float vyd = vec_splats(GGML_CPU_FP16_TO_FP32(y[ib].d));
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector float vxmin = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].m));
+        vector float vys = {GGML_CPU_FP16_TO_FP32(y[ib].s), 0.0f, 0.0f, 0.0f};
+        vsumf0 = vec_madd(vxmin, vys, vsumf0);
+
+        vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);
+        vector signed char q8y0 = vec_xl( 0, y[ib].qs);
+        vector signed char q8y1 = vec_xl(16, y[ib].qs);
+
+        vector unsigned char q4x0 = (vector unsigned char)vec_and(qxs, lowMask);
+        vector unsigned char q4x1 = (vector unsigned char)vec_sr(qxs, v4);
+
+        vector signed int vsumi0 = v0;
+
+        vsumi0 = vec_msum(q8y0, q4x0, vsumi0);
+        vsumi0 = vec_msum(q8y1, q4x1, vsumi0);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+    }
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    sumf = vec_extract(vsumf0, 0);
+
+    *s = sumf;
+#else
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(ib);
+    UNUSED(sumf);
+    ggml_vec_dot_q4_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+    assert(n % QK_MXFP4 == 0);
+    static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
+
+    const block_mxfp4 * GGML_RESTRICT x = vx;
+    const block_q8_0 * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_MXFP4;
+
+    int ib = 0;
+    float sumf = 0;
+
+#if defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector unsigned char vshift4 = vec_splats((unsigned char)4);
+    vector float vsumf0 = vec_splats(0.0f);
+
+    vector signed char kv = vec_xl(0, (const signed char *)kvalues_mxfp4);
+
+#pragma GCC unroll 8
+    for (; ib < nb; ++ib) {
+        __builtin_prefetch(x[ib].qs, 0, 1);
+        __builtin_prefetch(y[ib].qs, 0, 1);
+
+        vector float vyd = vec_splats(GGML_CPU_FP16_TO_FP32(y[ib].d) *
+                                      GGML_E8M0_TO_FP32_HALF(x[ib].e));
+
+        vector signed char q8y0 = vec_xl( 0, y[ib].qs);
+        vector signed char q8y1 = vec_xl(16, y[ib].qs);
+
+        vector signed char qxs = (vector signed char)vec_xl(0, x[ib].qs);
+
+        vector unsigned char lo_nibbles = (vector unsigned char)vec_and(qxs, lowMask);
+        vector unsigned char hi_nibbles = (vector unsigned char)vec_sr(qxs, vshift4);
+
+        vector signed char q4x0 = vec_perm(kv, kv, lo_nibbles);
+        vector signed char q4x1 = vec_perm(kv, kv, hi_nibbles);
+
+        vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0));
+        vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1));
+
+        vector signed int vsumi0 = vec_splats((int32_t)0);
+        vsumi0 = vec_sum4s(qv0, vsumi0);
+        vsumi0 = vec_sum4s(qv1, vsumi0);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vyd, vsumf0);
+    }
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+    sumf = vec_extract(vsumf0, 0);
+    *s = sumf;
+#else
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(ib);
+    UNUSED(sumf);
+    ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+
+    int ib = 0;
+    float sumf = 0;
+
+    assert(n % qk == 0);
+    assert(qk == QK5_0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q5_0 * GGML_RESTRICT x = vx;
+    const block_q8_0 * GGML_RESTRICT y = vy;
+
+#if defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector unsigned char v4 = vec_splats((unsigned char)4);
+
+    vector float vsumf0 = vec_splats(0.0f);
+
+#pragma GCC unroll 4
+    for (; ib < nb; ++ib) {
+        __builtin_prefetch(x[ib].qs, 0, 1);
+        __builtin_prefetch(y[ib].qs, 0, 1);
+
+        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d));
+        vector float vyd = vec_splats(GGML_CPU_FP16_TO_FP32(y[ib].d));
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed long long aux64x2_0 = {(uint64_t)(table_b2b_1[x[ib].qh[0]]), (uint64_t)(table_b2b_1[x[ib].qh[1]])};
+        vector signed long long aux64x2_1 = {(uint64_t)(table_b2b_1[x[ib].qh[2]]), (uint64_t)(table_b2b_1[x[ib].qh[3]])};
+
+        vector signed char qh0 = (vector signed char)aux64x2_0;
+        vector signed char qh1 = (vector signed char)aux64x2_1;
+
+        vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);
+
+        vector signed char q5x0 = vec_sub(vec_and (qxs, lowMask), qh0);
+        vector signed char q5x1 = vec_sub(vec_sr(qxs, v4), qh1);
+
+        vector signed char q8y0 = vec_xl(  0, y[ib].qs);
+        vector signed char q8y1 = vec_xl( 16, y[ib].qs);
+
+        vector signed short qv0 = vec_add(vec_mule(q5x0, q8y0), vec_mulo(q5x0, q8y0));
+        vector signed short qv1 = vec_add(vec_mule(q5x1, q8y1), vec_mulo(q5x1, q8y1));
+
+        qv0 = vec_add(qv0, qv1);
+
+        vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0));
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+    }
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    sumf = vec_extract(vsumf0, 0);
+
+    *s = sumf;
+#else
+    UNUSED(ib);
+    UNUSED(sumf);
+    UNUSED(x);
+    UNUSED(y);
+    ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    const int qk = QK8_1;
+    const int nb = n / qk;
+
+    int ib = 0;
+    float sumf = 0;
+
+    assert(n % qk == 0);
+    assert(qk == QK5_1);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q5_1 * GGML_RESTRICT x = vx;
+    const block_q8_1 * GGML_RESTRICT y = vy;
+
+#if defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector signed int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+    vector float vsumf0 = vec_splats(0.0f);
+
+#pragma GCC unroll 4
+    for (; ib < nb; ++ib) {
+        __builtin_prefetch(x[ib].qs, 0, 1);
+        __builtin_prefetch(y[ib].qs, 0, 1);
+
+        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d));
+        vector float vyd = vec_splats(GGML_CPU_FP16_TO_FP32(y[ib].d));
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector float vxmin = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].m));
+        vector float vys = {GGML_CPU_FP16_TO_FP32(y[ib].s), 0.f, 0.f, 0.f};
+        vsumf0 = vec_madd(vxmin, vys, vsumf0);
+
+        vector unsigned long long aux64x2_0 = {(uint64_t)(table_b2b_0[x[ib].qh[0]]), (uint64_t)(table_b2b_0[x[ib].qh[1]])};
+        vector unsigned long long aux64x2_1 = {(uint64_t)(table_b2b_0[x[ib].qh[2]]), (uint64_t)(table_b2b_0[x[ib].qh[3]])};
+
+        vector signed char qh0 = (vector signed char)aux64x2_0;
+        vector signed char qh1 = (vector signed char)aux64x2_1;
+
+        vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);
+
+        vector unsigned char q5x0 = (vector unsigned char)vec_or(vec_and(qxs, lowMask), qh0);
+        vector unsigned char q5x1 = (vector unsigned char)vec_or(vec_sr(qxs, v4), qh1);
+
+        vector signed char q8y0 = vec_xl(  0, y[ib].qs);
+        vector signed char q8y1 = vec_xl( 16, y[ib].qs);
+
+        vector signed int vsumi0 = v0;
+
+        vsumi0 = vec_msum(q8y0, q5x0, vsumi0);
+        vsumi0 = vec_msum(q8y1, q5x1, vsumi0);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+    }
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    sumf = vec_extract(vsumf0, 0);
+
+    *s = sumf;
+#else
+    UNUSED(nb);
+    UNUSED(ib);
+    UNUSED(sumf);
+    UNUSED(x);
+    UNUSED(y);
+    ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+
+    assert(n % qk == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q8_0 * GGML_RESTRICT x = vx;
+    const block_q8_0 * GGML_RESTRICT y = vy;
+
+    int ib = 0;
+    float sumf = 0;
+
+#if defined(__POWER9_VECTOR__)
+    const vector signed int v0 = vec_splats((int32_t)0);
+    vector float vsumf0 = vec_splats(0.0f);
+
+#pragma GCC unroll 8
+    for (; ib < nb; ++ib) {
+        __builtin_prefetch(x[ib].qs, 0, 1);
+        __builtin_prefetch(y[ib].qs, 0, 1);
+
+        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d));
+        vector float vyd = vec_splats(GGML_CPU_FP16_TO_FP32(y[ib].d));
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed char q8x0 = vec_xl( 0, x[ib].qs);
+        vector signed char q8x1 = vec_xl(16, x[ib].qs);
+        vector signed char q8y0 = vec_xl( 0, y[ib].qs);
+        vector signed char q8y1 = vec_xl(16, y[ib].qs);
+
+        vector signed short qv0 = vec_mule(q8x0, q8y0);
+        vector signed short qv1 = vec_mulo(q8x0, q8y0);
+        vector signed short qv2 = vec_mule(q8x1, q8y1);
+        vector signed short qv3 = vec_mulo(q8x1, q8y1);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+
+        vsumi0 = vec_sum4s(qv0, vsumi0);
+        vsumi1 = vec_sum4s(qv1, vsumi1);
+        vsumi0 = vec_sum4s(qv2, vsumi0);
+        vsumi1 = vec_sum4s(qv3, vsumi1);
+
+        vsumi0 = vec_add(vsumi0, vsumi1);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+    }
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    sumf = vec_extract(vsumf0, 0);
+
+    *s = sumf;
+#else
+    UNUSED(nb);
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(ib);
+    UNUSED(sumf);
+    ggml_vec_dot_q8_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q2_K * GGML_RESTRICT x = vx;
+    const block_q8_K * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0x3);
+    const vector signed char lowScaleMask = vec_splats((signed char)0xF);
+    const vector int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v2 = vec_splats((unsigned char)0x2);
+    const vector unsigned char v6 = vec_splats((unsigned char)0x6);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector float vxmin = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].dmin));
+        vector float vdmin = vec_mul(vxmin, vyd);
+
+        vector signed short q8ysums0 = vec_xl( 0, y[i].bsums);
+        vector signed short q8ysums1 = vec_xl(16, y[i].bsums);
+
+        vector signed char q2xmins = (vector signed char)vec_xl( 0, x[i].scales);
+        vector signed char vscales = vec_and(q2xmins, lowScaleMask);
+
+        q2xmins = vec_sr(q2xmins, v4);
+        vector signed short q2xmins0 = vec_unpackh(q2xmins);
+        vector signed short q2xmins1 = vec_unpackl(q2xmins);
+
+        vector signed int prod0 = vec_mule(q2xmins0, q8ysums0);
+        vector signed int prod1 = vec_mulo(q2xmins0, q8ysums0);
+        vector signed int prod2 = vec_mule(q2xmins1, q8ysums1);
+        vector signed int prod3 = vec_mulo(q2xmins1, q8ysums1);
+
+        vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0);
+        vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1);
+        vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2);
+        vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+        vector signed int vsumi4 = v0;
+        vector signed int vsumi5 = v0;
+        vector signed int vsumi6 = v0;
+        vector signed int vsumi7 = v0;
+
+        const uint8_t * GGML_RESTRICT q2 = x[i].qs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+            __builtin_prefetch(q2, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector signed char qxs0 = (vector signed char)vec_xl( 0, q2);
+            vector signed char qxs1 = (vector signed char)vec_xl(16, q2);
+            q2 += 32;
+
+            vector unsigned char q2x00 = (vector unsigned char)vec_and(qxs0, lowMask);
+            vector unsigned char q2x01 = (vector unsigned char)vec_and(vec_sr(qxs0, v2), lowMask);
+            vector unsigned char q2x02 = (vector unsigned char)vec_and(vec_sr(qxs0, v4), lowMask);
+            vector unsigned char q2x03 = (vector unsigned char)vec_and(vec_sr(qxs0, v6), lowMask);
+            vector unsigned char q2x10 = (vector unsigned char)vec_and(qxs1, lowMask);
+            vector unsigned char q2x11 = (vector unsigned char)vec_and(vec_sr(qxs1, v2), lowMask);
+            vector unsigned char q2x12 = (vector unsigned char)vec_and(vec_sr(qxs1, v4), lowMask);
+            vector unsigned char q2x13 = (vector unsigned char)vec_and(vec_sr(qxs1, v6), lowMask);
+
+            vector signed char q8y00 = vec_xl(  0, q8);
+            vector signed char q8y10 = vec_xl( 16, q8);
+            vector signed char q8y01 = vec_xl( 32, q8);
+            vector signed char q8y11 = vec_xl( 48, q8);
+            vector signed char q8y02 = vec_xl( 64, q8);
+            vector signed char q8y12 = vec_xl( 80, q8);
+            vector signed char q8y03 = vec_xl( 96, q8);
+            vector signed char q8y13 = vec_xl(112, q8);
+            q8 += 128;
+
+            vector signed int qv0 = vec_msum(q8y00, q2x00, v0);
+            vector signed int qv1 = vec_msum(q8y01, q2x01, v0);
+            vector signed int qv2 = vec_msum(q8y02, q2x02, v0);
+            vector signed int qv3 = vec_msum(q8y03, q2x03, v0);
+            vector signed int qv4 = vec_msum(q8y10, q2x10, v0);
+            vector signed int qv5 = vec_msum(q8y11, q2x11, v0);
+            vector signed int qv6 = vec_msum(q8y12, q2x12, v0);
+            vector signed int qv7 = vec_msum(q8y13, q2x13, v0);
+
+            vector signed short vscales_07 = vec_unpackh(vscales);
+            vector signed int vscales_03 = vec_unpackh(vscales_07);
+            vector signed int vscales_47 = vec_unpackl(vscales_07);
+            vector signed int vs0 = vec_splat(vscales_03, 0);
+            vector signed int vs1 = vec_splat(vscales_03, 1);
+            vector signed int vs2 = vec_splat(vscales_03, 2);
+            vector signed int vs3 = vec_splat(vscales_03, 3);
+            vector signed int vs4 = vec_splat(vscales_47, 0);
+            vector signed int vs5 = vec_splat(vscales_47, 1);
+            vector signed int vs6 = vec_splat(vscales_47, 2);
+            vector signed int vs7 = vec_splat(vscales_47, 3);
+            vscales = vec_sld(vscales, vscales, 8);
+
+            vsumi0 = vec_add(vec_mul(qv0, vs0), vsumi0);
+            vsumi1 = vec_add(vec_mul(qv1, vs2), vsumi1);
+            vsumi2 = vec_add(vec_mul(qv2, vs4), vsumi2);
+            vsumi3 = vec_add(vec_mul(qv3, vs6), vsumi3);
+            vsumi4 = vec_add(vec_mul(qv4, vs1), vsumi4);
+            vsumi5 = vec_add(vec_mul(qv5, vs3), vsumi5);
+            vsumi6 = vec_add(vec_mul(qv6, vs5), vsumi6);
+            vsumi7 = vec_add(vec_mul(qv7, vs7), vsumi7);
+        }
+
+        vsumi0 = vec_add(vsumi0, vsumi4);
+        vsumi1 = vec_add(vsumi1, vsumi5);
+        vsumi2 = vec_add(vsumi2, vsumi6);
+        vsumi3 = vec_add(vsumi3, vsumi7);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = vec_extract(vsumf0, 0);
+
+#else
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const uint32_t kmask1 = 0x03030303;
+    const uint32_t kmask2 = 0x0f0f0f0f;
+
+    const block_q3_K * GGML_RESTRICT x = vx;
+    const block_q8_K * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0x3);
+    const vector signed char lowMask1 = vec_splats((int8_t)0xf);
+    const vector signed char lowMask2 = vec_splats((int8_t)0x30);
+    const vector int v0 = vec_splats((int32_t)0);
+    const vector signed char v1 = vec_splats((signed char)0x1);
+    const vector unsigned char v2 = vec_splats((unsigned char)0x2);
+    const vector unsigned char v3 = vec_splats((unsigned char)0x3);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+    const vector unsigned char v6 = vec_splats((unsigned char)0x6);
+    const vector signed char off = vec_splats((signed char)0x20);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        UNUSED(kmask1);
+        UNUSED(kmask2);
+
+        vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8);
+        vector signed char u1 = vec_and(u0, lowMask1);
+        vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4);
+        vector signed char u3 = (vector signed char)vec_mergeh((vector signed int)u2, (vector signed int)vec_sr(u2, v2));
+        vector signed char u30 = vec_sl(vec_and(u3, lowMask), v4);
+        vector signed char u31 = vec_and(u3, lowMask2);
+
+        u1 = vec_or(u1, u30);
+        u2 = vec_or(vec_sr(u0, v4), u31);
+
+        vector signed char vscales = (vector signed char)vec_mergeh((vector signed long long)u1, (vector signed long long)u2);
+        vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].hmask);
+        vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].hmask);
+
+        vscales = vec_sub(vscales, off);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+        vector signed int vsumi4 = v0;
+        vector signed int vsumi5 = v0;
+        vector signed int vsumi6 = v0;
+        vector signed int vsumi7 = v0;
+
+        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+            __builtin_prefetch(q3, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector signed char qxs0 = (vector signed char)vec_xl( 0, q3);
+            vector signed char qxs1 = (vector signed char)vec_xl(16, q3);
+            q3 += 32;
+
+            //the low 2 bits
+            vector signed char qxs00 = vec_and(qxs0, lowMask);
+            vector signed char qxs01 = vec_and(vec_sr(qxs0, v2), lowMask);
+            vector signed char qxs02 = vec_and(vec_sr(qxs0, v4), lowMask);
+            vector signed char qxs03 = vec_and(vec_sr(qxs0, v6), lowMask);
+            vector signed char qxs10 = vec_and(qxs1, lowMask);
+            vector signed char qxs11 = vec_and(vec_sr(qxs1, v2), lowMask);
+            vector signed char qxs12 = vec_and(vec_sr(qxs1, v4), lowMask);
+            vector signed char qxs13 = vec_and(vec_sr(qxs1, v6), lowMask);
+
+            //the 3rd bit
+            vector signed char qxh00 = vec_sl(vec_andc(v1, qxhs0), v2);
+            vector signed char qxh01 = vec_sl(vec_andc(v1, vec_sr(qxhs0, (vector unsigned char)v1)), v2);
+            vector signed char qxh02 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v2)), v2);
+            vector signed char qxh03 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v3)), v2);
+            vector signed char qxh10 = vec_sl(vec_andc(v1, qxhs1), v2);
+            vector signed char qxh11 = vec_sl(vec_andc(v1, vec_sr(qxhs1, (vector unsigned char)v1)), v2);
+            vector signed char qxh12 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v2)), v2);
+            vector signed char qxh13 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v3)), v2);
+            qxhs0 = vec_sr(qxhs0, v4);
+            qxhs1 = vec_sr(qxhs1, v4);
+
+            vector signed char q3x00 = vec_sub(qxs00, qxh00);
+            vector signed char q3x01 = vec_sub(qxs01, qxh01);
+            vector signed char q3x02 = vec_sub(qxs02, qxh02);
+            vector signed char q3x03 = vec_sub(qxs03, qxh03);
+            vector signed char q3x10 = vec_sub(qxs10, qxh10);
+            vector signed char q3x11 = vec_sub(qxs11, qxh11);
+            vector signed char q3x12 = vec_sub(qxs12, qxh12);
+            vector signed char q3x13 = vec_sub(qxs13, qxh13);
+
+            vector signed char q8y00 = vec_xl(  0, q8);
+            vector signed char q8y10 = vec_xl( 16, q8);
+            vector signed char q8y01 = vec_xl( 32, q8);
+            vector signed char q8y11 = vec_xl( 48, q8);
+            vector signed char q8y02 = vec_xl( 64, q8);
+            vector signed char q8y12 = vec_xl( 80, q8);
+            vector signed char q8y03 = vec_xl( 96, q8);
+            vector signed char q8y13 = vec_xl(112, q8);
+            q8 += 128;
+
+            vector signed short vscales_h = vec_unpackh(vscales);
+            vector signed short vs0 = vec_splat(vscales_h, 0);
+            vector signed short vs1 = vec_splat(vscales_h, 1);
+            vector signed short vs2 = vec_splat(vscales_h, 2);
+            vector signed short vs3 = vec_splat(vscales_h, 3);
+            vector signed short vs4 = vec_splat(vscales_h, 4);
+            vector signed short vs5 = vec_splat(vscales_h, 5);
+            vector signed short vs6 = vec_splat(vscales_h, 6);
+            vector signed short vs7 = vec_splat(vscales_h, 7);
+            vscales = vec_sld(vscales, vscales, 8);
+
+            vector signed short qv00 = vec_add(vec_mule(q3x00, q8y00), vec_mulo(q3x00, q8y00));
+            vector signed short qv01 = vec_add(vec_mule(q3x01, q8y01), vec_mulo(q3x01, q8y01));
+            vector signed short qv02 = vec_add(vec_mule(q3x02, q8y02), vec_mulo(q3x02, q8y02));
+            vector signed short qv03 = vec_add(vec_mule(q3x03, q8y03), vec_mulo(q3x03, q8y03));
+            vector signed short qv10 = vec_add(vec_mule(q3x10, q8y10), vec_mulo(q3x10, q8y10));
+            vector signed short qv11 = vec_add(vec_mule(q3x11, q8y11), vec_mulo(q3x11, q8y11));
+            vector signed short qv12 = vec_add(vec_mule(q3x12, q8y12), vec_mulo(q3x12, q8y12));
+            vector signed short qv13 = vec_add(vec_mule(q3x13, q8y13), vec_mulo(q3x13, q8y13));
+
+            vsumi0 = vec_msum(qv00, vs0, vsumi0);
+            vsumi1 = vec_msum(qv01, vs2, vsumi1);
+            vsumi2 = vec_msum(qv02, vs4, vsumi2);
+            vsumi3 = vec_msum(qv03, vs6, vsumi3);
+            vsumi4 = vec_msum(qv10, vs1, vsumi4);
+            vsumi5 = vec_msum(qv11, vs3, vsumi5);
+            vsumi6 = vec_msum(qv12, vs5, vsumi6);
+            vsumi7 = vec_msum(qv13, vs7, vsumi7);
+        }
+
+        vsumi0 = vec_add(vsumi0, vsumi4);
+        vsumi1 = vec_add(vsumi1, vsumi5);
+        vsumi2 = vec_add(vsumi2, vsumi6);
+        vsumi3 = vec_add(vsumi3, vsumi7);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = vec_extract(vsumf0, 0);
+
+#else
+    UNUSED(kmask1);
+    UNUSED(kmask2);
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q4_K * GGML_RESTRICT x = vx;
+    const block_q8_K * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_K;
+
+    static const uint32_t kmask1 = 0x3f3f3f3f;
+    static const uint32_t kmask2 = 0x0f0f0f0f;
+    static const uint32_t kmask3 = 0x03030303;
+
+    uint32_t utmp[4];
+
+#if defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector signed char lowMask1 = vec_splats((int8_t)0x3f);
+    const vector signed char lowMask2 = vec_splats((int8_t)0x30);
+    const vector int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v2 = vec_splats((uint8_t)2);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector float vxmin = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].dmin));
+        vector float vdmin = vec_mul(vxmin, vyd);
+
+        vector signed short q8ysums0 = vec_xl( 0, y[i].bsums);
+        vector signed short q8ysums1 = vec_xl(16, y[i].bsums);
+
+        UNUSED(kmask1);
+        UNUSED(kmask2);
+        UNUSED(kmask3);
+        UNUSED(utmp);
+
+        vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8);
+        vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2);
+        vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4);
+        vector signed char u3 = vec_sr(u2, v4);
+
+        vector signed char u30 = u1;
+        vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3);
+
+        u1 = vec_and(u0, lowMask1);
+        u2 = vec_or(u30, u31);
+
+        vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2);
+
+        vector signed short vscales = vec_unpackh(utmps);
+        vector signed short q4xmins = vec_unpackl(utmps);
+        vector signed short q4xmins0 = vec_mergeh(q4xmins, q4xmins);
+        vector signed short q4xmins1 = vec_mergel(q4xmins, q4xmins);
+
+        vector signed int prod0 = vec_mule(q4xmins0, q8ysums0);
+        vector signed int prod1 = vec_mule(q4xmins1, q8ysums1);
+        vector signed int prod2 = vec_mulo(q4xmins0, q8ysums0);
+        vector signed int prod3 = vec_mulo(q4xmins1, q8ysums1);
+
+        vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0);
+        vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1);
+        vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2);
+        vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+
+        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
+
+        for (int j = 0; j < QK_K/64; j+=2) {
+            __builtin_prefetch(q4, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector signed char qxs0 = (vector signed char)vec_xl( 0, q4);
+            vector signed char qxs1 = (vector signed char)vec_xl(16, q4);
+            vector signed char qxs2 = (vector signed char)vec_xl(32, q4);
+            vector signed char qxs3 = (vector signed char)vec_xl(48, q4);
+            q4 += 64;
+
+            vector unsigned char q4x00 = (vector unsigned char)vec_and(qxs0, lowMask);
+            vector unsigned char q4x01 = (vector unsigned char)vec_sr(qxs0, v4);
+            vector unsigned char q4x10 = (vector unsigned char)vec_and(qxs1, lowMask);
+            vector unsigned char q4x11 = (vector unsigned char)vec_sr(qxs1, v4);
+            vector unsigned char q4x20 = (vector unsigned char)vec_and(qxs2, lowMask);
+            vector unsigned char q4x21 = (vector unsigned char)vec_sr(qxs2, v4);
+            vector unsigned char q4x30 = (vector unsigned char)vec_and(qxs3, lowMask);
+            vector unsigned char q4x31 = (vector unsigned char)vec_sr(qxs3, v4);
+
+            vector signed char q8y00 = vec_xl(  0, q8);
+            vector signed char q8y10 = vec_xl( 16, q8);
+            vector signed char q8y01 = vec_xl( 32, q8);
+            vector signed char q8y11 = vec_xl( 48, q8);
+            vector signed char q8y20 = vec_xl( 64, q8);
+            vector signed char q8y30 = vec_xl( 80, q8);
+            vector signed char q8y21 = vec_xl( 96, q8);
+            vector signed char q8y31 = vec_xl(112, q8);
+            q8 += 128;
+
+            vector signed int qv00 = vec_msum(q8y00, q4x00, v0);
+            vector signed int qv01 = vec_msum(q8y01, q4x01, v0);
+            vector signed int qv10 = vec_msum(q8y10, q4x10, v0);
+            vector signed int qv11 = vec_msum(q8y11, q4x11, v0);
+            vector signed int qv20 = vec_msum(q8y20, q4x20, v0);
+            vector signed int qv21 = vec_msum(q8y21, q4x21, v0);
+            vector signed int qv30 = vec_msum(q8y30, q4x30, v0);
+            vector signed int qv31 = vec_msum(q8y31, q4x31, v0);
+
+            vector signed int vscales_h = vec_unpackh(vscales);
+            vector signed int vs0 = vec_splat(vscales_h, 0);
+            vector signed int vs1 = vec_splat(vscales_h, 1);
+            vector signed int vs2 = vec_splat(vscales_h, 2);
+            vector signed int vs3 = vec_splat(vscales_h, 3);
+            vscales = vec_sld(vscales, vscales, 8);
+
+            vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0);
+            vsumi1 = vec_add(vec_mul(qv01, vs1), vsumi1);
+            vsumi2 = vec_add(vec_mul(qv20, vs2), vsumi2);
+            vsumi3 = vec_add(vec_mul(qv21, vs3), vsumi3);
+
+            vsumi0 = vec_add(vec_mul(qv10, vs0), vsumi0);
+            vsumi1 = vec_add(vec_mul(qv11, vs1), vsumi1);
+            vsumi2 = vec_add(vec_mul(qv30, vs2), vsumi2);
+            vsumi3 = vec_add(vec_mul(qv31, vs3), vsumi3);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = vec_extract(vsumf0, 0);
+
+#else
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    UNUSED(kmask1);
+    UNUSED(kmask2);
+    UNUSED(kmask3);
+    UNUSED(utmp);
+    ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy,  size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q5_K * GGML_RESTRICT x = vx;
+    const block_q8_K * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_K;
+
+    static const uint32_t kmask1 = 0x3f3f3f3f;
+    static const uint32_t kmask2 = 0x0f0f0f0f;
+    static const uint32_t kmask3 = 0x03030303;
+
+    uint32_t utmp[4];
+
+#if defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector signed char lowMask1 = vec_splats((int8_t)0x3f);
+    const vector signed char lowMask2 = vec_splats((int8_t)0x30);
+    const vector int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v1 = vec_splats((unsigned char)0x1);
+    const vector unsigned char v2 = vec_splats((unsigned char)0x2);
+    const vector unsigned char v3 = vec_splats((unsigned char)0x3);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector float vxmin = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].dmin));
+        vector float vdmin = vec_mul(vxmin, vyd);
+
+        UNUSED(kmask1);
+        UNUSED(kmask2);
+        UNUSED(kmask3);
+        UNUSED(utmp);
+
+        vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8);
+        vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2);
+        vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4);
+        vector signed char u3 = vec_sr(u2, v4);
+
+        vector signed char u30 = u1;
+        vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3);
+
+        u1 = vec_and(u0, lowMask1);
+        u2 = vec_or(u30, u31);
+
+        vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2);
+
+        vector signed short q8ysums0 = vec_xl( 0, y[i].bsums);
+        vector signed short q8ysums1 = vec_xl(16, y[i].bsums);
+
+        vector signed short vscales = vec_unpackh(utmps);
+
+        vector signed short q5xmins = vec_unpackl(utmps);
+        vector signed short q5xmins0 = vec_mergeh(q5xmins, q5xmins);
+        vector signed short q5xmins1 = vec_mergel(q5xmins, q5xmins);
+
+        vector signed int prod0 = vec_mule(q5xmins0, q8ysums0);
+        vector signed int prod1 = vec_mule(q5xmins1, q8ysums1);
+        vector signed int prod2 = vec_mulo(q5xmins0, q8ysums0);
+        vector signed int prod3 = vec_mulo(q5xmins1, q8ysums1);
+
+        vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0);
+        vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1);
+        vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2);
+        vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3);
+
+        vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].qh);
+        vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].qh);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+
+        const uint8_t * GGML_RESTRICT q5 = x[i].qs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+            __builtin_prefetch(q5, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector signed char qxs0 = (vector signed char)vec_xl( 0, q5);
+            vector signed char qxs1 = (vector signed char)vec_xl(16, q5);
+            q5 += 32;
+
+            vector signed char qxs00 = vec_and(qxs0, lowMask);
+            vector signed char qxs01 = vec_sr(qxs0, v4);
+            vector signed char qxs10 = vec_and(qxs1, lowMask);
+            vector signed char qxs11 = vec_sr(qxs1, v4);
+
+            vector signed char q5h00 = vec_sl(vec_and((vector signed char)v1, qxhs0), v4);
+            vector signed char q5h01 = vec_sl(vec_and((vector signed char)v2, qxhs0), v3);
+            vector signed char q5h10 = vec_sl(vec_and((vector signed char)v1, qxhs1), v4);
+            vector signed char q5h11 = vec_sl(vec_and((vector signed char)v2, qxhs1), v3);
+            qxhs0 = vec_sr(qxhs0, v2);
+            qxhs1 = vec_sr(qxhs1, v2);
+
+            vector unsigned char q5x00 = (vector unsigned char)vec_or(q5h00, qxs00);
+            vector unsigned char q5x01 = (vector unsigned char)vec_or(q5h01, qxs01);
+            vector unsigned char q5x10 = (vector unsigned char)vec_or(q5h10, qxs10);
+            vector unsigned char q5x11 = (vector unsigned char)vec_or(q5h11, qxs11);
+
+            vector signed char q8y00 = vec_xl( 0, q8);
+            vector signed char q8y10 = vec_xl(16, q8);
+            vector signed char q8y01 = vec_xl(32, q8);
+            vector signed char q8y11 = vec_xl(48, q8);
+            q8 += 64;
+
+            vector signed int qv00 = vec_msum(q8y00, q5x00, v0);
+            vector signed int qv01 = vec_msum(q8y01, q5x01, v0);
+            vector signed int qv10 = vec_msum(q8y10, q5x10, v0);
+            vector signed int qv11 = vec_msum(q8y11, q5x11, v0);
+
+            vector signed int vscales_h = vec_unpackh(vscales);
+            vector signed int vs0 = vec_splat(vscales_h, 0);
+            vector signed int vs1 = vec_splat(vscales_h, 1);
+            vscales = vec_sld(vscales, vscales, 12);
+
+            vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0);
+            vsumi1 = vec_add(vec_mul(qv10, vs0), vsumi1);
+            vsumi2 = vec_add(vec_mul(qv01, vs1), vsumi2);
+            vsumi3 = vec_add(vec_mul(qv11, vs1), vsumi3);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = vec_extract(vsumf0, 0);
+
+#else
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    UNUSED(kmask1);
+    UNUSED(kmask2);
+    UNUSED(kmask3);
+    UNUSED(utmp);
+    ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q6_K * GGML_RESTRICT x = vx;
+    const block_q8_K * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v2 = vec_splats((unsigned char)0x2);
+    const vector unsigned char v3 = vec_splats((unsigned char)0x3);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+    const vector unsigned char v6 = vec_splats((unsigned char)0x6);
+    const vector signed char off = vec_splats((signed char)0x20);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+        vector signed int vsumi4 = v0;
+        vector signed int vsumi5 = v0;
+        vector signed int vsumi6 = v0;
+        vector signed int vsumi7 = v0;
+
+        const uint8_t * GGML_RESTRICT q6 = x[i].ql;
+        const uint8_t * GGML_RESTRICT qh = x[i].qh;
+        const int8_t  * GGML_RESTRICT qs = x[i].scales;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+            __builtin_prefetch(q6, 0, 0);
+            __builtin_prefetch(qh, 0, 0);
+            __builtin_prefetch(q8, 0, 0);
+
+            vector signed char qxs0 = (vector signed char)vec_xl( 0, q6);
+            vector signed char qxs1 = (vector signed char)vec_xl(16, q6);
+            vector signed char qxs2 = (vector signed char)vec_xl(32, q6);
+            vector signed char qxs3 = (vector signed char)vec_xl(48, q6);
+            q6 += 64;
+
+            vector signed char qxs00 = vec_and(qxs0, lowMask);
+            vector signed char qxs01 = vec_sr(qxs0, v4);
+            vector signed char qxs10 = vec_and(qxs1, lowMask);
+            vector signed char qxs11 = vec_sr(qxs1, v4);
+            vector signed char qxs20 = vec_and(qxs2, lowMask);
+            vector signed char qxs21 = vec_sr(qxs2, v4);
+            vector signed char qxs30 = vec_and(qxs3, lowMask);
+            vector signed char qxs31 = vec_sr(qxs3, v4);
+
+            vector signed char qxhs0 = (vector signed char)vec_xl( 0, qh);
+            vector signed char qxhs1 = (vector signed char)vec_xl(16, qh);
+            qh += 32;
+
+            vector signed char qxh00 = vec_sl(vec_and((vector signed char)v3, qxhs0), v4);
+            vector signed char qxh01 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v4)), v4);
+            vector signed char qxh10 = vec_sl(vec_and((vector signed char)v3, qxhs1), v4);
+            vector signed char qxh11 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v4)), v4);
+            vector signed char qxh20 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v2)), v4);
+            vector signed char qxh21 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v6)), v4);
+            vector signed char qxh30 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v2)), v4);
+            vector signed char qxh31 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v6)), v4);
+
+            vector signed char q6x00 = vec_sub(vec_or(qxh00, qxs00), off);
+            vector signed char q6x01 = vec_sub(vec_or(qxh01, qxs01), off);
+            vector signed char q6x10 = vec_sub(vec_or(qxh10, qxs10), off);
+            vector signed char q6x11 = vec_sub(vec_or(qxh11, qxs11), off);
+            vector signed char q6x20 = vec_sub(vec_or(qxh20, qxs20), off);
+            vector signed char q6x21 = vec_sub(vec_or(qxh21, qxs21), off);
+            vector signed char q6x30 = vec_sub(vec_or(qxh30, qxs30), off);
+            vector signed char q6x31 = vec_sub(vec_or(qxh31, qxs31), off);
+
+            vector signed char q8y00 = vec_xl(  0, q8);
+            vector signed char q8y10 = vec_xl( 16, q8);
+            vector signed char q8y20 = vec_xl( 32, q8);
+            vector signed char q8y30 = vec_xl( 48, q8);
+            vector signed char q8y01 = vec_xl( 64, q8);
+            vector signed char q8y11 = vec_xl( 80, q8);
+            vector signed char q8y21 = vec_xl( 96, q8);
+            vector signed char q8y31 = vec_xl(112, q8);
+            q8 += 128;
+
+            vector signed short qv00 = vec_add(vec_mule(q6x00, q8y00), vec_mulo(q6x00, q8y00));
+            vector signed short qv10 = vec_add(vec_mule(q6x10, q8y10), vec_mulo(q6x10, q8y10));
+            vector signed short qv20 = vec_add(vec_mule(q6x20, q8y20), vec_mulo(q6x20, q8y20));
+            vector signed short qv30 = vec_add(vec_mule(q6x30, q8y30), vec_mulo(q6x30, q8y30));
+            vector signed short qv01 = vec_add(vec_mule(q6x01, q8y01), vec_mulo(q6x01, q8y01));
+            vector signed short qv11 = vec_add(vec_mule(q6x11, q8y11), vec_mulo(q6x11, q8y11));
+            vector signed short qv21 = vec_add(vec_mule(q6x21, q8y21), vec_mulo(q6x21, q8y21));
+            vector signed short qv31 = vec_add(vec_mule(q6x31, q8y31), vec_mulo(q6x31, q8y31));
+
+            vector signed short vscales = vec_unpackh(vec_xl_len(qs, 8));
+            qs += 8;
+
+            vector signed short vs0 = vec_splat(vscales, 0);
+            vector signed short vs1 = vec_splat(vscales, 1);
+            vector signed short vs2 = vec_splat(vscales, 2);
+            vector signed short vs3 = vec_splat(vscales, 3);
+            vector signed short vs4 = vec_splat(vscales, 4);
+            vector signed short vs5 = vec_splat(vscales, 5);
+            vector signed short vs6 = vec_splat(vscales, 6);
+            vector signed short vs7 = vec_splat(vscales, 7);
+
+            vsumi0 = vec_msum(qv00, vs0, vsumi0);
+            vsumi1 = vec_msum(qv01, vs4, vsumi1);
+            vsumi2 = vec_msum(qv10, vs1, vsumi2);
+            vsumi3 = vec_msum(qv11, vs5, vsumi3);
+            vsumi4 = vec_msum(qv20, vs2, vsumi4);
+            vsumi5 = vec_msum(qv21, vs6, vsumi5);
+            vsumi6 = vec_msum(qv30, vs3, vsumi6);
+            vsumi7 = vec_msum(qv31, vs7, vsumi7);
+        }
+
+        vsumi0 = vec_add(vsumi0, vsumi4);
+        vsumi1 = vec_add(vsumi1, vsumi5);
+        vsumi2 = vec_add(vsumi2, vsumi6);
+        vsumi3 = vec_add(vsumi3, vsumi7);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = vec_extract(vsumf0, 0);
+
+#else
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+#if defined (__POWER9_VECTOR__)
+static const int8_t keven_signs_q2xs[1024] = {
+     1,  1,  1,  1,  1,  1,  1,  1, -1,  1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1,  1,
+     1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1,  1,  1, -1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1, -1,
+     1,  1,  1, -1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1, -1,
+     1,  1, -1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1,  1,
+     1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1, -1,
+     1,  1, -1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1, -1,  1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1,  1,
+     1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1,  1,
+     1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1,  1,  1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1, -1,
+     1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1,  1,  1,  1, -1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1, -1,  1, -1,
+     1,  1, -1,  1,  1, -1,  1,  1, -1,  1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1,  1,
+     1,  1,  1, -1,  1, -1,  1,  1, -1,  1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1,  1,
+     1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1, -1,
+     1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1, -1,  1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1,  1,
+     1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1, -1,
+     1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1, -1,
+     1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1,  1,
+     1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1, -1,
+     1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1,  1,
+     1,  1,  1, -1,  1,  1, -1,  1, -1,  1,  1, -1,  1,  1, -1, -1,  1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1,  1,
+     1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1,  1,  1, -1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1, -1,
+     1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1,  1, -1,  1, -1, -1,  1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1,  1,
+     1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1,  1,  1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1, -1,
+     1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1,  1,  1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1, -1,
+     1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1,  1,
+     1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1, -1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1,  1,
+     1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1,  1,  1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1, -1,
+     1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1, -1,
+     1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1,  1,
+     1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1,  1,  1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1, -1,
+     1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1,  1,
+     1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1,  1,
+     1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1,  1,  1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1,
+};
+#endif
+
+void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq2_xxs * GGML_RESTRICT x = vx;
+    const block_q8_K    * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined(__POWER9_VECTOR__)
+    const vector int v0 = vec_splats((int32_t)0);
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+
+        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
+        const int8_t  *  GGML_RESTRICT q8 = y[i].qs;
+
+        for (int j = 0; j < QK_K/32; j += 2) {
+            __builtin_prefetch(q2, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            uint32_t aux32[4];
+            const uint8_t * aux8 = (const uint8_t *)aux32;
+
+            memcpy(aux32, q2, 4*sizeof(uint32_t));
+            q2 += 8;
+
+            vector signed long long aux64x2_0 = {*(const int64_t *)(iq2xxs_grid + aux8[ 0]), *(const int64_t *)(iq2xxs_grid + aux8[ 1])};
+            vector signed long long aux64x2_1 = {*(const int64_t *)(iq2xxs_grid + aux8[ 2]), *(const int64_t *)(iq2xxs_grid + aux8[ 3])};
+            vector signed long long aux64x2_2 = {*(const int64_t *)(iq2xxs_grid + aux8[ 8]), *(const int64_t *)(iq2xxs_grid + aux8[ 9])};
+            vector signed long long aux64x2_3 = {*(const int64_t *)(iq2xxs_grid + aux8[10]), *(const int64_t *)(iq2xxs_grid + aux8[11])};
+
+            vector signed long long vsigns0 = {*(const int64_t *)(signs64 + ((aux32[1] >>  0) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >>  7) & 127))};
+            vector signed long long vsigns1 = {*(const int64_t *)(signs64 + ((aux32[1] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 21) & 127))};
+            vector signed long long vsigns2 = {*(const int64_t *)(signs64 + ((aux32[3] >>  0) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >>  7) & 127))};
+            vector signed long long vsigns3 = {*(const int64_t *)(signs64 + ((aux32[3] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 21) & 127))};
+
+            vector signed char q2x0 = (vector signed char)vec_mul((vector signed char)vsigns0, (vector signed char)aux64x2_0);
+            vector signed char q2x1 = (vector signed char)vec_mul((vector signed char)vsigns1, (vector signed char)aux64x2_1);
+            vector signed char q2x2 = (vector signed char)vec_mul((vector signed char)vsigns2, (vector signed char)aux64x2_2);
+            vector signed char q2x3 = (vector signed char)vec_mul((vector signed char)vsigns3, (vector signed char)aux64x2_3);
+
+            vector signed char q8y0 = vec_xl( 0, q8);
+            vector signed char q8y1 = vec_xl(16, q8);
+            vector signed char q8y2 = vec_xl(32, q8);
+            vector signed char q8y3 = vec_xl(48, q8);
+            q8 += 64;
+
+            vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0));
+            vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1));
+            vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2));
+            vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3));
+
+            const uint16_t ls0 = aux32[1] >> 28;
+            const uint16_t ls1 = aux32[3] >> 28;
+
+            vector signed short vscales01 = vec_splats((int16_t)(2*ls0+1));
+            vector signed short vscales23 = vec_splats((int16_t)(2*ls1+1));
+
+            vsumi0 = vec_msum(qv0, vscales01, vsumi0);
+            vsumi1 = vec_msum(qv1, vscales01, vsumi1);
+            vsumi2 = vec_msum(qv2, vscales23, vsumi2);
+            vsumi3 = vec_msum(qv3, vscales23, vsumi3);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = 0.125f * vec_extract(vsumf0, 0);
+
+#else
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq2_xs * GGML_RESTRICT x = vx;
+    const block_q8_K   * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined(__POWER9_VECTOR__)
+    const vector int v0 = vec_splats((int32_t)0);
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+
+        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
+        const uint8_t  * GGML_RESTRICT sc = x[i].scales;
+        const int8_t  *  GGML_RESTRICT q8 = y[i].qs;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+            __builtin_prefetch(q2, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector signed long long aux64x2_0 = {*(const int64_t *)(iq2xs_grid + (q2[0] & 511)), *(const int64_t *)(iq2xs_grid + (q2[1] & 511))};
+            vector signed long long aux64x2_1 = {*(const int64_t *)(iq2xs_grid + (q2[2] & 511)), *(const int64_t *)(iq2xs_grid + (q2[3] & 511))};
+            vector signed long long aux64x2_2 = {*(const int64_t *)(iq2xs_grid + (q2[4] & 511)), *(const int64_t *)(iq2xs_grid + (q2[5] & 511))};
+            vector signed long long aux64x2_3 = {*(const int64_t *)(iq2xs_grid + (q2[6] & 511)), *(const int64_t *)(iq2xs_grid + (q2[7] & 511))};
+
+            vector signed long long vsigns0 = {*(const int64_t *)(signs64 + ((q2[0] >> 9))), *(const int64_t *)(signs64 + ((q2[1] >> 9)))};
+            vector signed long long vsigns1 = {*(const int64_t *)(signs64 + ((q2[2] >> 9))), *(const int64_t *)(signs64 + ((q2[3] >> 9)))};
+            vector signed long long vsigns2 = {*(const int64_t *)(signs64 + ((q2[4] >> 9))), *(const int64_t *)(signs64 + ((q2[5] >> 9)))};
+            vector signed long long vsigns3 = {*(const int64_t *)(signs64 + ((q2[6] >> 9))), *(const int64_t *)(signs64 + ((q2[7] >> 9)))};
+            q2 += 8;
+
+            vector signed char q2x0 = (vector signed char)vec_mul((vector signed char)vsigns0, (vector signed char)aux64x2_0);
+            vector signed char q2x1 = (vector signed char)vec_mul((vector signed char)vsigns1, (vector signed char)aux64x2_1);
+            vector signed char q2x2 = (vector signed char)vec_mul((vector signed char)vsigns2, (vector signed char)aux64x2_2);
+            vector signed char q2x3 = (vector signed char)vec_mul((vector signed char)vsigns3, (vector signed char)aux64x2_3);
+
+            vector signed char q8y0 = vec_xl( 0, q8);
+            vector signed char q8y1 = vec_xl(16, q8);
+            vector signed char q8y2 = vec_xl(32, q8);
+            vector signed char q8y3 = vec_xl(48, q8);
+            q8 += 64;
+
+            vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0));
+            vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1));
+            vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2));
+            vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3));
+
+            const uint16_t ls0 = (uint16_t)(sc[0] & 0xf);
+            const uint16_t ls1 = (uint16_t)(sc[0] >>  4);
+            const uint16_t ls2 = (uint16_t)(sc[1] & 0xf);
+            const uint16_t ls3 = (uint16_t)(sc[1] >>  4);
+            sc += 2;
+
+            vector signed short vscales0 = vec_splats((int16_t)(2*ls0+1));
+            vector signed short vscales1 = vec_splats((int16_t)(2*ls1+1));
+            vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1));
+            vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1));
+
+            vsumi0 = vec_msum(qv0, vscales0, vsumi0);
+            vsumi1 = vec_msum(qv1, vscales1, vsumi1);
+            vsumi2 = vec_msum(qv2, vscales2, vsumi2);
+            vsumi3 = vec_msum(qv3, vscales3, vsumi3);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = 0.125f * vec_extract(vsumf0, 0);
+
+#else
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq2_s * GGML_RESTRICT x = vx;
+    const block_q8_K  * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined(__POWER9_VECTOR__)
+    static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+                                        0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+    };
+
+    static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
+
+    const vector int v0 = vec_splats((int32_t)0);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    const vector unsigned char mask0 = vec_xl( 0, k_mask1);
+    const vector unsigned char mask1 = vec_xl(16, k_mask1);
+    const vector signed char mask2 = (vector signed char)vec_xl( 0, k_mask2);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+
+        const uint8_t *  GGML_RESTRICT q2 = x[i].qs;
+        const uint8_t *  GGML_RESTRICT qh = x[i].qh;
+        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8);
+        const uint8_t *  GGML_RESTRICT sc = x[i].scales;
+        const int8_t  *  GGML_RESTRICT q8 = y[i].qs;
+
+        for (int j = 0; j < QK_K/32; j += 2) {
+            __builtin_prefetch(q2, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector signed long long aux64x2_0 = {*(const int64_t *)(iq2s_grid + (q2[0] | ((qh[0] << 8) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[1] | ((qh[0] << 6) & 0x300)))};
+            vector signed long long aux64x2_1 = {*(const int64_t *)(iq2s_grid + (q2[2] | ((qh[0] << 4) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[3] | ((qh[0] << 2) & 0x300)))};
+            vector signed long long aux64x2_2 = {*(const int64_t *)(iq2s_grid + (q2[4] | ((qh[1] << 8) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[5] | ((qh[1] << 6) & 0x300)))};
+            vector signed long long aux64x2_3 = {*(const int64_t *)(iq2s_grid + (q2[6] | ((qh[1] << 4) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[7] | ((qh[1] << 2) & 0x300)))};
+            q2 += 8;
+            qh += 2;
+
+            vector signed char vsigns01 = (vector signed char)vec_splats(*(const uint32_t *)&signs[0]);
+            vector signed char vsigns23 = (vector signed char)vec_splats(*(const uint32_t *)&signs[2]);
+            signs += 4;
+
+            vector signed char vsigns0 = vec_perm(vsigns01, vsigns01, mask0);
+            vector signed char vsigns1 = vec_perm(vsigns01, vsigns01, mask1);
+            vector signed char vsigns2 = vec_perm(vsigns23, vsigns23, mask0);
+            vector signed char vsigns3 = vec_perm(vsigns23, vsigns23, mask1);
+
+            vsigns0 = (vector signed char)vec_cmpeq(vec_and(vsigns0, mask2), mask2);
+            vsigns1 = (vector signed char)vec_cmpeq(vec_and(vsigns1, mask2), mask2);
+            vsigns2 = (vector signed char)vec_cmpeq(vec_and(vsigns2, mask2), mask2);
+            vsigns3 = (vector signed char)vec_cmpeq(vec_and(vsigns3, mask2), mask2);
+
+            vector signed char q2x0 = vec_sub(vec_xor(vsigns0, (vector signed char)aux64x2_0), vsigns0);
+            vector signed char q2x1 = vec_sub(vec_xor(vsigns1, (vector signed char)aux64x2_1), vsigns1);
+            vector signed char q2x2 = vec_sub(vec_xor(vsigns2, (vector signed char)aux64x2_2), vsigns2);
+            vector signed char q2x3 = vec_sub(vec_xor(vsigns3, (vector signed char)aux64x2_3), vsigns3);
+
+            vector signed char q8y0 = vec_xl( 0, q8);
+            vector signed char q8y1 = vec_xl(16, q8);
+            vector signed char q8y2 = vec_xl(32, q8);
+            vector signed char q8y3 = vec_xl(48, q8);
+            q8 += 64;
+
+            vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0));
+            vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1));
+            vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2));
+            vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3));
+
+            const uint16_t ls0 = (uint16_t)(sc[0] & 0xf);
+            const uint16_t ls1 = (uint16_t)(sc[0] >>  4);
+            const uint16_t ls2 = (uint16_t)(sc[1] & 0xf);
+            const uint16_t ls3 = (uint16_t)(sc[1] >>  4);
+            sc += 2;
+
+            vector signed short vscales0 = vec_splats((int16_t)(2*ls0+1));
+            vector signed short vscales1 = vec_splats((int16_t)(2*ls1+1));
+            vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1));
+            vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1));
+
+            vsumi0 = vec_msum(qv0, vscales0, vsumi0);
+            vsumi1 = vec_msum(qv1, vscales1, vsumi1);
+            vsumi2 = vec_msum(qv2, vscales2, vsumi2);
+            vsumi3 = vec_msum(qv3, vscales3, vsumi3);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = 0.125f * vec_extract(vsumf0, 0);
+
+#else
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq3_xxs * GGML_RESTRICT x = vx;
+    const block_q8_K    * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined(__POWER9_VECTOR__)
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    const vector int v0 = vec_splats((int32_t)0);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+
+        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
+        const uint32_t * GGML_RESTRICT signs = (const uint32_t *)(x[i].qs + QK_K/4);
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
+
+#pragma GCC unroll 1
+        for (int j = 0; j < QK_K/32; j += 2) {
+            __builtin_prefetch(q3, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector unsigned int aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]};
+            vector unsigned int aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]};
+            vector unsigned int aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]};
+            vector unsigned int aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]};
+            q3 += 16;
+
+            vector unsigned long long aux64x2_0 = {(uint64_t)(signs64[(signs[0] >>  0) & 127]), (uint64_t)(signs64[(signs[0] >>  7) & 127])};
+            vector unsigned long long aux64x2_1 = {(uint64_t)(signs64[(signs[0] >> 14) & 127]), (uint64_t)(signs64[(signs[0] >> 21) & 127])};
+            vector unsigned long long aux64x2_2 = {(uint64_t)(signs64[(signs[1] >>  0) & 127]), (uint64_t)(signs64[(signs[1] >>  7) & 127])};
+            vector unsigned long long aux64x2_3 = {(uint64_t)(signs64[(signs[1] >> 14) & 127]), (uint64_t)(signs64[(signs[1] >> 21) & 127])};
+
+            vector signed char q3x0 = vec_mul((vector signed char)aux64x2_0, (vector signed char)aux32x4_0);
+            vector signed char q3x1 = vec_mul((vector signed char)aux64x2_1, (vector signed char)aux32x4_1);
+            vector signed char q3x2 = vec_mul((vector signed char)aux64x2_2, (vector signed char)aux32x4_2);
+            vector signed char q3x3 = vec_mul((vector signed char)aux64x2_3, (vector signed char)aux32x4_3);
+
+            vector signed char q8y0 = vec_xl( 0, q8);
+            vector signed char q8y1 = vec_xl(16, q8);
+            vector signed char q8y2 = vec_xl(32, q8);
+            vector signed char q8y3 = vec_xl(48, q8);
+            q8 += 64;
+
+            vector signed short qv0 = vec_add(vec_mule(q3x0, q8y0), vec_mulo(q3x0, q8y0));
+            vector signed short qv1 = vec_add(vec_mule(q3x1, q8y1), vec_mulo(q3x1, q8y1));
+            vector signed short qv2 = vec_add(vec_mule(q3x2, q8y2), vec_mulo(q3x2, q8y2));
+            vector signed short qv3 = vec_add(vec_mule(q3x3, q8y3), vec_mulo(q3x3, q8y3));
+
+            const uint16_t ls0 = (uint16_t)(signs[0] >> 28);
+            const uint16_t ls1 = (uint16_t)(signs[1] >> 28);
+            signs += 2;
+
+            vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1));
+            vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1));
+
+            vsumi0 = vec_msum(qv0, vscales01, vsumi0);
+            vsumi1 = vec_msum(qv1, vscales01, vsumi1);
+            vsumi2 = vec_msum(qv2, vscales23, vsumi2);
+            vsumi3 = vec_msum(qv3, vscales23, vsumi3);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = 0.25f * vec_extract(vsumf0, 0);
+
+#else
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq3_s * GGML_RESTRICT x = vx;
+    const block_q8_K  * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined(__POWER9_VECTOR__)
+    static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+                                        0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+    };
+
+    static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
+
+    const vector int v0 = vec_splats((int32_t)0);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    const vector unsigned char mask0 = vec_xl( 0, k_mask1);
+    const vector unsigned char mask1 = vec_xl(16, k_mask1);
+    const vector signed char mask2 = (vector signed char)vec_xl( 0, k_mask2);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        const uint8_t *  GGML_RESTRICT q3 = x[i].qs;
+        const uint8_t *  GGML_RESTRICT qh = x[i].qh;
+        const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].signs);
+        const uint8_t *  GGML_RESTRICT sc = x[i].scales;
+        const int8_t  *  GGML_RESTRICT q8 = y[i].qs;
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+
+        for (int j = 0; j < QK_K/32; j += 2) {
+            __builtin_prefetch(q3, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector unsigned int aux32x4_0 = {iq3s_grid[q3[ 0] | ((qh[0] << 8) & 256)], iq3s_grid[q3[ 1] | ((qh[0] << 7) & 256)],
+                                             iq3s_grid[q3[ 2] | ((qh[0] << 6) & 256)], iq3s_grid[q3[ 3] | ((qh[0] << 5) & 256)]};
+            vector unsigned int aux32x4_1 = {iq3s_grid[q3[ 4] | ((qh[0] << 4) & 256)], iq3s_grid[q3[ 5] | ((qh[0] << 3) & 256)],
+                                             iq3s_grid[q3[ 6] | ((qh[0] << 2) & 256)], iq3s_grid[q3[ 7] | ((qh[0] << 1) & 256)]};
+            vector unsigned int aux32x4_2 = {iq3s_grid[q3[ 8] | ((qh[1] << 8) & 256)], iq3s_grid[q3[ 9] | ((qh[1] << 7) & 256)],
+                                             iq3s_grid[q3[10] | ((qh[1] << 6) & 256)], iq3s_grid[q3[11] | ((qh[1] << 5) & 256)]};
+            vector unsigned int aux32x4_3 = {iq3s_grid[q3[12] | ((qh[1] << 4) & 256)], iq3s_grid[q3[13] | ((qh[1] << 3) & 256)],
+                                             iq3s_grid[q3[14] | ((qh[1] << 2) & 256)], iq3s_grid[q3[15] | ((qh[1] << 1) & 256)]};
+            q3 += 16;
+            qh += 2;
+
+            vector signed char vsigns01 = (vector signed char)vec_splats(*(const uint32_t *)&signs[0]);
+            vector signed char vsigns02 = (vector signed char)vec_splats(*(const uint32_t *)&signs[2]);
+            signs += 4;
+
+            vector signed char vsigns0 = vec_perm(vsigns01, vsigns01, mask0);
+            vector signed char vsigns1 = vec_perm(vsigns01, vsigns01, mask1);
+            vector signed char vsigns2 = vec_perm(vsigns02, vsigns02, mask0);
+            vector signed char vsigns3 = vec_perm(vsigns02, vsigns02, mask1);
+
+            vsigns0 = (vector signed char)vec_cmpeq(vec_and(vsigns0, mask2), mask2);
+            vsigns1 = (vector signed char)vec_cmpeq(vec_and(vsigns1, mask2), mask2);
+            vsigns2 = (vector signed char)vec_cmpeq(vec_and(vsigns2, mask2), mask2);
+            vsigns3 = (vector signed char)vec_cmpeq(vec_and(vsigns3, mask2), mask2);
+
+            vector signed char q3x0 = vec_sub(vec_xor(vsigns0, (vector signed char)aux32x4_0), vsigns0);
+            vector signed char q3x1 = vec_sub(vec_xor(vsigns1, (vector signed char)aux32x4_1), vsigns1);
+            vector signed char q3x2 = vec_sub(vec_xor(vsigns2, (vector signed char)aux32x4_2), vsigns2);
+            vector signed char q3x3 = vec_sub(vec_xor(vsigns3, (vector signed char)aux32x4_3), vsigns3);
+
+            vector signed char q8y0 = vec_xl( 0, q8);
+            vector signed char q8y1 = vec_xl(16, q8);
+            vector signed char q8y2 = vec_xl(32, q8);
+            vector signed char q8y3 = vec_xl(48, q8);
+            q8 += 64;
+
+            vector signed short qv0 = vec_add(vec_mule(q3x0, q8y0), vec_mulo(q3x0, q8y0));
+            vector signed short qv1 = vec_add(vec_mule(q3x1, q8y1), vec_mulo(q3x1, q8y1));
+            vector signed short qv2 = vec_add(vec_mule(q3x2, q8y2), vec_mulo(q3x2, q8y2));
+            vector signed short qv3 = vec_add(vec_mule(q3x3, q8y3), vec_mulo(q3x3, q8y3));
+
+            const uint16_t ls0 = (uint16_t)(sc[0] & 0xf);
+            const uint16_t ls1 = (uint16_t)(sc[0] >>  4);
+            sc ++;
+
+            vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1));
+            vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1));
+
+            vsumi0 = vec_msum(qv0, vscales01, vsumi0);
+            vsumi1 = vec_msum(qv1, vscales01, vsumi1);
+            vsumi2 = vec_msum(qv2, vscales23, vsumi2);
+            vsumi3 = vec_msum(qv3, vscales23, vsumi3);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = vec_extract(vsumf0, 0);
+
+#else
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq1_s * GGML_RESTRICT x = vx;
+    const block_q8_K  * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined(__POWER9_VECTOR__)
+    const vector unsigned char v0 = vec_splats((unsigned char)0x0);
+    const vector unsigned short vsign = vec_splats((unsigned short)0x8000);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed int vsumi0 = vec_splats((int32_t)0);
+        vector signed int vsumi1 = vec_splats((int32_t)0);
+        vector signed int vsumi2 = vec_splats((int32_t)0);
+        vector signed int vsumi3 = vec_splats((int32_t)0);
+        vector signed int vsumi8 = vec_splats((int32_t)0);
+
+        const uint8_t  * GGML_RESTRICT q1 = x[i].qs;
+        const uint16_t * GGML_RESTRICT qh = x[i].qh;
+        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
+        const int16_t  * GGML_RESTRICT qs = y[i].bsums;
+
+        for (int j = 0; j < QK_K/32; j += 2) {
+            __builtin_prefetch(q1, 0, 1);
+            __builtin_prefetch(qh, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector signed long long aux64x2_0 = {*(const int64_t *)(iq1s_grid + (q1[0] | ((qh[0] << 8) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[1] | ((qh[0] << 5) & 0x700)))};
+            vector signed long long aux64x2_1 = {*(const int64_t *)(iq1s_grid + (q1[2] | ((qh[0] << 2) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[3] | ((qh[0] >> 1) & 0x700)))};
+            vector signed long long aux64x2_2 = {*(const int64_t *)(iq1s_grid + (q1[4] | ((qh[1] << 8) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[5] | ((qh[1] << 5) & 0x700)))};
+            vector signed long long aux64x2_3 = {*(const int64_t *)(iq1s_grid + (q1[6] | ((qh[1] << 2) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[7] | ((qh[1] >> 1) & 0x700)))};
+            q1 += 8;
+
+            vector signed char q1x0 = (vector signed char)aux64x2_0;
+            vector signed char q1x1 = (vector signed char)aux64x2_1;
+            vector signed char q1x2 = (vector signed char)aux64x2_2;
+            vector signed char q1x3 = (vector signed char)aux64x2_3;
+
+            vector signed char q8y0 = vec_xl( 0, q8);
+            vector signed char q8y1 = vec_xl(16, q8);
+            vector signed char q8y2 = vec_xl(32, q8);
+            vector signed char q8y3 = vec_xl(48, q8);
+            q8 += 64;
+
+            vector signed short qv0 = vec_add(vec_mule(q1x0, q8y0), vec_mulo(q1x0, q8y0));
+            vector signed short qv1 = vec_add(vec_mule(q1x1, q8y1), vec_mulo(q1x1, q8y1));
+            vector signed short qv2 = vec_add(vec_mule(q1x2, q8y2), vec_mulo(q1x2, q8y2));
+            vector signed short qv3 = vec_add(vec_mule(q1x3, q8y3), vec_mulo(q1x3, q8y3));
+
+            const uint16_t ls0 = (uint16_t)((qh[0] >> 12) & 7);
+            const uint16_t ls1 = (uint16_t)((qh[1] >> 12) & 7);
+
+            vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1));
+            vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1));
+            vector signed short vscales = vec_sld(vscales23, vscales01, 8);
+
+            vsumi0 = vec_msum(qv0, vscales01, vsumi0);
+            vsumi1 = vec_msum(qv1, vscales01, vsumi1);
+            vsumi2 = vec_msum(qv2, vscales23, vsumi2);
+            vsumi3 = vec_msum(qv3, vscales23, vsumi3);
+
+            vector signed short q8ysums = vec_xl_len(qs, 8);
+            qs += 4;
+            q8ysums = vec_mergeh(q8ysums, (vector signed short)v0);
+
+            vector signed short qxh = (vector signed short)vec_sld(vec_splats(qh[1]), vec_splats(qh[0]), 8);
+            qh += 2;
+            vector __bool short vsel = vec_cmpge(qxh, (vector signed short)v0);
+
+            vector signed short q8ysum = vec_sel((vector signed short)vec_xor((vector unsigned short)q8ysums, vsign), q8ysums, vsel);
+
+            vsumi8 = vec_add(vec_mule(q8ysum, vscales), vsumi8);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi8, 0), vec_mul(vd, vec_splats(IQ1S_DELTA)), vsumf0);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = vec_extract(vsumf0, 0);
+
+#else
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+    assert(n % QK4_NL == 0);
+    static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
+
+    const block_iq4_nl * GGML_RESTRICT x = vx;
+    const block_q8_0   * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK4_NL;
+
+    int ib = 0;
+    float sumf = 0;
+
+#if defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector signed int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+
+    const vector signed char values = vec_xl( 0, kvalues_iq4nl);
+
+#pragma GCC unroll 4
+    for (; ib < nb; ++ib) {
+        __builtin_prefetch(x[ib].qs, 0, 1);
+        __builtin_prefetch(y[ib].qs, 0, 1);
+
+
+        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d));
+        vector float vyd = vec_splats(GGML_CPU_FP16_TO_FP32(y[ib].d));
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);
+        vector signed char q4x0 = vec_and(qxs, lowMask);
+        vector signed char q4x1 = vec_sr(qxs, v4);
+
+        q4x0 = vec_perm(values, values, (vector unsigned char)q4x0);
+        q4x1 = vec_perm(values, values, (vector unsigned char)q4x1);
+
+        vector signed char q8y0 = vec_xl( 0, y[ib].qs);
+        vector signed char q8y1 = vec_xl(16, y[ib].qs);
+
+        vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0));
+        vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1));
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+
+        vsumi0 = vec_sum4s(qv0, vsumi0);
+        vsumi1 = vec_sum4s(qv1, vsumi1);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    sumf = vec_extract(vsumf0, 0);
+
+    *s = sumf;
+#else
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    UNUSED(ib);
+    UNUSED(sumf);
+    ggml_vec_dot_iq4_nl_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+    assert(n % QK_K == 0);
+
+    const block_iq4_xs * GGML_RESTRICT x = vx;
+    const block_q8_K   * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    const vector signed char values = vec_xl( 0, kvalues_iq4nl);
+
+    for (int ibl = 0; ibl < nb; ++ibl) {
+
+        vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[ibl].d));
+        vector float vyd = vec_splats(y[ibl].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+
+        uint16_t h = x[ibl].scales_h;
+
+        const uint8_t * GGML_RESTRICT q4 = x[ibl].qs;
+        const uint8_t * GGML_RESTRICT sc = x[ibl].scales_l;
+        const int8_t  * GGML_RESTRICT q8 = y[ibl].qs;
+
+        for (int ib = 0; ib < QK_K/64; ib ++ ) {
+            __builtin_prefetch(q4, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector signed char qxs0 = (vector signed char)vec_xl( 0, q4);
+            vector signed char qxs1 = (vector signed char)vec_xl(16, q4);
+            q4 += 32;
+
+            vector signed char q4x00 = (vector signed char)vec_and(qxs0, lowMask);
+            vector signed char q4x01 = (vector signed char)vec_sr(qxs0, v4);
+            vector signed char q4x10 = (vector signed char)vec_and(qxs1, lowMask);
+            vector signed char q4x11 = (vector signed char)vec_sr(qxs1, v4);
+
+            q4x00 = vec_perm(values, values, (vector unsigned char)q4x00);
+            q4x01 = vec_perm(values, values, (vector unsigned char)q4x01);
+            q4x10 = vec_perm(values, values, (vector unsigned char)q4x10);
+            q4x11 = vec_perm(values, values, (vector unsigned char)q4x11);
+
+            vector signed char q8y0 = vec_xl( 0, q8);
+            vector signed char q8y1 = vec_xl(16, q8);
+            vector signed char q8y2 = vec_xl(32, q8);
+            vector signed char q8y3 = vec_xl(48, q8);
+            q8 += 64;
+
+            vector signed short qv0 = vec_add(vec_mule(q4x00, q8y0), vec_mulo(q4x00, q8y0));
+            vector signed short qv1 = vec_add(vec_mule(q4x01, q8y1), vec_mulo(q4x01, q8y1));
+            vector signed short qv2 = vec_add(vec_mule(q4x10, q8y2), vec_mulo(q4x10, q8y2));
+            vector signed short qv3 = vec_add(vec_mule(q4x11, q8y3), vec_mulo(q4x11, q8y3));
+
+            const uint16_t ls0 = (uint16_t)(((sc[0] & 0xf) | ((h << 4) & 0x30)) - 32);
+            const uint16_t ls1 = (uint16_t)(((sc[0] >>  4) | ((h << 2) & 0x30)) - 32);
+            h >>= 4;
+            sc ++;
+
+            vector signed short vscales01 = vec_splats((int16_t)ls0);
+            vector signed short vscales23 = vec_splats((int16_t)ls1);
+
+            vsumi0 = vec_msum(qv0, vscales01, vsumi0);
+            vsumi1 = vec_msum(qv1, vscales01, vsumi1);
+            vsumi2 = vec_msum(qv2, vscales23, vsumi2);
+            vsumi3 = vec_msum(qv3, vscales23, vsumi3);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = vec_extract(vsumf0, 0);
+
+#else
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/arch/x86/quants.c b/ml/backend/ggml/ggml/src/ggml-cpu/arch/x86/quants.c
index cb49320a67f..74d699f633d 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/arch/x86/quants.c
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/arch/x86/quants.c
@@ -268,9 +268,9 @@ static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const
                            _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0)));
 }
 
-static inline __m256 quad_mx_delta_float(const int8_t x0, const float y0, const int8_t x1, const float y1) {
-    return _mm256_set_m128(_mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)),
-                           _mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0)));
+static inline __m256 quad_mx_delta_float(const uint8_t x0, const float y0, const uint8_t x1, const float y1) {
+    return _mm256_set_m128(_mm_set1_ps(GGML_CPU_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)),
+                           _mm_set1_ps(GGML_CPU_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0)));
 }
 #endif
 #elif defined(__SSSE3__)
@@ -782,6 +782,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
 
     __m256 accum1 = _mm256_setzero_ps();
     __m256 accum2 = _mm256_setzero_ps();
+
     for (; ib + 1 < nb; ib += 2) {
         const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);
         const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);
@@ -795,10 +796,10 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
         const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
         const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
         const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
-        accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)),
-                _mm256_cvtepi32_ps(p_1), accum1);
-        accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)),
-                _mm256_cvtepi32_ps(p_2), accum2);
+        const __m256 scale0 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib + 0].e));
+        const __m256 scale1 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib + 1].e));
+        accum1 = _mm256_fmadd_ps(scale0, _mm256_cvtepi32_ps(p_1), accum1);
+        accum2 = _mm256_fmadd_ps(scale1, _mm256_cvtepi32_ps(p_2), accum2);
     }
 
     sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
@@ -830,7 +831,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
 
 #endif
     for (; ib < nb; ++ib) {
-        const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
+        const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib].e);
         int sumi1 = 0;
         int sumi2 = 0;
         for (int j = 0; j < QK_MXFP4/2; ++j) {
@@ -3817,4 +3818,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
     ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
-
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/arch/x86/repack.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/arch/x86/repack.cpp
index 7dda9eea0c5..bd6906c4159 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/arch/x86/repack.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/arch/x86/repack.cpp
@@ -522,7 +522,8 @@ template
 static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, __m256i signextendlut) {
     static_assert(
             std::is_same_v ||
-            std::is_same_v,
+            std::is_same_v ||
+            std::is_same_v,
             "Unsupported block type");
 
     const int qk = QK8_0;
@@ -580,6 +581,18 @@ static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
                         std::is_same_v ||
                         std::is_same_v) {
                     col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask);
+                } else if constexpr (std::is_same_v) {
+                    // Load 8 E8M0 exponents and convert to float via LUT
+                    // Rearranged to match changemask order: 0,4,1,5,2,6,3,7
+                    col_scale_f32 = _mm256_set_ps(
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0]));
                 }
 
                 // Load and convert to FP32 scale from block_q8_0
@@ -628,7 +641,8 @@ template
 static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, __m256i signextendlut) {
     static_assert(
             std::is_same_v ||
-            std::is_same_v,
+            std::is_same_v ||
+            std::is_same_v,
             "Unsupported block type");
 
     const int qk = QK8_0;
@@ -749,6 +763,25 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
                         std::is_same_v ||
                         std::is_same_v) {
                     col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
+                } else if constexpr (std::is_same_v) {
+                    //TODO: simd-ify
+                    col_scale_f32 = _mm512_set_ps(
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[7]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[6]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[5]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[4]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[3]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[2]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[1]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[0]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[7]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[6]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[5]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[4]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[3]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[2]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[1]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[0]));
                 }
 
                 // Process LHS in pairs of rows
@@ -941,6 +974,25 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
                         std::is_same_v ||
                         std::is_same_v) {
                     col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
+                } else if constexpr (std::is_same_v) {
+                    //TODO: simd-ify
+                    col_scale_f32 = _mm512_set_ps(
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[7]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[6]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[5]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[4]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[3]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[2]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[1]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[0]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[7]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[6]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[5]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[4]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[3]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[2]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[1]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[0]));
                 }
 
                 // Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
@@ -1123,6 +1175,16 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
                         std::is_same_v ||
                         std::is_same_v) {
                     col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
+                } else if constexpr (std::is_same_v) {
+                    col_scale_f32 = _mm256_set_ps(
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0]));
                 }
 
                 // Process LHS in groups of four
@@ -1283,6 +1345,16 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
                         std::is_same_v ||
                         std::is_same_v) {
                     col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
+                } else if constexpr (std::is_same_v) {
+                    col_scale_f32 = _mm256_set_ps(
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0]));
                 }
 
                 // Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
@@ -1625,6 +1697,19 @@ void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
     ggml_gemv_iq4_nl_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 }
 
+void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+#if defined(__AVX2__)
+    __m256i signextendlut = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)kvalues_mxfp4));
+    signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
+
+    gemv_q4_b32_8x8_q8_0_lut_avx(n, s, bs, vx, vy, nr, nc, signextendlut);
+
+    return;
+#endif
+
+    ggml_gemv_mxfp4_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
 void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     const int qk = QK_K;
     const int nb = n / qk;
@@ -3423,6 +3508,21 @@ void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
     ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
+void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+#if defined(__AVX2__) || defined(__AVX512F__)
+    {
+        __m256i signextendlut = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)kvalues_mxfp4));
+        signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
+
+        gemm_q4_b32_8x8_q8_0_lut_avx(n, s, bs, vx, vy, nr, nc, signextendlut);
+
+        return;
+    }
+#endif // defined(__AVX2__) || defined(__AVX512F__)
+
+    ggml_gemm_mxfp4_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
 void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     const int qk = QK_K;
     const int nb = n / qk;
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/binary-ops.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/binary-ops.cpp
index 14f5b43ae0e..75e38290015 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/binary-ops.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/binary-ops.cpp
@@ -59,11 +59,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds
     GGML_ASSERT(nb00 == sizeof(src0_t));
 
     const auto [ir0, ir1] = get_thread_range(params, src0);
-    const bool is_src1_contiguous = (nb10 == sizeof(src1_t));
-
-    if (!is_src1_contiguous) { // broadcast not implemented yet for non-contiguous
-        GGML_ASSERT(ggml_are_same_shape(src0, src1));
-    }
+    const bool is_src1_contiguous_rows = ggml_is_contiguous_rows(src1);
 
 #ifdef GGML_USE_ACCELERATE
     vDSP_fn_t vDSP_op = nullptr;
@@ -94,7 +90,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds
         const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
         const src1_t * src1_ptr = (const src1_t *) ((const char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
 
-        if (is_src1_contiguous) {
+        if (is_src1_contiguous_rows) {
             // src1 is broadcastable across src0 and dst in i1, i2, i3
             const int64_t nr0 = ne00 / ne10;
 
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/common.h b/ml/backend/ggml/ggml/src/ggml-cpu/common.h
index 6adca5437f8..abbadc359c5 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/common.h
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/common.h
@@ -6,6 +6,9 @@
 #include "ggml-impl.h"
 #include "simd-mappings.h"
 
+#define GGML_FA_TILE_Q  64
+#define GGML_FA_TILE_KV 64
+
 #ifdef __cplusplus
 
 #include 
@@ -84,4 +87,9 @@ static std::pair get_thread_range(const struct ggml_compute_pa
     return {ir0, ir1};
 }
 
+struct ggml_fa_tile_config {
+    static constexpr size_t Q  = GGML_FA_TILE_Q;
+    static constexpr size_t KV = GGML_FA_TILE_KV;
+};
+
 #endif
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h
index 7597377cc27..88a9c9ec057 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h
@@ -24,6 +24,9 @@ struct ggml_compute_params {
     void * wdata;
 
     struct ggml_threadpool * threadpool;
+
+    // use reference implementation
+    bool use_ref;
 };
 
 
@@ -328,7 +331,7 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b)
 
 #if defined(_MSC_VER) || defined(__MINGW32__)
 #include 
-#elif defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__)
+#elif defined(__SSE__) || defined(__SSE3__) || defined(__SSSE3__) || defined(__AVX__) || defined(__F16C__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX512BF16__)
 #include 
 #endif
 
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c
index 8d485131228..b4938ea0b4d 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c
@@ -5,7 +5,6 @@
 #include "ggml-backend.h"
 #include "traits.h"
 #include "ggml-cpu-impl.h"
-#include "ggml-cpu.h"
 #include "ggml-impl.h"
 #include "quants.h"
 #include "ggml-threading.h"
@@ -14,6 +13,7 @@
 #include "vec.h"
 #include "ops.h"
 #include "ggml.h"
+#include "common.h"
 
 #include "ollama-debug.h"
 
@@ -77,6 +77,9 @@
 // precomputed f32 table for f16 (256 KB) (simd-mappings.h)
 float ggml_table_f32_f16[1 << 16];
 
+// precomputed f32 table for e8m0 half (1 KB) (simd-mappings.h)
+float ggml_table_f32_e8m0_half[1 << 8];
+
 #if defined(__ARM_ARCH)
 struct ggml_arm_arch_features_type {
     int sve_cnt;
@@ -2868,10 +2871,20 @@ struct ggml_cplan ggml_graph_plan(
                     } break;
                 case GGML_OP_FLASH_ATTN_EXT:
                     {
-                        const int64_t ne10 = node->src[1]->ne[0]; // DK
-                        const int64_t ne20 = node->src[2]->ne[0]; // DV
+                        const int64_t neq2 = node->src[0]->ne[2]; // number of query heads
+                        const int64_t DK = node->src[1]->ne[0];
+                        const int64_t DV = node->src[2]->ne[0];
+
+                        // Tiled flash attention scratch (tile sizes defined in common.h)
+                        // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + K_f32 + padding
+                        size_t prefill  = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV + GGML_FA_TILE_KV*DK)*n_tasks;
 
-                        cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread)
+                        // Decode path: n_kv_chunks = n_tasks (one chunk per thread)
+                        // Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ
+                        size_t n_chunks = n_tasks;
+                        size_t decode   = sizeof(float)*(neq2*n_chunks*(2+DV) + n_tasks*(DK + 2*DV));
+
+                        cur += MAX(prefill, decode);
                     } break;
                 case GGML_OP_FLASH_ATTN_BACK:
                     {
@@ -2928,14 +2941,19 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
     set_numa_thread_affinity(state->ith);
 
     struct ggml_compute_params params = {
-        /*.ith       =*/ state->ith,
-        /*.nth       =*/ atomic_load_explicit(&tp->n_graph, memory_order_relaxed) & GGML_THREADPOOL_N_THREADS_MASK,
-        /*.wsize     =*/ cplan->work_size,
-        /*.wdata     =*/ cplan->work_data,
-        /*.threadpool=*/ tp,
+        /*.ith        =*/ state->ith,
+        /*.nth        =*/ atomic_load_explicit(&tp->n_graph, memory_order_relaxed) & GGML_THREADPOOL_N_THREADS_MASK,
+        /*.wsize      =*/ cplan->work_size,
+        /*.wdata      =*/ cplan->work_data,
+        /*.threadpool =*/ tp,
+        /*.use_ref    =*/ cplan->use_ref,
     };
 
-    GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d \n", state->ith, cplan, state->last_graph);
+#ifdef GGML_USE_OPENMP
+    GGML_PRINT_DEBUG("thread #%d compute-start cplan %p\n", state->ith, (const void *)cplan);
+#else
+    GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d\n", state->ith, (const void *)cplan, state->last_graph);
+#endif
 
     for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {
         struct ggml_tensor * node = cgraph->nodes[node_n];
@@ -2945,6 +2963,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
             continue;
         }
 
+        if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
+            continue;
+        }
+
         ggml_compute_forward(¶ms, node);
 
 #ifdef OLLAMA_DEBUG
@@ -2962,7 +2984,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
         }
     }
 
-    GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d \n", state->ith, cplan, state->last_graph);
+#ifdef GGML_USE_OPENMP
+    GGML_PRINT_DEBUG("thread #%d compute-done cplan %p\n", state->ith, (const void *)cplan);
+#else
+    GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d\n", state->ith, (const void *)cplan, state->last_graph);
+#endif
 
     ggml_barrier(state->threadpool);
 
@@ -3326,13 +3352,33 @@ void ggml_cpu_fp16_to_fp32(const ggml_fp16_t * x, float * y, int64_t n) {
         __m128 y_vec = _mm_cvtph_ps(x_vec);
         _mm_storeu_ps(y + i, y_vec);
     }
-#elif defined(__riscv_zvfh)
-    for (int vl; i < n; i += vl) {
-        vl = __riscv_vsetvl_e16m1(n - i);
-        vfloat16m1_t vx = __riscv_vle16_v_f16m1((_Float16 *)&x[i], vl);
-        vfloat32m2_t vy = __riscv_vfwcvt_f_f_v_f32m2(vx, vl);
-        __riscv_vse32_v_f32m2(&y[i], vy, vl);
+
+#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfhmin)
+    // calculate step size
+    const int epr = __riscv_vsetvlmax_e16m2();
+    const int step = epr * 2;
+    const int np = (n & ~(step - 1));
+
+    // unroll by 2
+    for (; i < np; i += step) {
+        vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16*)x + i, epr);
+        vfloat32m4_t ay0 = __riscv_vfwcvt_f_f_v_f32m4(ax0, epr);
+        __riscv_vse32_v_f32m4(y + i, ay0, epr);
+
+        vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16*)x + i + epr, epr);
+        vfloat32m4_t ay1 = __riscv_vfwcvt_f_f_v_f32m4(ax1, epr);
+        __riscv_vse32_v_f32m4(y + i + epr, ay1, epr);
     }
+
+    // leftovers
+    int vl;
+    for (i = np; i < n; i += vl) {
+        vl = __riscv_vsetvl_e16m2(n - i);
+        vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16*)x + i, vl);
+        vfloat32m4_t ay0 = __riscv_vfwcvt_f_f_v_f32m4(ax0, vl);
+        __riscv_vse32_v_f32m4(y + i, ay0, vl);
+    }
+
 #endif
 
     for (; i < n; ++i) {
@@ -3377,6 +3423,31 @@ void ggml_cpu_bf16_to_fp32(const ggml_bf16_t * x, float * y, int64_t n) {
                                         (const __m128i *)(x + i))),
                                 16)));
     }
+#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfbfmin)
+    // calculate step size
+    const int epr = __riscv_vsetvlmax_e16m2();
+    const int step = epr * 2;
+    const int np = (n & ~(step - 1));
+
+    // unroll by 2
+    for (; i < np; i += step) {
+        vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16*)x + i, epr);
+        vfloat32m4_t ay0 = __riscv_vfwcvtbf16_f_f_v_f32m4(ax0, epr);
+        __riscv_vse32_v_f32m4(y + i, ay0, epr);
+
+        vbfloat16m2_t ax1 = __riscv_vle16_v_bf16m2((const __bf16*)x + i + epr, epr);
+        vfloat32m4_t ay1 = __riscv_vfwcvtbf16_f_f_v_f32m4(ax1, epr);
+        __riscv_vse32_v_f32m4(y + i + epr, ay1, epr);
+    }
+
+    // leftovers
+    int vl;
+    for (i = np; i < n; i += vl) {
+        vl = __riscv_vsetvl_e16m2(n - i);
+        vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16*)x + i, vl);
+        vfloat32m4_t ay0 = __riscv_vfwcvtbf16_f_f_v_f32m4(ax0, vl);
+        __riscv_vse32_v_f32m4(y + i, ay0, vl);
+    }
 #endif
     for (; i < n; i++) {
         y[i] = GGML_BF16_TO_FP32(x[i]);
@@ -3627,6 +3698,11 @@ void ggml_cpu_init(void) {
                 ggml_table_gelu_quick_f16[i] = GGML_CPU_FP32_TO_FP16(ggml_gelu_quick_f32(f));
             }
 
+            // initialize E8M0 half table (256 entries)
+            for (int i = 0; i < (1 << 8); ++i) {
+                ggml_table_f32_e8m0_half[i] = GGML_E8M0_TO_FP32_HALF(i);
+            }
+
             const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
 
             GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0);
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp
index 92ba577a543..622cf5d24bd 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp
@@ -105,6 +105,8 @@ struct ggml_backend_cpu_context {
 
     ggml_abort_callback abort_callback;
     void *              abort_callback_data;
+
+    bool                use_ref;  // use reference implementation
 };
 
 static const char * ggml_backend_cpu_get_name(ggml_backend_t backend) {
@@ -143,6 +145,7 @@ static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend
 
     cpu_plan->cplan.abort_callback      = cpu_ctx->abort_callback;
     cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data;
+    cpu_plan->cplan.use_ref             = cpu_ctx->use_ref;
 
     return cpu_plan;
 }
@@ -182,6 +185,7 @@ static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, s
 
     cplan.abort_callback      = cpu_ctx->abort_callback;
     cplan.abort_callback_data = cpu_ctx->abort_callback_data;
+    cplan.use_ref             = cpu_ctx->use_ref;
 
     return ggml_graph_compute(cgraph, &cplan);
 
@@ -225,6 +229,7 @@ ggml_backend_t ggml_backend_cpu_init(void) {
     ctx->work_size           = 0;
     ctx->abort_callback      = NULL;
     ctx->abort_callback_data = NULL;
+    ctx->use_ref             = false;
 
     ggml_backend_t cpu_backend = new ggml_backend {
         /* .guid    = */ ggml_backend_cpu_guid(),
@@ -272,6 +277,13 @@ void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_
     ctx->abort_callback_data = abort_callback_data;
 }
 
+void ggml_backend_cpu_set_use_ref(ggml_backend_t backend_cpu, bool use_ref) {
+    GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
+
+    struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
+    ctx->use_ref = use_ref;
+}
+
 // CPU backend - device
 
 struct ggml_backend_cpu_device_context {
@@ -648,6 +660,9 @@ static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const ch
     if (strcmp(name, "ggml_backend_cpu_is_numa") == 0) {
         return (void *)ggml_is_numa;
     }
+    if (strcmp(name, "ggml_backend_cpu_set_use_ref") == 0) {
+        return (void *)ggml_backend_cpu_set_use_ref;
+    }
 
     // threadpool - TODO:  move to ggml-base
     if (strcmp(name, "ggml_threadpool_new") == 0) {
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h b/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h
deleted file mode 100644
index a7078687288..00000000000
--- a/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h
+++ /dev/null
@@ -1,333 +0,0 @@
-#pragma once
-
-typedef vector unsigned char vec_t;
-typedef __vector_quad acc_t;
-
-template 
-class tinyBLAS_Q0_PPC {
-  public:
-    tinyBLAS_Q0_PPC(int64_t k,
-                    const TA *A, int64_t lda,
-                    const block_q8_0 *B, int64_t ldb,
-                    float *C, int64_t ldc,
-                    int ith, int nth);
-
-    void matmul(int64_t m, int64_t n);
-    void matmul_tiled_q0(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
-        vec_t A_pack[mc*kc*2];
-        vec_t B_pack[nc*kc*2];
-        int comparray[mc*kc];
-        constexpr bool is_Ablock_q4 = std::is_same_v;
-        int64_t ytiles = m / mc;
-        int64_t xtiles = n / nc;
-        int64_t tiles  = xtiles * ytiles;
-        int64_t duty = (tiles + nth - 1) / nth;
-        int64_t start = duty * ith;
-        int64_t end = start + duty;
-        if (end > tiles) {
-            end = tiles;
-        }
-        for (int64_t job = start; job < end; ++job) {
-            int64_t ii = (job / xtiles) * mc;
-            int64_t jj = (job % xtiles) * nc;
-            for (int64_t kk = 0; kk < k; kk += kc) {
-                if constexpr(is_Ablock_q4) {
-                    packNormalInt4_large(A + ii*lda + kk, lda, mc, 4, (int8_t*)A_pack, comparray);
-                } else {
-                    packNormal_large(A + ii*lda + kk, lda, mc, 8, (int8_t*)A_pack, false, comparray);
-                }
-                packNormal_large(B + jj*ldb + kk, ldb, nc, 8, (uint8_t*)B_pack, true);
-                KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack, comparray);
-            }
-        }
-    }
-
-  private:
-    inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
-        for (int I = 0; I < RM; I++) {
-            for (int J = 0; J < RN; J++) {
-                *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
-            }
-        }
-    }
-
-    inline void add_save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
-        for (int I = 0; I < RM; I++) {
-            for (int J = 0; J < RN; J++) {
-                float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
-                *c_ptr += *((float*)&fin_res[idx+I]+J);
-            }
-        }
-    }
-
-    template
-    inline void compute(acc_t* ACC, int c_idx, int s_idx, ArrayType& comparray, vector float* vs, vector float* fin_res) {
-        vector signed int vec_C[4];
-        vector float CA[4] = {0};
-        vector float res[4] = {0};
-        __builtin_mma_disassemble_acc(vec_C, ACC);
-        for (int i = 0; i < 4; i++) {
-            CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
-            res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
-            fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
-        }
-    }
-
-    inline void process_q4_elements(vector signed char (&c)[2], int* ca) {
-        const vector signed char lowMask = vec_splats((signed char)0xF);
-        const vector unsigned char v4 = vec_splats((unsigned char)0x4);
-        const vector signed char v8 = vec_splats((signed char)0x8);
-        vector signed int vsum = {0};
-        vector signed int vsum2 = {0};
-        c[0] = vec_and(c[1], lowMask);
-        c[1] = vec_sr(c[1], v4);
-        c[0] = vec_sub(c[0], v8);
-        c[1] = vec_sub(c[1], v8);
-        vsum = vec_sum4s(c[0], vsum);
-        vsum2 = vec_sum4s(c[1], vsum2);
-        vsum = vec_add(vsum, vsum2);
-        *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
-    }
-
-    template 
-    inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) {
-        vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
-        vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
-        vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
-        vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
-        V2 t1, t2, t3, t4, t5, t6, t7, t8;
-        vector unsigned char xor_vector;
-        uint8_t flip_vec = 0x80;
-        xor_vector = vec_splats(flip_vec);
-        t1 = vec_perm(s1, s2, swiz1);
-        t2 = vec_perm(s1, s2, swiz2);
-        t3 = vec_perm(s3, s4, swiz1);
-        t4 = vec_perm(s3, s4, swiz2);
-        t5 = vec_perm(t1, t3, swiz3);
-        t6 = vec_perm(t1, t3, swiz4);
-        t7 = vec_perm(t2, t4, swiz3);
-        t8 = vec_perm(t2, t4, swiz4);
-        if (flip == true) {
-            t5 = vec_xor(t5, xor_vector);
-            t6 = vec_xor(t6, xor_vector);
-            t7 = vec_xor(t7, xor_vector);
-            t8 = vec_xor(t8, xor_vector);
-        }
-        vec_xst(t5, 0, vecOffset);
-        vec_xst(t6, 0, vecOffset+16);
-        vec_xst(t7, 0, vecOffset+32);
-        vec_xst(t8, 0, vecOffset+48);
-    }
-
-    template
-    inline void kernel(int64_t ii, int64_t jj) {
-        if constexpr(RM == 4 && RN == 8) {
-            KERNEL_4x8(ii,jj);
-        } else if constexpr(RM == 8 && RN == 4) {
-            KERNEL_8x4(ii,jj);
-        } else if constexpr(RM == 8 && RN == 8) {
-            KERNEL_8x8(ii,jj);
-        } else {
-            assert(false && "RN/RM values not supported");
-        }
-    }
-    template
-    void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array& comparray);
-    template
-    void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip);
-    void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n);
-    void KERNEL_4x8(int64_t ii, int64_t jj);
-    void KERNEL_8x4(int64_t ii, int64_t jj);
-    void KERNEL_8x8(int64_t ii, int64_t jj);
-    void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN);
-    template 
-    void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n);
-
-    void compute_scale(int64_t ii, int64_t jj, int blk, vector float* vs){
-        for (int I = 0; I<8; I++) {
-            float a_scale = unhalf((A+((ii+I)*lda)+blk)->d);
-            for (int J = 0; J<4; J++) {
-                *((float*)&vs[I]+J) = (a_scale * unhalf((B+((jj+J)*ldb)+blk)->d));
-                *((float*)&vs[I+8]+J) = (a_scale * unhalf((B+((jj+J+4)*ldb)+blk)->d));
-             }
-         }
-    }
-
-    inline void process_q8_elements(const int8_t *qs, int *ca) {
-        vector signed char c1 = vec_xl(0, qs);
-        vector signed char c2 = vec_xl(16, qs);
-        vector signed int vsum1 = {0};
-        vector signed int vsum2 = {0};
-        vsum1 = vec_sum4s(c1, vsum1);
-        vsum2 = vec_sum4s(c2, vsum2);
-        vector signed int vsum = vec_add(vsum1, vsum2);
-        *ca = vsum[0] + vsum[1] + vsum[2] + vsum[3];
-    }
-
-    template
-    void packNormal_large(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip, int* comparray=nullptr) {
-        int64_t i, j;
-        block_q8_0 *aoffset = NULL;
-        VA *vecOffset = NULL;
-        block_q8_0* aoffsets[8];
-        __vector_pair arr[8];
-        VB c[8][2] = {0};
-        VB c1[8] = {0}; VB c2[8] = {0};
-        aoffset = const_cast(a);
-        vecOffset = vec;
-        j = (rows >> 3);
-        int index = 0;
-        if (j > 0) {
-            do {
-                for (int it = 0; it < 8; it++)
-                    aoffsets[it] = aoffset + it*lda;
-                aoffset += 8 * lda;
-                for (int blk = 0; blk < kc; blk++) {
-                    for (int it = 0; it < 8; it++) {
-                        arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)(aoffsets[it]+blk)->qs);
-                        __builtin_vsx_disassemble_pair(c[it], &arr[it]);
-                        c1[it] = c[it][0];
-                        c2[it] = c[it][1];
-                        if (comparray){
-                            process_q8_elements((aoffsets[it]+ blk)->qs, &comparray[index + 8*blk + it]);
-                        }
-                    }
-                    vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
-                    vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
-                    vector_permute_store(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
-                    vector_permute_store(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
-                    vecOffset += 256;
-                }
-                j--;
-                index += 8*kc;
-            } while(j > 0);
-        }
-
-    }
-
-    void packNormalInt4_large(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, int*comparray) {
-        int64_t i, j;
-        TA *aoffset = NULL;
-        int8_t *vecOffset = NULL;
-        TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
-        TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
-        vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
-        vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
-        aoffset = const_cast(a);
-        vecOffset = vec;
-        int index = 0;
-        j = (rows >> 3);
-        if (j > 0) {
-            do {
-                aoffset1 = aoffset;
-                aoffset2 = aoffset1 + lda;
-                aoffset3 = aoffset2 + lda;
-                aoffset4 = aoffset3 + lda;
-                aoffset5 = aoffset4 + lda;
-                aoffset6 = aoffset5 + lda;
-                aoffset7 = aoffset6 + lda;
-                aoffset8 = aoffset7 + lda;
-                aoffset += 8 * lda;
-                for (int blk = 0; blk < kc; blk++) {
-                    c1[1] = reinterpret_cast(vec_xl(0, (aoffset1+blk)->qs));
-                    c2[1] = reinterpret_cast(vec_xl(0, (aoffset2+blk)->qs));
-                    c3[1] = reinterpret_cast(vec_xl(0, (aoffset3+blk)->qs));
-                    c4[1] = reinterpret_cast(vec_xl(0, (aoffset4+blk)->qs));
-                    c5[1] = reinterpret_cast(vec_xl(0, (aoffset5+blk)->qs));
-                    c6[1] = reinterpret_cast(vec_xl(0, (aoffset6+blk)->qs));
-                    c7[1] = reinterpret_cast(vec_xl(0, (aoffset7+blk)->qs));
-                    c8[1] = reinterpret_cast(vec_xl(0, (aoffset8+blk)->qs));
-
-                    process_q4_elements(c1, &comparray[index + 8*blk+0]);
-                    process_q4_elements(c2, &comparray[index + 8*blk+1]);
-                    process_q4_elements(c3, &comparray[index + 8*blk+2]);
-                    process_q4_elements(c4, &comparray[index + 8*blk+3]);
-                    process_q4_elements(c5, &comparray[index + 8*blk+4]);
-                    process_q4_elements(c6, &comparray[index + 8*blk+5]);
-                    process_q4_elements(c7, &comparray[index + 8*blk+6]);
-                    process_q4_elements(c8, &comparray[index + 8*blk+7]);
-                    vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
-                    vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
-                    vector_permute_store(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
-                    vector_permute_store(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
-                    vecOffset += 256;
-                }
-                j--;
-                index += 8*kc;
-            } while (j > 0);
-        }
-    }
-
-    void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t *vec_A, vec_t *vec_B, int *comparray) {
-        acc_t acc[8];
-        for (int i = 0; i < mc ; i += 8) {
-            for (int j = 0; j < nc; j += 8) {
-                vector float fin_res[16] = {0};
-                vector float vs[16] = {0};
-                for (int64_t kk = 0; kk < kc; kk+=2) {
-                    for (int x = 0; x < 8; x++) {
-                        __builtin_mma_xxsetaccz(&acc[x]);
-                    }
-                    int A_block_idx = (i/8)*(16*kc) + kk*16;
-                    int B_block_idx = (j/8)*(16*kc)+ kk*16;
-                    vec_t *A_block = &vec_A[A_block_idx];
-                    vec_t *B_block = &vec_B[B_block_idx];
-                    for (int x = 0; x < 8; x++) {
-                        __builtin_mma_xvi8ger4pp(&acc[0], A_block[x],     B_block[x]);
-                        __builtin_mma_xvi8ger4pp(&acc[1], A_block[x + 8], B_block[x]);
-                        __builtin_mma_xvi8ger4pp(&acc[2], A_block[x],     B_block[x+8]);
-                        __builtin_mma_xvi8ger4pp(&acc[3], A_block[x+8],   B_block[x+8]);
-                    }
-                    compute_scale(ii+i, jj+j, l+kk, vs);
-                    int c_index = (i/8)*(8*kc)+ kk*8;
-                    int* c_block = &comparray[c_index];
-                    compute(&acc[0], 0,  0,  c_block, vs, fin_res);
-                    compute(&acc[1], 4,  4,  c_block, vs, fin_res);
-                    compute(&acc[2], 0,  8,  c_block, vs, fin_res);
-                    compute(&acc[3], 4, 12,  c_block, vs, fin_res);
-
-                    A_block_idx = (i/8)*(16*kc) + (kk+1)*16;
-                    B_block_idx = (j/8)*(16*kc)+ (kk+1)*16;
-                    A_block = &vec_A[A_block_idx];
-                    B_block = &vec_B[B_block_idx];
-                    for (int x = 0; x < 8; x++) {
-                        __builtin_mma_xvi8ger4pp(&acc[4], A_block[x],     B_block[x]);
-                        __builtin_mma_xvi8ger4pp(&acc[5], A_block[x + 8], B_block[x]);
-                        __builtin_mma_xvi8ger4pp(&acc[6], A_block[x],     B_block[x+8]);
-                        __builtin_mma_xvi8ger4pp(&acc[7], A_block[x+8],   B_block[x+8]);
-                    }
-                    compute_scale(ii+i, jj+j, l+kk+1, vs);
-                    c_index = (i/8)*(8*kc)+ (kk+1)*8;
-                    c_block = &comparray[c_index];
-                    compute(&acc[4], 0,  0,  c_block, vs, fin_res);
-                    compute(&acc[5], 4,  4,  c_block, vs, fin_res);
-                    compute(&acc[6], 0,  8,  c_block, vs, fin_res);
-                    compute(&acc[7], 4, 12,  c_block, vs, fin_res);
-
-                }
-                if (l == 0) {
-                    save_res(ii+i,   jj+j,    0,  fin_res);
-                    save_res(ii+i+4, jj+j,    4,  fin_res);
-                    save_res(ii+i,   jj+j+4,  8,  fin_res);
-                    save_res(ii+i+4, jj+j+4, 12,  fin_res);
-                } else {
-                    add_save_res(ii+i,   jj+j,    0,  fin_res);
-                    add_save_res(ii+i+4, jj+j,    4,  fin_res);
-                    add_save_res(ii+i,   jj+j+4,  8,  fin_res);
-                    add_save_res(ii+i+4, jj+j+4, 12,  fin_res);
-                }
-            }
-        }
-    }
-
-    const TA *const A;
-    const block_q8_0 *const B;
-    float *C;
-    const int64_t k;
-    int64_t kc;
-    const int64_t lda;
-    const int64_t ldb;
-    const int64_t ldc;
-    const int ith;
-    const int nth;
-};
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp
index a0cce10aa7c..da412fd009b 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp
@@ -69,6 +69,10 @@
 #define VECTOR_REGISTERS 16
 #endif
 
+#if defined(__riscv_v_intrinsic)
+#define LMUL 4
+#endif
+
 #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
 
 namespace {
@@ -117,7 +121,8 @@ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
 #endif
 
 #if defined(__MMA__)
-#include "sgemm-ppc.h"
+typedef vector unsigned char vec_t;
+typedef __vector_quad acc_t;
 #endif
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 // VECTORIZED FUSED MULTIPLY ADD
@@ -175,6 +180,46 @@ inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
 }
 #endif
 
+#if defined(__riscv_zvfh)
+template <>
+inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) {
+    return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
+}
+inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) {
+    return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
+}
+inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) {
+    return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
+}
+inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) {
+    return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
+}
+inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) {
+    return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
+}
+inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) {
+    return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
+}
+inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) {
+    return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
+}
+inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) {
+    return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
+}
+#endif
+
+#if defined(__riscv_zvfbfwma)
+inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) {
+    return __riscv_vfwmaccbf16_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
+}
+inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) {
+    return __riscv_vfwmaccbf16_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
+}
+inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) {
+    return __riscv_vfwmaccbf16_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
+}
+#endif
+
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 // VECTORIZED HORIZONTAL SUM
 
@@ -227,6 +272,25 @@ inline float hsum(__m512 x) {
 }
 #endif // __AVX512F__
 
+#if defined(__riscv_zvfh)
+inline float hsum(vfloat32m1_t x) {
+    return __riscv_vfmv_f_s_f32m1_f32(
+        __riscv_vfredusum_vs_f32m1_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m1()));
+}
+inline float hsum(vfloat32m2_t x) {
+    return __riscv_vfmv_f_s_f32m1_f32(
+        __riscv_vfredusum_vs_f32m2_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m2()));
+}
+inline float hsum(vfloat32m4_t x) {
+    return __riscv_vfmv_f_s_f32m1_f32(
+        __riscv_vfredusum_vs_f32m4_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m4()));
+}
+inline float hsum(vfloat32m8_t x) {
+    return __riscv_vfmv_f_s_f32m1_f32(
+        __riscv_vfredusum_vs_f32m8_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m8()));
+}
+#endif
+
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 // VECTORIZED MEMORY LOADING
 
@@ -315,6 +379,88 @@ template <> inline __m256bh load(const float *p) {
 }
 #endif
 
+#if defined(__riscv_zvfh)
+template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) {
+    return __riscv_vle16_v_f16mf2(reinterpret_cast(p), __riscv_vsetvlmax_e16mf2());
+}
+template <> inline vfloat16m1_t load(const ggml_fp16_t *p) {
+    return __riscv_vle16_v_f16m1(reinterpret_cast(p), __riscv_vsetvlmax_e16m1());
+}
+template <> inline vfloat16m2_t load(const ggml_fp16_t *p) {
+    return __riscv_vle16_v_f16m2(reinterpret_cast(p), __riscv_vsetvlmax_e16m2());
+}
+template <> inline vfloat16m4_t load(const ggml_fp16_t *p) {
+    return __riscv_vle16_v_f16m4(reinterpret_cast(p), __riscv_vsetvlmax_e16m4());
+}
+template <> inline vfloat32m1_t load(const float *p) {
+    return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1());
+}
+template <> inline vfloat32m2_t load(const float *p) {
+    return __riscv_vle32_v_f32m2(p, __riscv_vsetvlmax_e32m2());
+}
+template <> inline vfloat32m4_t load(const float *p) {
+    return __riscv_vle32_v_f32m4(p, __riscv_vsetvlmax_e32m4());
+}
+template <> inline vfloat32m8_t load(const float *p) {
+    return __riscv_vle32_v_f32m8(p, __riscv_vsetvlmax_e32m8());
+}
+#endif
+
+#if defined(__riscv_zvfbfwma)
+template <> inline vbfloat16mf2_t load(const ggml_bf16_t *p) {
+    return __riscv_vle16_v_bf16mf2(reinterpret_cast(p), __riscv_vsetvlmax_e16mf2());
+}
+template <> inline vbfloat16m1_t load(const ggml_bf16_t *p) {
+    return __riscv_vle16_v_bf16m1(reinterpret_cast(p), __riscv_vsetvlmax_e16m1());
+}
+template <> inline vbfloat16m2_t load(const ggml_bf16_t *p) {
+    return __riscv_vle16_v_bf16m2(reinterpret_cast(p), __riscv_vsetvlmax_e16m2());
+}
+#endif
+
+#if defined(__riscv_zvfh)
+template  T set_zero();
+
+template <> inline vfloat16mf2_t set_zero() {
+    return __riscv_vfmv_v_f_f16mf2(0, __riscv_vsetvlmax_e16mf2());
+}
+template <> inline vfloat16m1_t set_zero() {
+    return __riscv_vfmv_v_f_f16m1(0, __riscv_vsetvlmax_e16m1());
+}
+template <> inline vfloat16m2_t set_zero() {
+    return __riscv_vfmv_v_f_f16m2(0, __riscv_vsetvlmax_e16m2());
+}
+template <> inline vfloat16m4_t set_zero() {
+    return __riscv_vfmv_v_f_f16m4(0, __riscv_vsetvlmax_e16m4());
+}
+template <> inline vfloat32m1_t set_zero() {
+    return __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvlmax_e32m1());
+}
+template <> inline vfloat32m2_t set_zero() {
+    return __riscv_vfmv_v_f_f32m2(0, __riscv_vsetvlmax_e32m2());
+}
+template <> inline vfloat32m4_t set_zero() {
+    return __riscv_vfmv_v_f_f32m4(0, __riscv_vsetvlmax_e32m4());
+}
+template <> inline vfloat32m8_t set_zero() {
+    return __riscv_vfmv_v_f_f32m8(0, __riscv_vsetvlmax_e32m8());
+}
+#endif
+
+#if defined(__riscv_v_intrinsic)
+template  size_t vlmax() {
+    if constexpr (std::is_same_v) { return  __riscv_vsetvlmax_e16mf2(); }
+    else if constexpr (std::is_same_v) { return  __riscv_vsetvlmax_e16m1(); }
+    else if constexpr (std::is_same_v) { return  __riscv_vsetvlmax_e16m2(); }
+    else if constexpr (std::is_same_v) { return  __riscv_vsetvlmax_e16m4(); }
+    else if constexpr (std::is_same_v) { return  __riscv_vsetvlmax_e32m1(); }
+    else if constexpr (std::is_same_v) { return  __riscv_vsetvlmax_e32m2(); }
+    else if constexpr (std::is_same_v) { return  __riscv_vsetvlmax_e32m4(); }
+    else if constexpr (std::is_same_v) { return  __riscv_vsetvlmax_e32m8(); }
+    return 0;
+}
+#endif
+
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 // FLOATING POINT MATRIX MULTIPLICATION
 
@@ -488,6 +634,573 @@ class tinyBLAS {
     const int64_t ldc;
 };
 
+#if defined(__riscv_v_intrinsic)
+template 
+class tinyBLAS_RVV {
+  public:
+    tinyBLAS_RVV(const ggml_compute_params * params, int64_t k,
+             const TA *A, int64_t lda,
+             const TB *B, int64_t ldb,
+             TC *C, int64_t ldc)
+        : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
+    }
+
+    bool matmul(int64_t m, int64_t n) {
+        if (k % vlmax() != 0) {
+            return false;
+        }
+
+#if LMUL == 1
+        if (m % 16 == 0 && (m/16 >= params->nth)) {
+            const int64_t SIZE_N = BLOCK_SIZE<6>(n);
+            mnpack<4, 6, 4>(m, n, SIZE_N, 12);
+            return true;
+        }
+        if (m % 8 == 0 ) {
+            const int64_t SIZE_N = BLOCK_SIZE<6>(n);
+            mnpack<4, 6, 2>(m, n, SIZE_N, 12);
+            return true;
+        }
+        if (m % 4 == 0) {
+            const int64_t SIZE_N = BLOCK_SIZE<6>(n);
+            mnpack<4, 6, 1>(m, n, SIZE_N, 12);
+            return true;
+        }
+#elif LMUL == 2
+        if (m % 16 == 0 && (m/16 >= params->nth)) {
+            const int64_t SIZE_N = BLOCK_SIZE<3>(n);
+            mnpack<4, 3, 4>(m, n, SIZE_N, 24);
+            return true;
+        }
+        if (m % 8 == 0 ) {
+            const int64_t SIZE_N = BLOCK_SIZE<3>(n);
+            mnpack<4, 3, 2>(m, n, SIZE_N, 24);
+            return true;
+        }
+        if (m % 4 == 0) {
+            const int64_t SIZE_N = BLOCK_SIZE<3>(n);
+            mnpack<4, 3, 1>(m, n, SIZE_N, 24);
+            return true;
+        }
+#else // LMUL = 4
+        if (m % 16 == 0 && (m/16 >= params->nth)) {
+            const int64_t SIZE_N = BLOCK_SIZE<2>(n);
+            mnpack<2, 2, 8>(m, n, SIZE_N, 36);
+            return true;
+        }
+        if (m % 8 == 0 ) {
+            const int64_t SIZE_N = BLOCK_SIZE<2>(n);
+            mnpack<2, 2, 4>(m, n, SIZE_N, 36);
+            return true;
+        }
+        if (m % 4 == 0) {
+            const int64_t SIZE_N = BLOCK_SIZE<2>(n);
+            mnpack<2, 2, 2>(m, n, SIZE_N, 36);
+            return true;
+        }
+#endif
+        return false;
+    }
+
+  private:
+    template
+    inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
+        if (SIZE_N == RN) {
+            return gemm(m, n, BN);
+        }
+        if constexpr (RN > 1) {
+            return mnpack(m, n, SIZE_N, BN);
+        } else {
+            GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
+            GGML_ASSERT(false); // we have miss something.
+        }
+    }
+
+    inline void gemm_bloc_4x6(int64_t ii, int64_t jj) {
+        size_t vl = vlmax();
+        D Cv00 = set_zero();
+        D Cv01 = set_zero();
+        D Cv02 = set_zero();
+        D Cv03 = set_zero();
+        D Cv10 = set_zero();
+        D Cv11 = set_zero();
+        D Cv12 = set_zero();
+        D Cv13 = set_zero();
+        D Cv20 = set_zero();
+        D Cv21 = set_zero();
+        D Cv22 = set_zero();
+        D Cv23 = set_zero();
+        D Cv30 = set_zero();
+        D Cv31 = set_zero();
+        D Cv32 = set_zero();
+        D Cv33 = set_zero();
+        D Cv40 = set_zero();
+        D Cv41 = set_zero();
+        D Cv42 = set_zero();
+        D Cv43 = set_zero();
+        D Cv50 = set_zero();
+        D Cv51 = set_zero();
+        D Cv52 = set_zero();
+        D Cv53 = set_zero();
+
+        for (int64_t l = 0; l < k; l += vl) {
+            V Bv0 = load(B + ldb * (jj + 0) + l);
+            V Bv1 = load(B + ldb * (jj + 1) + l);
+            V Bv2 = load(B + ldb * (jj + 2) + l);
+            V Bv3 = load(B + ldb * (jj + 3) + l);
+            V Bv4 = load(B + ldb * (jj + 4) + l);
+            V Bv5 = load(B + ldb * (jj + 5) + l);
+
+            V Av0 = load(A + lda * (ii + 0) + l);
+            Cv00 = madd(Av0, Bv0, Cv00);
+            Cv10 = madd(Av0, Bv1, Cv10);
+            Cv20 = madd(Av0, Bv2, Cv20);
+            Cv30 = madd(Av0, Bv3, Cv30);
+            Cv40 = madd(Av0, Bv4, Cv40);
+            Cv50 = madd(Av0, Bv5, Cv50);
+
+            V Av1 = load(A + lda * (ii + 1) + l);
+            Cv01 = madd(Av1, Bv0, Cv01);
+            Cv11 = madd(Av1, Bv1, Cv11);
+            Cv21 = madd(Av1, Bv2, Cv21);
+            Cv31 = madd(Av1, Bv3, Cv31);
+            Cv41 = madd(Av1, Bv4, Cv41);
+            Cv51 = madd(Av1, Bv5, Cv51);
+
+            V Av2 = load(A + lda * (ii + 2) + l);
+            Cv02 = madd(Av2, Bv0, Cv02);
+            Cv12 = madd(Av2, Bv1, Cv12);
+            Cv22 = madd(Av2, Bv2, Cv22);
+            Cv32 = madd(Av2, Bv3, Cv32);
+            Cv42 = madd(Av2, Bv4, Cv42);
+            Cv52 = madd(Av2, Bv5, Cv52);
+
+            V Av3 = load(A + lda * (ii + 3) + l);
+            Cv03 = madd(Av3, Bv0, Cv03);
+            Cv13 = madd(Av3, Bv1, Cv13);
+            Cv23 = madd(Av3, Bv2, Cv23);
+            Cv33 = madd(Av3, Bv3, Cv33);
+            Cv43 = madd(Av3, Bv4, Cv43);
+            Cv53 = madd(Av3, Bv5, Cv53);
+        }
+
+        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
+        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
+        C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
+        C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
+        C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
+        C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
+        C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
+        C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
+        C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
+        C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
+        C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
+        C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
+        C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
+        C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
+        C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
+        C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
+        C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40);
+        C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41);
+        C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42);
+        C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43);
+        C[ldc * (jj + 5) + (ii + 0)] = hsum(Cv50);
+        C[ldc * (jj + 5) + (ii + 1)] = hsum(Cv51);
+        C[ldc * (jj + 5) + (ii + 2)] = hsum(Cv52);
+        C[ldc * (jj + 5) + (ii + 3)] = hsum(Cv53);
+    }
+
+    inline void gemm_bloc_4x5(int64_t ii, int64_t jj) {
+        size_t vl = vlmax();
+        D Cv00 = set_zero();
+        D Cv01 = set_zero();
+        D Cv02 = set_zero();
+        D Cv03 = set_zero();
+        D Cv10 = set_zero();
+        D Cv11 = set_zero();
+        D Cv12 = set_zero();
+        D Cv13 = set_zero();
+        D Cv20 = set_zero();
+        D Cv21 = set_zero();
+        D Cv22 = set_zero();
+        D Cv23 = set_zero();
+        D Cv30 = set_zero();
+        D Cv31 = set_zero();
+        D Cv32 = set_zero();
+        D Cv33 = set_zero();
+        D Cv40 = set_zero();
+        D Cv41 = set_zero();
+        D Cv42 = set_zero();
+        D Cv43 = set_zero();
+
+        for (int64_t l = 0; l < k; l += vl) {
+            V Bv0 = load(B + ldb * (jj + 0) + l);
+            V Bv1 = load(B + ldb * (jj + 1) + l);
+            V Bv2 = load(B + ldb * (jj + 2) + l);
+            V Bv3 = load(B + ldb * (jj + 3) + l);
+            V Bv4 = load(B + ldb * (jj + 4) + l);
+
+            V Av0 = load(A + lda * (ii + 0) + l);
+            Cv00 = madd(Av0, Bv0, Cv00);
+            Cv10 = madd(Av0, Bv1, Cv10);
+            Cv20 = madd(Av0, Bv2, Cv20);
+            Cv30 = madd(Av0, Bv3, Cv30);
+            Cv40 = madd(Av0, Bv4, Cv40);
+
+            V Av1 = load(A + lda * (ii + 1) + l);
+            Cv01 = madd(Av1, Bv0, Cv01);
+            Cv11 = madd(Av1, Bv1, Cv11);
+            Cv21 = madd(Av1, Bv2, Cv21);
+            Cv31 = madd(Av1, Bv3, Cv31);
+            Cv41 = madd(Av1, Bv4, Cv41);
+
+            V Av2 = load(A + lda * (ii + 2) + l);
+            Cv02 = madd(Av2, Bv0, Cv02);
+            Cv12 = madd(Av2, Bv1, Cv12);
+            Cv22 = madd(Av2, Bv2, Cv22);
+            Cv32 = madd(Av2, Bv3, Cv32);
+            Cv42 = madd(Av2, Bv4, Cv42);
+
+            V Av3 = load(A + lda * (ii + 3) + l);
+            Cv03 = madd(Av3, Bv0, Cv03);
+            Cv13 = madd(Av3, Bv1, Cv13);
+            Cv23 = madd(Av3, Bv2, Cv23);
+            Cv33 = madd(Av3, Bv3, Cv33);
+            Cv43 = madd(Av3, Bv4, Cv43);
+        }
+
+        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
+        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
+        C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
+        C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
+        C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
+        C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
+        C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
+        C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
+        C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
+        C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
+        C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
+        C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
+        C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
+        C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
+        C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
+        C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
+        C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40);
+        C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41);
+        C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42);
+        C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43);
+    }
+
+    inline void gemm_bloc_4x4(int64_t ii, int64_t jj) {
+        size_t vl = vlmax();
+        D Cv00 = set_zero();
+        D Cv01 = set_zero();
+        D Cv02 = set_zero();
+        D Cv03 = set_zero();
+        D Cv10 = set_zero();
+        D Cv11 = set_zero();
+        D Cv12 = set_zero();
+        D Cv13 = set_zero();
+        D Cv20 = set_zero();
+        D Cv21 = set_zero();
+        D Cv22 = set_zero();
+        D Cv23 = set_zero();
+        D Cv30 = set_zero();
+        D Cv31 = set_zero();
+        D Cv32 = set_zero();
+        D Cv33 = set_zero();
+
+        for (int64_t l = 0; l < k; l += vl) {
+            V Av0 = load(A + lda * (ii + 0) + l);
+            V Av1 = load(A + lda * (ii + 1) + l);
+            V Av2 = load(A + lda * (ii + 2) + l);
+            V Av3 = load(A + lda * (ii + 3) + l);
+
+            V Bv0 = load(B + ldb * (jj + 0) + l);
+            Cv00 = madd(Av0, Bv0, Cv00);
+            Cv01 = madd(Av1, Bv0, Cv01);
+            Cv02 = madd(Av2, Bv0, Cv02);
+            Cv03 = madd(Av3, Bv0, Cv03);
+
+            V Bv1 = load(B + ldb * (jj + 1) + l);
+            Cv10 = madd(Av0, Bv1, Cv10);
+            Cv11 = madd(Av1, Bv1, Cv11);
+            Cv12 = madd(Av2, Bv1, Cv12);
+            Cv13 = madd(Av3, Bv1, Cv13);
+
+            V Bv2 = load(B + ldb * (jj + 2) + l);
+            Cv20 = madd(Av0, Bv2, Cv20);
+            Cv21 = madd(Av1, Bv2, Cv21);
+            Cv22 = madd(Av2, Bv2, Cv22);
+            Cv23 = madd(Av3, Bv2, Cv23);
+
+            V Bv3 = load(B + ldb * (jj + 3) + l);
+            Cv30 = madd(Av0, Bv3, Cv30);
+            Cv31 = madd(Av1, Bv3, Cv31);
+            Cv32 = madd(Av2, Bv3, Cv32);
+            Cv33 = madd(Av3, Bv3, Cv33);
+        }
+
+        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
+        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
+        C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
+        C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
+        C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
+        C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
+        C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
+        C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
+        C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
+        C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
+        C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
+        C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
+        C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30);
+        C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31);
+        C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32);
+        C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33);
+    }
+
+    inline void gemm_bloc_4x3(int64_t ii, int64_t jj) {
+        size_t vl = vlmax();
+        D Cv00 = set_zero();
+        D Cv01 = set_zero();
+        D Cv02 = set_zero();
+        D Cv03 = set_zero();
+        D Cv10 = set_zero();
+        D Cv11 = set_zero();
+        D Cv12 = set_zero();
+        D Cv13 = set_zero();
+        D Cv20 = set_zero();
+        D Cv21 = set_zero();
+        D Cv22 = set_zero();
+        D Cv23 = set_zero();
+
+        for (int64_t l = 0; l < k; l += vl) {
+            V Av0 = load(A + lda * (ii + 0) + l);
+            V Av1 = load(A + lda * (ii + 1) + l);
+            V Av2 = load(A + lda * (ii + 2) + l);
+            V Av3 = load(A + lda * (ii + 3) + l);
+
+            V Bv0 = load(B + ldb * (jj + 0) + l);
+            Cv00 = madd(Av0, Bv0, Cv00);
+            Cv01 = madd(Av1, Bv0, Cv01);
+            Cv02 = madd(Av2, Bv0, Cv02);
+            Cv03 = madd(Av3, Bv0, Cv03);
+
+            V Bv1 = load(B + ldb * (jj + 1) + l);
+            Cv10 = madd(Av0, Bv1, Cv10);
+            Cv11 = madd(Av1, Bv1, Cv11);
+            Cv12 = madd(Av2, Bv1, Cv12);
+            Cv13 = madd(Av3, Bv1, Cv13);
+
+            V Bv2 = load(B + ldb * (jj + 2) + l);
+            Cv20 = madd(Av0, Bv2, Cv20);
+            Cv21 = madd(Av1, Bv2, Cv21);
+            Cv22 = madd(Av2, Bv2, Cv22);
+            Cv23 = madd(Av3, Bv2, Cv23);
+        }
+
+        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
+        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
+        C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
+        C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
+        C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
+        C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
+        C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
+        C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
+        C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20);
+        C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21);
+        C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22);
+        C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23);
+    }
+
+    inline void gemm_bloc_4x2(int64_t ii, int64_t jj) {
+        size_t vl = vlmax();
+        D Cv00 = set_zero();
+        D Cv01 = set_zero();
+        D Cv02 = set_zero();
+        D Cv03 = set_zero();
+        D Cv10 = set_zero();
+        D Cv11 = set_zero();
+        D Cv12 = set_zero();
+        D Cv13 = set_zero();
+
+        for (int64_t l = 0; l < k; l += vl) {
+            V Av0 = load(A + lda * (ii + 0) + l);
+            V Av1 = load(A + lda * (ii + 1) + l);
+            V Av2 = load(A + lda * (ii + 2) + l);
+            V Av3 = load(A + lda * (ii + 3) + l);
+
+            V Bv0 = load(B + ldb * (jj + 0) + l);
+            Cv00 = madd(Av0, Bv0, Cv00);
+            Cv01 = madd(Av1, Bv0, Cv01);
+            Cv02 = madd(Av2, Bv0, Cv02);
+            Cv03 = madd(Av3, Bv0, Cv03);
+
+            V Bv1 = load(B + ldb * (jj + 1) + l);
+            Cv10 = madd(Av0, Bv1, Cv10);
+            Cv11 = madd(Av1, Bv1, Cv11);
+            Cv12 = madd(Av2, Bv1, Cv12);
+            Cv13 = madd(Av3, Bv1, Cv13);
+        }
+
+        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
+        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
+        C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
+        C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
+        C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
+        C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
+        C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12);
+        C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13);
+    }
+
+    inline void gemm_bloc_4x1(int64_t ii, int64_t jj) {
+        size_t vl = vlmax();
+        D Cv00 = set_zero();
+        D Cv01 = set_zero();
+        D Cv02 = set_zero();
+        D Cv03 = set_zero();
+
+        for (int64_t l = 0; l < k; l += vl) {
+            V Av0 = load(A + lda * (ii + 0) + l);
+            V Av1 = load(A + lda * (ii + 1) + l);
+            V Av2 = load(A + lda * (ii + 2) + l);
+            V Av3 = load(A + lda * (ii + 3) + l);
+
+            V Bv0 = load(B + ldb * (jj + 0) + l);
+            Cv00 = madd(Av0, Bv0, Cv00);
+            Cv01 = madd(Av1, Bv0, Cv01);
+            Cv02 = madd(Av2, Bv0, Cv02);
+            Cv03 = madd(Av3, Bv0, Cv03);
+        }
+
+        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
+        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
+        C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02);
+        C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03);
+    }
+
+    inline void gemm_bloc_2x2(int64_t ii, int64_t jj) {
+        size_t vl = vlmax();
+        D Cv00 = set_zero();
+        D Cv01 = set_zero();
+        D Cv10 = set_zero();
+        D Cv11 = set_zero();
+
+        for (int64_t l = 0; l < k; l += vl) {
+            V Av0 = load(A + lda * (ii + 0) + l);
+            V Av1 = load(A + lda * (ii + 1) + l);
+
+            V Bv0 = load(B + ldb * (jj + 0) + l);
+            Cv00 = madd(Av0, Bv0, Cv00);
+            Cv01 = madd(Av1, Bv0, Cv01);
+
+            V Bv1 = load(B + ldb * (jj + 1) + l);
+            Cv10 = madd(Av0, Bv1, Cv10);
+            Cv11 = madd(Av1, Bv1, Cv11);
+        }
+
+        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
+        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
+        C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10);
+        C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11);
+    }
+
+    inline void gemm_bloc_2x1(int64_t ii, int64_t jj) {
+        size_t vl = vlmax();
+        D Cv00 = set_zero();
+        D Cv01 = set_zero();
+
+        for (int64_t l = 0; l < k; l += vl) {
+            V Av0 = load(A + lda * (ii + 0) + l);
+            V Av1 = load(A + lda * (ii + 1) + l);
+
+            V Bv0 = load(B + ldb * (jj + 0) + l);
+            Cv00 = madd(Av0, Bv0, Cv00);
+            Cv01 = madd(Av1, Bv0, Cv01);
+        }
+
+        C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00);
+        C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01);
+    }
+
+    template 
+    inline void gemm_bloc(int64_t ii, int64_t jj) {
+        if constexpr (RM == 4) {
+            if constexpr (RN == 6) { return gemm_bloc_4x6(ii, jj); }
+            if constexpr (RN == 5) { return gemm_bloc_4x5(ii, jj); }
+            if constexpr (RN == 4) { return gemm_bloc_4x4(ii, jj); }
+            if constexpr (RN == 3) { return gemm_bloc_4x3(ii, jj); }
+            if constexpr (RN == 2) { return gemm_bloc_4x2(ii, jj); }
+            if constexpr (RN == 1) { return gemm_bloc_4x1(ii, jj); }
+        } else if constexpr (RM == 2) {
+            if constexpr (RN == 2) { return gemm_bloc_2x2(ii, jj); }
+            if constexpr (RN == 1) { return gemm_bloc_2x1(ii, jj); }
+        }
+    }
+
+    template 
+    NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
+        GGML_ASSERT(m % (RM * BM) == 0);
+        const int64_t ytiles = m / (RM * BM);
+        const int64_t xtiles = (n + RN -1) / RN;
+        const int64_t jj_RN = (xtiles - (xtiles * RN - n));
+
+        // "round" bloc_size to "nearest" BN
+        const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
+        const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
+        const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
+        const int64_t nb_job = ytiles * NB_BN;
+
+        if (params->ith == 0) {
+            GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
+            // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.
+            ggml_threadpool_chunk_set(params->threadpool, params->nth);
+        }
+
+        ggml_barrier(params->threadpool);
+
+        int64_t job = params->ith;
+        while (job < nb_job) {
+            const int64_t ii = (job % ytiles) * RM * BM;
+            const int64_t jb =  job / ytiles;
+            const int64_t jr0 = BLOC_POS(jb  , jj_BN, SIZE_BN);
+            const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);
+
+            const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);
+            const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);
+            const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
+
+            for (int64_t bi = 0; bi < BM * RM; bi += RM) {
+                int64_t jj = jj0;
+                for (; jj < jj1; jj += RN) {
+                    gemm_bloc(ii + bi, jj);
+                }
+                if constexpr (RN > 1) {
+                    for (; jj < jj2; jj += RN - 1) {
+                        gemm_bloc(ii + bi, jj);
+                    }
+                }
+                GGML_ASSERT(jj == jj2);
+            }
+
+            job = ggml_threadpool_chunk_add(params->threadpool, 1);
+        }
+
+        ggml_barrier(params->threadpool);
+        return;
+    }
+
+    const ggml_compute_params * params;
+    const TA *const A;
+    const TB *const B;
+    TC *const C;
+    const int64_t k;
+    const int64_t lda;
+    const int64_t ldb;
+    const int64_t ldc;
+};
+#endif
+
 //////////////////////////////////////////////////////////////////////////////////////////
 // QUANT ZERO MATRIX MULTIPLICATION
 
@@ -1085,10 +1798,27 @@ class tinyBLAS_Q0_AVX {
       } \
    } \
 
+template
+struct mma_instr;
+
+template<>
+struct mma_instr {
+    static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
+        __builtin_mma_xvbf16ger2pp(acc, a, b);
+    }
+};
+
+template<>
+struct mma_instr {
+    static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
+        __builtin_mma_xvf16ger2pp(acc, a, b);
+    }
+};
+
 template 
-class tinyBLAS_BF16_PPC {
+class tinyBLAS_HP16_PPC {
   public:
-    tinyBLAS_BF16_PPC(int64_t k,
+    tinyBLAS_HP16_PPC(int64_t k,
                 const TA *A, int64_t lda,
                 const TB *B, int64_t ldb,
                 TC *C, int64_t ldc,
@@ -1406,8 +2136,8 @@ class tinyBLAS_BF16_PPC {
             packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);
             packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);
             for (int x = 0; x < 4; x++) {
-                __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
-                __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
+                mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]);
+                mma_instr::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
             }
         }
         SAVE_ACC(&acc_0, ii, jj);
@@ -1423,8 +2153,8 @@ class tinyBLAS_BF16_PPC {
             packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);
             packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
             for (int x = 0; x < 4; x++) {
-                __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
-                __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]);
+                mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]);
+                mma_instr::outer_product(&acc_1, vec_A[x+4], vec_B[x]);
             }
         }
         SAVE_ACC(&acc_0, ii, jj);
@@ -1443,10 +2173,10 @@ class tinyBLAS_BF16_PPC {
             packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);
             packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);
             for (int x = 0; x < 4; x++) {
-                __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
-                __builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]);
-                __builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]);
-                __builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]);
+                mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]);
+                mma_instr::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
+                mma_instr::outer_product(&acc_2, vec_A[x+4], vec_B[x]);
+                mma_instr::outer_product(&acc_3, vec_A[x+4], vec_B[x+4]);
             }
         }
 
@@ -1477,7 +2207,7 @@ class tinyBLAS_BF16_PPC {
                 packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);
                 packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);
                 for (int x = 0; x<2; x++) {
-                    __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
+                    mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]);
                 }
             }
             __builtin_mma_disassemble_acc(vec_C, &acc_0);
@@ -1512,8 +2242,8 @@ class tinyBLAS_BF16_PPC {
                 packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);
                 packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);
                 for (int x = 0; x<4; x++) {
-                    __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
-                    __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
+                    mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]);
+                    mma_instr::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
                 }
             }
             __builtin_mma_disassemble_acc(vec_C, &acc_0);
@@ -1572,43 +2302,299 @@ class tinyBLAS_BF16_PPC {
     const int nth;
 };
 
-    template 
-    tinyBLAS_Q0_PPC::tinyBLAS_Q0_PPC(int64_t k,
-        const TA *A, int64_t lda,
-        const block_q8_0 *B, int64_t ldb,
-        float *C, int64_t ldc,
-        int ith, int nth)
+template 
+class tinyBLAS_Q0_PPC {
+  public:
+    tinyBLAS_Q0_PPC(int64_t k,
+             const TA * A, int64_t lda,
+             const block_q8_0 * B, int64_t ldb,
+             float * C, int64_t ldc,
+             int ith, int nth)
         : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
-                kc = 64;
     }
 
-    template
-    void tinyBLAS_Q0_PPC::matmul(int64_t m, int64_t n) {
-        int mc = 64; int nc = 64;
-        if (n % 8 == 0 && n < nc) {
-                nc = n;
-                mc = 32 ;
-                kc = 32;
-        }
-        const bool is_aligned = ((m & (mc - 1)) == 0) & ((n & (nc - 1)) == 0) & ((k & (kc - 1)) == 0);
-        if (is_aligned) {
-            this->matmul_tiled_q0(m, n, mc, nc, kc);
+    void matmul(int64_t m, int64_t n) {
+        const int64_t mc = 64;
+        const int64_t kc = 64;
+        int64_t nc = 64;
+        int64_t n_aligned = 0;
+        if (n % 64 == 0) {
+            n_aligned = n;
+        } else if (n == 4) {
+            n_aligned = 4;
+        } else if (n < 64) {
+            n_aligned = (n / 8) * 8;
+        } else {
+            n_aligned = (n / 64) * 64;
+        }
+
+        if (n_aligned > 0) {
+            if (n_aligned % 64 == 0)      nc = 64;
+            else if (n_aligned == n)      nc = n;
+            else if (n_aligned % 32 == 0) nc = 32;
+            else if (n_aligned % 24 == 0) nc = 24;
+            else if (n_aligned % 16 == 0) nc = 16;
+            else                          nc = 8;
+        }
+        bool can_use_tiled = n_aligned > 0 && (m % mc == 0) && (k % kc == 0);
+        if (can_use_tiled) {
+            matmul_tiled(m, n_aligned, mc, nc, kc);
+            if (n > n_aligned) {
+                mnpack(0, m, n_aligned, n);
+            }
         } else {
             mnpack(0, m, 0, n);
         }
     }
 
-   template
-   template
-   void tinyBLAS_Q0_PPC::packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array& comparray) {
+  private:
+    inline void save_res(int ii, int jj, int idx, vector float * fin_res, int RM = 4, int RN = 4) {
+        for (int I = 0; I < RM; I++) {
+            for (int J = 0; J < RN; J++) {
+                *((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&fin_res[idx + I] + J);
+            }
+        }
+    }
+
+    inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
+        vec_t vec_C[4];
+        __builtin_mma_disassemble_acc(vec_C, ACC);
+        for (int I = 0; I < 4; I++) {
+            for (int J = 0; J < 4; J++) {
+                *((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&vec_C[I] + J);
+            }
+        }
+    }
+
+    inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
+        vec_t vec_C[4];
+        __builtin_mma_disassemble_acc(vec_C, ACC);
+        for (int I = 0; I < 4; I++) {
+            for (int J = 0; J < 4; J++) {
+                float * c_ptr = (float *)(C + ii+ ((jj + J) * ldc) + I);
+                *c_ptr += *((float *)&vec_C[I] + J);
+            }
+        }
+    }
+
+    template
+    inline void compute(acc_t * ACC, int c_idx, int s_idx, ArrayType & comparray, vector float * vs, vector float * fin_res) {
+        vector signed int vec_C[4];
+        vector float CA[4] = {0};
+        vector float res[4] = {0};
+        __builtin_mma_disassemble_acc(vec_C, ACC);
+        for (int i = 0; i < 4; i++) {
+            CA[i] = vec_splats((float)(((double)comparray[c_idx + i]) * -128.0));
+            res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
+            fin_res[s_idx + i] = vec_madd(res[i], vs[s_idx + i], fin_res[s_idx + i]);
+        }
+    }
+
+    inline void process_q4_elements(vector signed char (&c)[2], int * ca) {
+        const vector signed char lowMask = vec_splats((signed char)0xF);
+        const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+        const vector signed char v8 = vec_splats((signed char)0x8);
+        vector signed int vsum = {0};
+        vector signed int vsum2 = {0};
+        c[0] = vec_and(c[1], lowMask);
+        c[1] = vec_sr(c[1], v4);
+        c[0] = vec_sub(c[0], v8);
+        c[1] = vec_sub(c[1], v8);
+        vsum = vec_sum4s(c[0], vsum);
+        vsum2 = vec_sum4s(c[1], vsum2);
+        vsum = vec_add(vsum, vsum2);
+        *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+    }
+
+    template 
+    inline void vector_permute_store(V2 & s1, V2 & s2, V2 & s3, V2 & s4, V1 * vecOffset, bool flip) {
+        vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
+        vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
+        vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
+        vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
+        V2 t1, t2, t3, t4, t5, t6, t7, t8;
+        vector unsigned char xor_vector;
+        uint8_t flip_vec = 0x80;
+        xor_vector = vec_splats(flip_vec);
+        t1 = vec_perm(s1, s2, swiz1);
+        t2 = vec_perm(s1, s2, swiz2);
+        t3 = vec_perm(s3, s4, swiz1);
+        t4 = vec_perm(s3, s4, swiz2);
+        t5 = vec_perm(t1, t3, swiz3);
+        t6 = vec_perm(t1, t3, swiz4);
+        t7 = vec_perm(t2, t4, swiz3);
+        t8 = vec_perm(t2, t4, swiz4);
+        if (flip == true) {
+            t5 = vec_xor(t5, xor_vector);
+            t6 = vec_xor(t6, xor_vector);
+            t7 = vec_xor(t7, xor_vector);
+            t8 = vec_xor(t8, xor_vector);
+        }
+        vec_xst(t5, 0, vecOffset);
+        vec_xst(t6, 0, vecOffset + 16);
+        vec_xst(t7, 0, vecOffset + 32);
+        vec_xst(t8, 0, vecOffset + 48);
+    }
+
+    inline void unpack_q4_to_q8(vector signed char packed, vector signed char & lo, vector signed char & hi) {
+        const vector signed char lowMask = vec_splats((signed char)0x0F);
+        const vector signed char v8      = vec_splats((signed char)0x08);
+        const vector unsigned char v4    = vec_splats((unsigned char)4);
+        lo = vec_and(packed, lowMask);
+        hi = vec_sr(packed, v4);
+        lo = vec_sub(lo, v8);
+        hi = vec_sub(hi, v8);
+    }
+
+    inline void vector_permute_store_fp16(vec_t * c, unsigned char * vecOffset) {
+        vec_t t[8], s[8];
+        vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
+        vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
+        vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
+        vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
+        for (int i = 0; i < 4; i += 2) {
+            t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);
+            t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);
+        }
+        for (int i = 4; i < 8; i += 2) {
+            t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);
+            t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);
+        }
+        s[0] = vec_perm(t[0], t[2], swiz3);
+        s[1] = vec_perm(t[0], t[2], swiz4);
+        s[2] = vec_perm(t[1], t[3], swiz3);
+        s[3] = vec_perm(t[1], t[3], swiz4);
+        s[4] = vec_perm(t[4], t[6], swiz3);
+        s[5] = vec_perm(t[4], t[6], swiz4);
+        s[6] = vec_perm(t[5], t[7], swiz3);
+        s[7] = vec_perm(t[5], t[7], swiz4);
+        for (int i = 0; i < 8; ++i) {
+            vec_xst(s[i], 0, (vec_t *)(vecOffset + i * 16));
+        }
+    }
+
+    static inline void convert_and_scale_q8(vector signed char raw, vector float v_scale, vector unsigned short & out_hi, vector unsigned short & out_lo) {
+        vector signed short i16_hi = vec_unpackh(raw);
+        vector signed short i16_lo = vec_unpackl(raw);
+
+        vector float f_hi_h = vec_ctf(vec_unpackh(i16_hi), 0);
+        vector float f_hi_l = vec_ctf(vec_unpackl(i16_hi), 0);
+        vector float f_lo_h = vec_ctf(vec_unpackh(i16_lo), 0);
+        vector float f_lo_l = vec_ctf(vec_unpackl(i16_lo), 0);
+        out_hi = vec_pack_to_short_fp32(vec_mul(f_hi_h, v_scale), vec_mul(f_hi_l, v_scale));
+        out_lo = vec_pack_to_short_fp32(vec_mul(f_lo_h, v_scale), vec_mul(f_lo_l, v_scale));
+    }
+
+    void packNormal_q4_fp16(const block_q4_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
+        unsigned char * vecOffset = vec;
+        for (int i = 0; i < rows; i += 8) {
+            const block_q4_0 * rows_base[8];
+            for (int r = 0; r < 8; r++) {
+                rows_base[r] = a + (i + r) * lda;
+            }
+            for (int blk = 0; blk < blocks; blk++) {
+                vector unsigned short hp_res[8][4];
+                for (int r = 0; r < 8; r++) {
+                    const block_q4_0 * current_blk = rows_base[r] + blk;
+                    vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(current_blk->d));
+                    vector signed char v_qs = reinterpret_cast(vec_xl(0, current_blk->qs));
+                    vector signed char c1, c2;
+                    unpack_q4_to_q8(v_qs, c1, c2);
+                    convert_and_scale_q8(c1, v_scale, hp_res[r][0], hp_res[r][1]);
+                    convert_and_scale_q8(c2, v_scale, hp_res[r][2], hp_res[r][3]);
+                }
+                for (int c = 0; c < 4; c++) {
+                    vector unsigned char c_arr[8];
+                    for (int r = 0; r < 8; r++) {
+                        c_arr[r] = (vector unsigned char)hp_res[r][c];
+                    }
+                    vector_permute_store_fp16((vec_t *)c_arr, vecOffset);
+                    vecOffset += 128;
+                }
+            }
+        }
+    }
+
+    template 
+    static inline void pack_q8_block(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
+        unsigned char * vecOffset = vec;
+        const vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
+        const vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
+        const vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
+        const vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
+
+        for (int i = 0; i < rows; i += chunk_size) {
+            const block_q8_0 * rows_base[chunk_size];
+            for (int r = 0; r < chunk_size; r++) {
+                rows_base[r] = a + (i + r) * lda;
+            }
+            for (int blk = 0; blk < blocks; blk++) {
+                vector unsigned short hp_res[chunk_size][4];
+                for (int r = 0; r < chunk_size; r++) {
+                    const block_q8_0 * b = rows_base[r] + blk;
+                    vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(b->d));
+                    vector signed char c[2];
+                    __vector_pair pair = __builtin_vsx_lxvp(0, (__vector_pair *)b->qs);
+                    __builtin_vsx_disassemble_pair(c, & pair);
+                    convert_and_scale_q8(c[0], v_scale, hp_res[r][0], hp_res[r][1]);
+                    convert_and_scale_q8(c[1], v_scale, hp_res[r][2], hp_res[r][3]);
+                }
+                for (int col = 0; col < 4; col++) {
+                    if constexpr (chunk_size == 8) {
+                        vec_t t[8];
+                        t[0] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);
+                        t[1] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);
+                        t[2] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);
+                        t[3] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);
+                        t[4] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz1);
+                        t[5] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz2);
+                        t[6] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz1);
+                        t[7] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz2);
+
+                        vec_xst(vec_perm(t[0], t[2], swiz3), 0, (vec_t *)(vecOffset + 0));
+                        vec_xst(vec_perm(t[0], t[2], swiz4), 0, (vec_t *)(vecOffset + 16));
+                        vec_xst(vec_perm(t[1], t[3], swiz3), 0, (vec_t *)(vecOffset + 32));
+                        vec_xst(vec_perm(t[1], t[3], swiz4), 0, (vec_t *)(vecOffset + 48));
+                        vec_xst(vec_perm(t[4], t[6], swiz3), 0, (vec_t *)(vecOffset + 64));
+                        vec_xst(vec_perm(t[4], t[6], swiz4), 0, (vec_t *)(vecOffset + 80));
+                        vec_xst(vec_perm(t[5], t[7], swiz3), 0, (vec_t *)(vecOffset + 96));
+                        vec_xst(vec_perm(t[5], t[7], swiz4), 0, (vec_t *)(vecOffset + 112));
+                        vecOffset += 128;
+                    } else {
+                        vec_t t0 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);
+                        vec_t t1 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);
+                        vec_t t2 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);
+                        vec_t t3 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);
+
+                        vec_xst(vec_perm(t0, t2, swiz3), 0, (vec_t *)(vecOffset + 0));
+                        vec_xst(vec_perm(t0, t2, swiz4), 0, (vec_t *)(vecOffset + 16));
+                        vec_xst(vec_perm(t1, t3, swiz3), 0, (vec_t *)(vecOffset + 32));
+                        vec_xst(vec_perm(t1, t3, swiz4), 0, (vec_t *)(vecOffset + 48));
+                        vecOffset += 64;
+                    }
+                }
+            }
+        }
+    }
+
+    void packNormal_q8_fp16(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
+        if (rows == 4) {
+            pack_q8_block<4>(a, lda, rows, blocks, vec);
+        } else {
+            pack_q8_block<8>(a, lda, rows, blocks, vec);
+        }
+    }
+
+    template
+    void packNormalInt4(const TA * a, int64_t lda, int rows, int cols, int8_t * vec, std::array & comparray) {
         int64_t i, j;
-        TA *aoffset = NULL;
-        int8_t *vecOffset = NULL;
-        TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
-        TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
+        TA * aoffset = NULL;
+        int8_t * vecOffset = NULL;
+        TA * aoffset1 = NULL, * aoffset2 = NULL, * aoffset3 = NULL, * aoffset4 = NULL;
+        TA * aoffset5 = NULL, * aoffset6 = NULL, * aoffset7 = NULL, * aoffset8 = NULL;
         vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
         vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
-        aoffset = const_cast(a);
+        aoffset = const_cast(a);
         vecOffset = vec;
         j = (rows >> 3);
         if (j > 0) {
@@ -1634,18 +2620,18 @@ class tinyBLAS_BF16_PPC {
                         c7[1] = reinterpret_cast(vec_xl(0, aoffset7->qs));
                         c8[1] = reinterpret_cast(vec_xl(0, aoffset8->qs));
 
-                        process_q4_elements(c1, &comparray[0]);
-                        process_q4_elements(c2, &comparray[1]);
-                        process_q4_elements(c3, &comparray[2]);
-                        process_q4_elements(c4, &comparray[3]);
-                        process_q4_elements(c5, &comparray[4]);
-                        process_q4_elements(c6, &comparray[5]);
-                        process_q4_elements(c7, &comparray[6]);
-                        process_q4_elements(c8, &comparray[7]);
+                        process_q4_elements(c1, & comparray[0]);
+                        process_q4_elements(c2, & comparray[1]);
+                        process_q4_elements(c3, & comparray[2]);
+                        process_q4_elements(c4, & comparray[3]);
+                        process_q4_elements(c5, & comparray[4]);
+                        process_q4_elements(c6, & comparray[5]);
+                        process_q4_elements(c7, & comparray[6]);
+                        process_q4_elements(c8, & comparray[7]);
                         vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
-                        vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
-                        vector_permute_store(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
-                        vector_permute_store(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
+                        vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
+                        vector_permute_store(c5[0], c6[0], c7[0], c8[0], vecOffset + 128, false);
+                        vector_permute_store(c5[1], c6[1], c7[1], c8[1], vecOffset + 192, false);
                         aoffset1 += lda;
                         aoffset2 += lda;
                         aoffset3 += lda;
@@ -1676,12 +2662,12 @@ class tinyBLAS_BF16_PPC {
                     c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs));
                     c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs));
 
-                    process_q4_elements(c1, &comparray[0]);
-                    process_q4_elements(c2, &comparray[1]);
-                    process_q4_elements(c3, &comparray[2]);
-                    process_q4_elements(c4, &comparray[3]);
+                    process_q4_elements(c1, & comparray[0]);
+                    process_q4_elements(c2, & comparray[1]);
+                    process_q4_elements(c3, & comparray[2]);
+                    process_q4_elements(c4, & comparray[3]);
                     vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
-                    vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
+                    vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
                     aoffset1 += lda;
                     aoffset2 += lda;
                     aoffset3 += lda;
@@ -1705,12 +2691,12 @@ class tinyBLAS_BF16_PPC {
                         case 1: c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs));
                             break;
                     }
-                    process_q4_elements(c1, &comparray[0]);
-                    process_q4_elements(c2, &comparray[1]);
-                    process_q4_elements(c3, &comparray[2]);
-                    process_q4_elements(c4, &comparray[3]);
+                    process_q4_elements(c1, & comparray[0]);
+                    process_q4_elements(c2, & comparray[1]);
+                    process_q4_elements(c3, & comparray[2]);
+                    process_q4_elements(c4, & comparray[3]);
                     vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
-                    vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
+                    vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
                     aoffset1 += lda;
                     aoffset2 += lda;
                     aoffset3 += lda;
@@ -1721,39 +2707,38 @@ class tinyBLAS_BF16_PPC {
         }
     }
 
-    template
     template
-    void tinyBLAS_Q0_PPC::packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
+    void packNormal(const block_q8_0 * a, int64_t lda, int rows, int cols, VA * vec, bool flip) {
         int64_t i, j;
-        block_q8_0 *aoffset = NULL;
-        VA *vecOffset = NULL;
-        block_q8_0* aoffsets[8];
+        block_q8_0 * aoffset = NULL;
+        VA * vecOffset = NULL;
+        block_q8_0 * aoffsets[8];
         __vector_pair arr[8];
         VB c[8][2] = {0};
         VB c1[8] = {0}; VB c2[8] = {0};
-        aoffset = const_cast(a);
+        aoffset = const_cast(a);
         vecOffset = vec;
         j = (rows >> 3);
         if (j > 0) {
             do {
                 aoffsets[0] = aoffset;
                 for (int it = 1; it < 8; it++)
-                    aoffsets[it] = aoffsets[it-1] + lda;
+                    aoffsets[it] = aoffsets[it - 1] + lda;
                 aoffset += 8 * lda;
 
                 i = (cols >> 3);
                 if (i > 0) {
                 do {
                     for (int it = 0; it < 8; it++) {
-                        arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
-                        __builtin_vsx_disassemble_pair(c[it], &arr[it]);
+                        arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);
+                        __builtin_vsx_disassemble_pair(c[it], & arr[it]);
                         c1[it] = c[it][0];
                         c2[it] = c[it][1];
                     }
                     vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
-                    vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
-                    vector_permute_store(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
-                    vector_permute_store(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
+                    vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
+                    vector_permute_store(c1[4], c1[5], c1[6], c1[7], vecOffset + 128, flip);
+                    vector_permute_store(c2[4], c2[5], c2[6], c2[7], vecOffset + 192, flip);
                     for (int it = 0; it < 8; it++)
                         aoffsets[it] += lda;
                     vecOffset += 256;
@@ -1772,13 +2757,13 @@ class tinyBLAS_BF16_PPC {
             if (i > 0) {
                do {
                     for (int it = 0; it < 4; it++) {
-                        arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
-                        __builtin_vsx_disassemble_pair(c[it], &arr[it]);
+                        arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);
+                        __builtin_vsx_disassemble_pair(c[it], & arr[it]);
                         c1[it] = c[it][0];
                         c2[it] = c[it][1];
                     }
                     vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
-                    vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
+                    vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
                     for (int it = 0; it < 4; it++) {
                         aoffsets[it] += lda;
                     }
@@ -1791,24 +2776,24 @@ class tinyBLAS_BF16_PPC {
         if (rows & 3) {
             aoffsets[0]  = aoffset;
             for (int it = 1; it < 3; it++ )
-                aoffsets[it] = aoffsets[it-1] + lda;
+                aoffsets[it] = aoffsets[it - 1] + lda;
             i = (cols >> 3);
             if (i > 0) {
                 do {
                     switch(rows) {
-                        case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs);
-                                __builtin_vsx_disassemble_pair(c[2], &arr[2]);
+                        case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[2]->qs);
+                                __builtin_vsx_disassemble_pair(c[2], & arr[2]);
                                 c1[2] = c[2][0]; c2[2] = c[2][1];
-                        case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs);
-                                __builtin_vsx_disassemble_pair(c[1], &arr[1]);
+                        case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[1]->qs);
+                                __builtin_vsx_disassemble_pair(c[1], & arr[1]);
                                 c1[1] = c[1][0]; c2[1] = c[1][1];
-                        case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs);
-                                __builtin_vsx_disassemble_pair(c[0], &arr[0]);
+                        case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[0]->qs);
+                                __builtin_vsx_disassemble_pair(c[0], & arr[0]);
                                 c1[0] = c[0][0]; c2[0] = c[0][1];
                                 break;
                     }
                     vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
-                    vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
+                    vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
                     for (int it = 0; it < 3; it++)
                          aoffsets[it] += lda;
                     vecOffset += 128;
@@ -1818,8 +2803,7 @@ class tinyBLAS_BF16_PPC {
         }
     }
 
-    template
-    void tinyBLAS_Q0_PPC::mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+    void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
         int m_rem = MIN(m - m0, 16);
         int n_rem = MIN(n - n0, 16);
 
@@ -1856,8 +2840,7 @@ class tinyBLAS_BF16_PPC {
     }
 
 
-    template
-    void tinyBLAS_Q0_PPC::KERNEL_4x8(int64_t ii, int64_t jj) {
+    void KERNEL_4x8(int64_t ii, int64_t jj) {
         vec_t vec_A[8], vec_B[16] = {0};
         acc_t acc_0, acc_1;
         std::array comparray {};
@@ -1865,26 +2848,26 @@ class tinyBLAS_BF16_PPC {
         vector float vs[8] = {0};
         bool isAblock_q4 = std::is_same_v;
         for (int l = 0; l < k; l++) {
-            __builtin_mma_xxsetaccz(&acc_0);
-            __builtin_mma_xxsetaccz(&acc_1);
+            __builtin_mma_xxsetaccz(& acc_0);
+            __builtin_mma_xxsetaccz(& acc_1);
             if (std::is_same_v) {
-               packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
+               packNormalInt4<4>((A + (ii * lda) + l), lda, 4, 4, (int8_t *)vec_A, comparray);
             } else {
-               packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
+               packNormal((const block_q8_0 *)(A + (ii * lda) + l), lda, 4, 8, (int8_t *)vec_A, false);
             }
-            packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
+            packNormal((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);
             for(int x = 0; x < 8; x++) {
-                __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
-                __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]);
+                __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
+                __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x], vec_B[x+8]);
             }
             for (int I = 0; I<4; I++) {
                 for (int J = 0; J<4; J++) {
-                    *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
-                    *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
+                    *((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
+                    *((float *)& vs[I + 4] + J) = (unhalf((A +((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));
                 }
             }
             if (!isAblock_q4) {
-                auto aoffset = A+(ii*lda)+l;
+                auto aoffset = A + (ii * lda) + l;
                 for (int i = 0; i < 4; i++) {
                     comparray[i] = 0;
                     int ca = 0;
@@ -1895,15 +2878,14 @@ class tinyBLAS_BF16_PPC {
                     aoffset += lda;
                 }
             }
-            compute(&acc_0, 0, 0, comparray, vs, fin_res);
-            compute(&acc_1, 0, 4, comparray, vs, fin_res);
+            compute(& acc_0, 0, 0, comparray, vs, fin_res);
+            compute(& acc_1, 0, 4, comparray, vs, fin_res);
         }
         save_res(ii, jj, 0, fin_res);
-        save_res(ii, jj+4, 4, fin_res);
+        save_res(ii, jj + 4, 4, fin_res);
     }
 
-    template
-    void tinyBLAS_Q0_PPC::KERNEL_8x4(int64_t ii, int64_t jj) {
+    void KERNEL_8x4(int64_t ii, int64_t jj) {
         vec_t vec_A[16], vec_B[8] = {0};
         acc_t acc_0, acc_1;
         std::array comparray {};
@@ -1911,25 +2893,25 @@ class tinyBLAS_BF16_PPC {
         vector float vs[8] = {0};
         bool isAblock_q4 = std::is_same_v;
         for (int l = 0; l < k; l++) {
-            __builtin_mma_xxsetaccz(&acc_0);
-            __builtin_mma_xxsetaccz(&acc_1);
+            __builtin_mma_xxsetaccz(& acc_0);
+            __builtin_mma_xxsetaccz(& acc_1);
             if (std::is_same_v) {
-               packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
+               packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);
             } else {
-               packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
+               packNormal((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);
             }
-            packNormal((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
+            packNormal((B + (jj * ldb) + l), ldb, 4, 8, (uint8_t *)vec_B, true);
             for(int x = 0; x < 8; x++) {
-                __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
-                __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
+                __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
+                __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);
             }
-            for (int I = 0; I<8; I++) {
-                for (int J = 0; J<4; J++) {
-                    *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
+            for (int I = 0; I < 8; I++) {
+                for (int J = 0; J < 4; J++) {
+                    *((float *)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
                 }
             }
             if (!isAblock_q4) {
-                auto aoffset = A+(ii*lda)+l;
+                auto aoffset = A + (ii * lda) + l;
                 for (int i = 0; i < 8; i++) {
                     comparray[i] = 0;
                     int ca = 0;
@@ -1940,15 +2922,14 @@ class tinyBLAS_BF16_PPC {
                     aoffset += lda;
                 }
             }
-            compute(&acc_0, 0, 0, comparray, vs, fin_res);
-            compute(&acc_1, 4, 4, comparray, vs, fin_res);
+            compute(& acc_0, 0, 0, comparray, vs, fin_res);
+            compute(& acc_1, 4, 4, comparray, vs, fin_res);
         }
         save_res(ii, jj, 0, fin_res);
-        save_res(ii+4, jj, 4, fin_res);
+        save_res(ii + 4, jj, 4, fin_res);
     }
 
-    template
-    void tinyBLAS_Q0_PPC::KERNEL_8x8(int64_t ii, int64_t jj) {
+    void KERNEL_8x8(int64_t ii, int64_t jj) {
         vec_t vec_A[16], vec_B[16] = {0};
         acc_t acc_0, acc_1, acc_2, acc_3;
         acc_t acc_4, acc_5, acc_6, acc_7;
@@ -1957,30 +2938,30 @@ class tinyBLAS_BF16_PPC {
         vector float vs[16] = {0};
         bool isAblock_q4 = std::is_same_v;
         for (int l = 0; l < k; l++) {
-            __builtin_mma_xxsetaccz(&acc_0);
-            __builtin_mma_xxsetaccz(&acc_1);
-            __builtin_mma_xxsetaccz(&acc_2);
-            __builtin_mma_xxsetaccz(&acc_3);
+            __builtin_mma_xxsetaccz(& acc_0);
+            __builtin_mma_xxsetaccz(& acc_1);
+            __builtin_mma_xxsetaccz(& acc_2);
+            __builtin_mma_xxsetaccz(& acc_3);
             if (std::is_same_v) {
-               packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
+               packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);
             } else {
-               packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
+               packNormal((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);
             }
-            packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
+            packNormal((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);
             for(int x = 0; x < 8; x++) {
-                __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
-                __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
-                __builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]);
-                __builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]);
+                __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
+                __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);
+                __builtin_mma_xvi8ger4pp(& acc_2, vec_A[x], vec_B[x + 8]);
+                __builtin_mma_xvi8ger4pp(& acc_3, vec_A[x + 8], vec_B[x + 8]);
             }
-            for (int I = 0; I<8; I++) {
-                for (int J = 0; J<4; J++) {
-                    *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
-                    *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
+            for (int I = 0; I < 8 ; I++) {
+                for (int J = 0; J < 4; J++) {
+                    *((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
+                    *((float *)& vs[I + 8] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));
                 }
             }
             if (!isAblock_q4) {
-                auto aoffset = A+(ii*lda)+l;
+                auto aoffset = A + (ii * lda) + l;
                 for (int i = 0; i < 8; i++) {
                     comparray[i] = 0;
                     int ca = 0;
@@ -1991,19 +2972,99 @@ class tinyBLAS_BF16_PPC {
                     aoffset += lda;
                 }
             }
-            compute(&acc_0, 0, 0, comparray, vs, fin_res);
-            compute(&acc_1, 4, 4, comparray, vs, fin_res);
-            compute(&acc_2, 0, 8, comparray, vs, fin_res);
-            compute(&acc_3, 4, 12, comparray, vs, fin_res);
+            compute(& acc_0, 0, 0, comparray, vs, fin_res);
+            compute(& acc_1, 4, 4, comparray, vs, fin_res);
+            compute(& acc_2, 0, 8, comparray, vs, fin_res);
+            compute(& acc_3, 4, 12, comparray, vs, fin_res);
         }
         save_res(ii, jj, 0, fin_res);
-        save_res(ii+4, jj, 4, fin_res);
-        save_res(ii, jj+4, 8, fin_res);
-        save_res(ii+4, jj+4, 12, fin_res);
+        save_res(ii + 4, jj, 4, fin_res);
+        save_res(ii, jj + 4, 8, fin_res);
+        save_res(ii + 4, jj + 4, 12, fin_res);
+    }
+
+    void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t * vec_A, vec_t * vec_B) {
+        acc_t acc[8];
+        for (int i = 0; i < mc ; i += 16) {
+            for (int j = 0; j < nc; j += 8) {
+                int A0_base = (i / 16) * (2 * 32 * kc);
+                int B0_base = (j / 8) * (32 * kc);
+                for (int x = 0; x < 8; x++) {
+                     __builtin_mma_xxsetaccz(&acc[x]);
+                }
+                for (int64_t kk = 0; kk < kc; kk++) {
+                    int A0_block_idx = A0_base + kk * 32;
+                    int B0_block_idx = B0_base + kk * 32;
+                    int A1_block_idx = A0_block_idx + 32 * kc;
+                    int B1_block_idx = B0_block_idx + 32 * kc;
+                    vec_t * A0_block = & vec_A[A0_block_idx];
+                    vec_t * B0_block = & vec_B[B0_block_idx];
+                    vec_t * A1_block = & vec_A[A1_block_idx];
+                    for (int it = 0; it < 4; it++) {
+                        for (int x = 0; x < 4; x++) {
+                            __builtin_mma_xvf16ger2pp(& acc[0], A0_block[8 * it + x], B0_block[8 * it + x]);
+                            __builtin_mma_xvf16ger2pp(& acc[1], A0_block[8 * it + x], B0_block[8 * it + x + 4]);
+                            __builtin_mma_xvf16ger2pp(& acc[2], A0_block[8 * it + x + 4], B0_block[8 * it + x]);
+                            __builtin_mma_xvf16ger2pp(& acc[3], A0_block[8 * it + x + 4], B0_block[8 * it + x + 4]);
+                            __builtin_mma_xvf16ger2pp(& acc[4], A1_block[8 * it + x], B0_block[8 * it + x]);
+                            __builtin_mma_xvf16ger2pp(& acc[5], A1_block[8 * it + x], B0_block[8 * it+ x + 4]);
+                            __builtin_mma_xvf16ger2pp(& acc[6], A1_block[8 * it + x + 4], B0_block[8 * it + x]);
+                            __builtin_mma_xvf16ger2pp(& acc[7], A1_block[8 * it + x + 4], B0_block[8 * it + x + 4]);
+                        }
+                    }
+                }
+                if (l == 0) {
+                    save_acc(& acc[0], ii + i, jj + j);
+                    save_acc(& acc[1], ii + i, jj + j + 4);
+                    save_acc(& acc[2], ii + i + 4, jj + j);
+                    save_acc(& acc[3], ii + i + 4, jj + j + 4);
+                    save_acc(& acc[4], ii + i + 8, jj + j);
+                    save_acc(& acc[5], ii + i + 8, jj + j + 4);
+                    save_acc(& acc[6], ii + i + 12, jj + j);
+                    save_acc(& acc[7], ii + i + 12, jj + j + 4);
+                } else {
+                    add_save_acc(& acc[0], ii + i, jj + j);
+                    add_save_acc(& acc[1], ii + i, jj + j + 4);
+                    add_save_acc(& acc[2], ii + i + 4, jj + j);
+                    add_save_acc(& acc[3], ii + i + 4, jj + j + 4);
+                    add_save_acc(& acc[4], ii + i + 8, jj + j);
+                    add_save_acc(& acc[5], ii + i + 8, jj + j + 4);
+                    add_save_acc(& acc[6], ii + i + 12, jj + j);
+                    add_save_acc(& acc[7], ii + i + 12, jj + j + 4);
+                }
+            }
+        }
+    }
+
+    void matmul_tiled(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
+        vec_t A_pack[mc * kc * 4];
+        vec_t B_pack[nc * kc * 4];
+        constexpr bool is_Ablock_q4 = std::is_same_v;
+        int64_t ytiles = m / mc;
+        int64_t xtiles = n / nc;
+        int64_t tiles  = xtiles * ytiles;
+        int64_t duty = (tiles + nth - 1) / nth;
+        int64_t start = duty * ith;
+        int64_t end = start + duty;
+        if (end > tiles) {
+            end = tiles;
+        }
+        for (int64_t job = start; job < end; ++job) {
+            int64_t ii = (job / xtiles) * mc;
+            int64_t jj = (job % xtiles) * nc;
+            for (int64_t kk = 0; kk < k; kk += kc) {
+                if constexpr(is_Ablock_q4) {
+                    packNormal_q4_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
+                } else {
+                    packNormal_q8_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
+                }
+                packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, kc, (uint8_t *)B_pack);
+                KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack);
+            }
+        }
     }
 
-    template
-    void tinyBLAS_Q0_PPC::gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
+    void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
         int64_t ytiles = (m - m0) / RM;
         int64_t xtiles = (n - n0) / RN;
         int64_t tiles = xtiles * ytiles;
@@ -2025,32 +3086,32 @@ class tinyBLAS_BF16_PPC {
             vector float fin_res[4] = {0};
             vector float vs[4] = {0};
             vector float CA[4] = {0};
-            __builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value
-            __builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value
+            __builtin_prefetch((A + (ii * lda) + 0)->qs, 0, 1); // prefetch first value
+            __builtin_prefetch((B + (jj * ldb) + 0)->qs, 0, 1); // prefetch first value
             for (int l = 0; l < k; l++) {
-                __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
-                __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
-                __builtin_mma_xxsetaccz(&acc_0);
+                __builtin_prefetch((A + (ii * lda) + (l + 1))->qs, 0, 1); // prefetch one loop ahead
+                __builtin_prefetch((B + (jj * ldb) + (l + 1))->qs, 0, 1); // prefetch one loop ahead
+                __builtin_mma_xxsetaccz(& acc_0);
                 if (isAblock_q4) {
-                   packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
+                    packNormalInt4<4>((A + (ii * lda) + l), lda, RM, 4, (int8_t *)vec_A, comparray);
                 } else {
-                   packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
+                    packNormal((const block_q8_0 *)(A + (ii * lda) + l), lda, RM, 8, (int8_t *)vec_A, false);
                 }
-                packNormal((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
-                for(int x = 0; x < 8; x+=4) {
-                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
-                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]);
-                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]);
-                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]);
+                packNormal((B + (jj * ldb) + l), ldb, RN, 8, (uint8_t *)vec_B, true);
+                for (int x = 0; x < 8; x += 4) {
+                    __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
+                    __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 1], vec_B[x + 1]);
+                    __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 2], vec_B[x + 2]);
+                    __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 3], vec_B[x + 3]);
                 }
-                for (int I = 0; Id) * unhalf((B+((jj+J)*ldb)+l)->d));
+                for (int I = 0; I < RM; I++) {
+                    for (int J = 0; J < RN; J++) {
+                        *((float*)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
                     }
                 }
-                __builtin_mma_disassemble_acc(vec_C, &acc_0);
+                __builtin_mma_disassemble_acc(vec_C, & acc_0);
                 if (!isAblock_q4) {
-                    auto aoffset = A+(ii*lda)+l;
+                    auto aoffset = A + (ii * lda) + l;
                     for (int i = 0; i < RM; i++) {
                         comparray[i] = 0;
                         int ca = 0;
@@ -2071,9 +3132,21 @@ class tinyBLAS_BF16_PPC {
         }
     }
 
-    template
+    template
+    inline void kernel(int64_t ii, int64_t jj) {
+        if constexpr(RM == 4 && RN == 8) {
+            KERNEL_4x8(ii,jj);
+        } else if constexpr(RM == 8 && RN == 4) {
+            KERNEL_8x4(ii,jj);
+        } else if constexpr(RM == 8 && RN == 8) {
+            KERNEL_8x8(ii,jj);
+        } else {
+            assert(false && "RN/RM values not supported");
+        }
+    }
+
     template 
-    NOINLINE void tinyBLAS_Q0_PPC::gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
         int64_t ytiles = (m - m0) / RM;
         int64_t xtiles = (n - n0) / RN;
         int64_t tiles = xtiles * ytiles;
@@ -2085,12 +3158,20 @@ class tinyBLAS_BF16_PPC {
         for (int64_t job = start; job < end; ++job) {
             int64_t ii = m0 + job / xtiles * RM;
             int64_t jj = n0 + job % xtiles * RN;
-            this->kernel(ii, jj);
+            kernel(ii, jj);
         }
     }
-
-template class tinyBLAS_Q0_PPC;
-template class tinyBLAS_Q0_PPC;
+    const TA * const A;
+    const block_q8_0 * const B;
+    float * C;
+    const int64_t k;
+    int64_t kc;
+    const int64_t lda;
+    const int64_t ldb;
+    const int64_t ldc;
+    const int ith;
+    const int nth;
+};
 
 class tinyBLAS_PPC {
   public:
@@ -2657,6 +3738,24 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
             params->ith, params->nth};
         tb.matmul(m, n);
         return true;
+#elif defined(__riscv_zvfh)
+    #if LMUL == 1
+        tinyBLAS_RVV tb{ params,
+            k, (const float *)A, lda,
+            (const float *)B, ldb,
+            (float *)C, ldc};
+    #elif LMUL == 2
+        tinyBLAS_RVV tb{ params,
+            k, (const float *)A, lda,
+            (const float *)B, ldb,
+            (float *)C, ldc};
+    #else // LMUL = 4
+        tinyBLAS_RVV tb{ params,
+            k, (const float *)A, lda,
+            (const float *)B, ldb,
+            (float *)C, ldc};
+    #endif
+        return tb.matmul(m, n);
 #else
         return false;
 #endif
@@ -2688,17 +3787,38 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
             return tb.matmul(m, n);
         }
 #elif defined(__MMA__)
-        if ((k % 8))
-                return false;
-        if(Btype == GGML_TYPE_BF16) {
-           tinyBLAS_BF16_PPC tb{ k,
-            (const ggml_bf16_t *)A, lda,
-            (const ggml_bf16_t *)B, ldb,
-            (float *)C, ldc,
-            params->ith, params->nth};
-        tb.matmul(m, n);
-        return true;
+        if (k % 8) {
+            return false;
         }
+
+        if (Btype == GGML_TYPE_BF16) {
+            tinyBLAS_HP16_PPC tb{ k,
+                (const ggml_bf16_t *)A, lda,
+                (const ggml_bf16_t *)B, ldb,
+                (float *)C, ldc,
+                params->ith, params->nth };
+
+            tb.matmul(m, n);
+            return true;
+        }
+#elif defined(__riscv_zvfbfwma)
+        #if LMUL == 1
+            tinyBLAS_RVV tb{ params,
+                k, (const ggml_bf16_t *)A, lda,
+                (const ggml_bf16_t *)B, ldb,
+                (float *)C, ldc};
+        #elif LMUL == 2
+            tinyBLAS_RVV tb{ params,
+                k, (const ggml_bf16_t *)A, lda,
+                (const ggml_bf16_t *)B, ldb,
+                (float *)C, ldc};
+        #else // LMUL = 4
+            tinyBLAS_RVV tb{ params,
+                k, (const ggml_bf16_t *)A, lda,
+                (const ggml_bf16_t *)B, ldb,
+                (float *)C, ldc};
+        #endif
+            return tb.matmul(m, n);
 #endif
         return false;
     }
@@ -2748,6 +3868,41 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
                 (float *)C, ldc};
             return tb.matmul(m, n);
         }
+#elif defined(__riscv_zvfh)
+        if (Btype == GGML_TYPE_F16) {
+        #if LMUL == 1
+            tinyBLAS_RVV tb{ params,
+                k, (const ggml_fp16_t *)A, lda,
+                (const ggml_fp16_t *)B, ldb,
+                (float *)C, ldc};
+        #elif LMUL == 2
+            tinyBLAS_RVV tb{ params,
+                k, (const ggml_fp16_t *)A, lda,
+                (const ggml_fp16_t *)B, ldb,
+                (float *)C, ldc};
+        #else // LMUL = 4
+            tinyBLAS_RVV tb{ params,
+                k, (const ggml_fp16_t *)A, lda,
+                (const ggml_fp16_t *)B, ldb,
+                (float *)C, ldc};
+        #endif
+            return tb.matmul(m, n);
+        }
+#elif defined(__MMA__)
+        if (k % 8) {
+            return false;
+        }
+
+        if (Btype == GGML_TYPE_F16) {
+            tinyBLAS_HP16_PPC tb{ k,
+                (const ggml_fp16_t *)A, lda,
+                (const ggml_fp16_t *)B, ldb,
+                (float *)C, ldc,
+                params->ith, params->nth };
+
+            tb.matmul(m, n);
+            return true;
+        }
 #endif
         return false;
     }
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp
index f4aae533213..5988ef40be9 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp
@@ -3,14 +3,14 @@
 #include "ggml-cpu.h"
 #include "ggml-impl.h"
 #include "binary-ops.h"
+#include "simd-gemm.h"
 #include "ggml.h"
 #include "unary-ops.h"
 #include "vec.h"
 
-#include 
 #include 
+#include 
 #include 
-#include 
 
 // ggml_compute_forward_dup
 
@@ -2097,10 +2097,14 @@ static void ggml_compute_forward_gelu_f32(
 
     const ggml_tensor * src0 = dst->src[0];
 
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_is_contiguous_rows(src0));
     assert(ggml_are_same_shape(src0, dst));
 
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -2114,10 +2118,14 @@ static void ggml_compute_forward_gelu_f32(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    for (int i1 = ir0; i1 < ir1; i1++) {
+    for (int ir = ir0; ir < ir1; ++ir) {
+        const int i3 = ir/(ne02*ne01);
+        const int i2 = (ir - i3*ne02*ne01)/ne01;
+        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
+
         ggml_vec_gelu_f32(nc,
-                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (float *) ((char *) src0->data + i1*(src0->nb[1])));
+                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
+                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
 
 #ifndef NDEBUG
         for (int k = 0; k < nc; k++) {
@@ -2136,10 +2144,14 @@ static void ggml_compute_forward_gelu_f16(
 
     const ggml_tensor * src0 = dst->src[0];
 
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_is_contiguous_rows(src0));
     assert(ggml_are_same_shape(src0, dst));
 
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -2153,10 +2165,14 @@ static void ggml_compute_forward_gelu_f16(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    for (int i1 = ir0; i1 < ir1; i1++) {
+    for (int ir = ir0; ir < ir1; ++ir) {
+        const int i3 = ir/(ne02*ne01);
+        const int i2 = (ir - i3*ne02*ne01)/ne01;
+        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
+
         ggml_vec_gelu_f16(nc,
-                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
+                (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
+                (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
 
 #ifndef NDEBUG
         for (int k = 0; k < nc; k++) {
@@ -2277,10 +2293,14 @@ static void ggml_compute_forward_gelu_erf_f32(
 
     const ggml_tensor * src0 = dst->src[0];
 
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_is_contiguous_rows(src0));
     assert(ggml_are_same_shape(src0, dst));
 
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -2294,10 +2314,14 @@ static void ggml_compute_forward_gelu_erf_f32(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    for (int i1 = ir0; i1 < ir1; i1++) {
+    for (int ir = ir0; ir < ir1; ++ir) {
+        const int i3 = ir/(ne02*ne01);
+        const int i2 = (ir - i3*ne02*ne01)/ne01;
+        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
+
         ggml_vec_gelu_erf_f32(nc,
-                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (float *) ((char *) src0->data + i1*(src0->nb[1])));
+                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
+                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
 
 #ifndef NDEBUG
         for (int k = 0; k < nc; k++) {
@@ -2316,10 +2340,14 @@ static void ggml_compute_forward_gelu_erf_f16(
 
     const ggml_tensor * src0 = dst->src[0];
 
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_is_contiguous_rows(src0));
     assert(ggml_are_same_shape(src0, dst));
 
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -2333,10 +2361,14 @@ static void ggml_compute_forward_gelu_erf_f16(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    for (int i1 = ir0; i1 < ir1; i1++) {
+    for (int ir = ir0; ir < ir1; ++ir) {
+        const int i3 = ir/(ne02*ne01);
+        const int i2 = (ir - i3*ne02*ne01)/ne01;
+        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
+
         ggml_vec_gelu_erf_f16(nc,
-                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
+                (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
+                (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
 
 #ifndef NDEBUG
         for (int k = 0; k < nc; k++) {
@@ -2380,10 +2412,14 @@ static void ggml_compute_forward_gelu_quick_f32(
 
     const ggml_tensor * src0 = dst->src[0];
 
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_is_contiguous_rows(src0));
     assert(ggml_are_same_shape(src0, dst));
 
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -2397,10 +2433,14 @@ static void ggml_compute_forward_gelu_quick_f32(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    for (int i1 = ir0; i1 < ir1; i1++) {
+    for (int ir = ir0; ir < ir1; ++ir) {
+        const int i3 = ir/(ne02*ne01);
+        const int i2 = (ir - i3*ne02*ne01)/ne01;
+        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
+
         ggml_vec_gelu_quick_f32(nc,
-                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (float *) ((char *) src0->data + i1*(src0->nb[1])));
+                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
+                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
 
 #ifndef NDEBUG
         for (int k = 0; k < nc; k++) {
@@ -2419,10 +2459,14 @@ static void ggml_compute_forward_gelu_quick_f16(
 
     const ggml_tensor * src0 = dst->src[0];
 
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_is_contiguous_rows(src0));
     assert(ggml_are_same_shape(src0, dst));
 
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -2436,10 +2480,14 @@ static void ggml_compute_forward_gelu_quick_f16(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    for (int i1 = ir0; i1 < ir1; i1++) {
+    for (int ir = ir0; ir < ir1; ++ir) {
+        const int i3 = ir/(ne02*ne01);
+        const int i2 = (ir - i3*ne02*ne01)/ne01;
+        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
+
         ggml_vec_gelu_quick_f16(nc,
-                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
+                (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
+                (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
 
 #ifndef NDEBUG
         for (int k = 0; k < nc; k++) {
@@ -2483,10 +2531,14 @@ static void ggml_compute_forward_silu_f32(
 
     const ggml_tensor * src0 = dst->src[0];
 
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_is_contiguous_rows(src0));
     assert(ggml_are_same_shape(src0, dst));
 
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -2500,10 +2552,14 @@ static void ggml_compute_forward_silu_f32(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    for (int i1 = ir0; i1 < ir1; i1++) {
+    for (int ir = ir0; ir < ir1; ++ir) {
+        const int i3 = ir/(ne02*ne01);
+        const int i2 = (ir - i3*ne02*ne01)/ne01;
+        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
+
         ggml_vec_silu_f32(nc,
-                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (float *) ((char *) src0->data + i1*(src0->nb[1])));
+                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
+                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
 
 #ifndef NDEBUG
         for (int k = 0; k < nc; k++) {
@@ -2522,10 +2578,14 @@ static void ggml_compute_forward_silu_f16(
 
     const ggml_tensor * src0 = dst->src[0];
 
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_is_contiguous_rows(src0));
     assert(ggml_are_same_shape(src0, dst));
 
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -2539,10 +2599,14 @@ static void ggml_compute_forward_silu_f16(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    for (int i1 = ir0; i1 < ir1; i1++) {
+    for (int ir = ir0; ir < ir1; ++ir) {
+        const int i3 = ir/(ne02*ne01);
+        const int i2 = (ir - i3*ne02*ne01)/ne01;
+        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
+
         ggml_vec_silu_f16(nc,
-                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
+                (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
+                (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
 
 #ifndef NDEBUG
         for (int k = 0; k < nc; k++) {
@@ -7110,12 +7174,13 @@ void ggml_compute_forward_conv_2d_dw(
     }
 }
 
-// ggml_compute_forward_pool_1d_sk_p0
-
-static void ggml_compute_forward_pool_1d_sk_p0(
+// ggml_compute_forward_pool_1d_ksp
+static void ggml_compute_forward_pool_1d_ksp(
         const ggml_compute_params * params,
         const ggml_op_pool op,
         const int k,
+        const int s,
+        const int p,
         ggml_tensor * dst) {
 
     const ggml_tensor * src = dst->src[0];
@@ -7126,39 +7191,56 @@ static void ggml_compute_forward_pool_1d_sk_p0(
         return;
     }
 
-    const char * cdata = (const char *)src->data;
-    const char * const data_end = cdata + ggml_nbytes(src);
-    float * drow = (float *)dst->data;
+    const int64_t IW = src->ne[0];
+    const int64_t OW = dst->ne[0];
 
-    const int64_t rs = dst->ne[0];
+    const int64_t nr = ggml_nrows(src);
 
-    while (cdata < data_end) {
-        const void * srow = (const void *)cdata;
-        int j = 0;
-        for (int64_t i = 0; i < rs; ++i) {
+    for (int64_t ir = 0; ir < nr; ++ir) {
+        const char * srow_bytes =            (const char *) src->data + ir * src->nb[1];
+        float      * drow       = (float *) ((      char *) dst->data + ir * dst->nb[1]);
+
+        for (int64_t ow = 0; ow < OW; ++ow) {
+            float res = 0;
             switch (op) {
-                case GGML_OP_POOL_AVG:   drow[i] = 0;        break;
-                case GGML_OP_POOL_MAX:   drow[i] = -FLT_MAX; break;
+                case GGML_OP_POOL_AVG: res = 0.0f;     break;
+                case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
                 case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
             }
+
+            int count = 0;
+            const int base = (int) ow * s - p;
+
             for (int ki = 0; ki < k; ++ki) {
-                const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
+                const int j = base + ki;
+                if (j < 0 || j >= (int) IW) {
+                    continue;
+                }
+
+                float v;
+                if (src->type == GGML_TYPE_F32) {
+                    v = ((const float *) srow_bytes)[j];
+                } else {
+                    v = GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) srow_bytes)[j]);
+                }
+
                 switch (op) {
-                    case GGML_OP_POOL_AVG:                         drow[i] += srow_j; break;
-                    case GGML_OP_POOL_MAX:   if (srow_j > drow[i]) drow[i]  = srow_j; break;
-                    case GGML_OP_POOL_COUNT:                       GGML_ABORT("fatal error");
+                    case GGML_OP_POOL_AVG: res += v;                break;
+                    case GGML_OP_POOL_MAX: res =  std::max(v, res); break;
+                    case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
                 }
-                ++j;
+
+                ++count;
             }
+
             switch (op) {
-                case GGML_OP_POOL_AVG:         drow[i] /= k; break;
-                case GGML_OP_POOL_MAX:                       break;
+                case GGML_OP_POOL_AVG: res = (count > 0) ? (res / count) : 0.0f; break;
+                case GGML_OP_POOL_MAX:                                           break;
                 case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
             }
-        }
 
-        cdata += src->nb[1];
-        drow  += rs;
+            drow[ow] = res;
+        }
     }
 }
 
@@ -7173,10 +7255,8 @@ void ggml_compute_forward_pool_1d(
     const int k0 = opts[1];
     const int s0 = opts[2];
     const int p0 = opts[3];
-    GGML_ASSERT(p0 == 0); // padding not supported
-    GGML_ASSERT(k0 == s0); // only s = k supported
 
-    ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst);
+    ggml_compute_forward_pool_1d_ksp(params, op, k0, s0, p0, dst);
 }
 
 // ggml_compute_forward_pool_2d
@@ -7194,6 +7274,7 @@ void ggml_compute_forward_pool_2d(
     }
 
     const int32_t * opts = (const int32_t *)dst->op_params;
+
     ggml_op_pool op = static_cast(opts[0]);
     const int k0 = opts[1];
     const int k1 = opts[2];
@@ -7217,11 +7298,13 @@ void ggml_compute_forward_pool_2d(
     while (cdata < data_end) {
         for (int oy = 0; oy < py; ++oy) {
             float * const drow = dplane + oy * px;
+            float * const out  = drow;
+
             for (int ox = 0; ox < px; ++ox) {
-                float * const out =  drow + ox;
+                float res = 0;
                 switch (op) {
-                    case GGML_OP_POOL_AVG:     *out = 0;        break;
-                    case GGML_OP_POOL_MAX:     *out = -FLT_MAX; break;
+                    case GGML_OP_POOL_AVG: res = 0;        break;
+                    case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
                     case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
                 }
 
@@ -7229,24 +7312,32 @@ void ggml_compute_forward_pool_2d(
                 const int iy = offset1 + oy * s1;
 
                 for (int ky = 0; ky < k1; ++ky) {
-                    if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
+                    if (iy + ky < 0 || iy + ky >= src->ne[1]) {
+                        continue;
+                    }
+
                     const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
                     for (int kx = 0; kx < k0; ++kx) {
                         int j = ix + kx;
-                        if (j < 0 || j >= src->ne[0]) continue;
+                        if (j < 0 || j >= src->ne[0]) {
+                            continue;
+                        }
+
                         const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
                         switch (op) {
-                            case GGML_OP_POOL_AVG:                     *out += srow_j; break;
-                            case GGML_OP_POOL_MAX: if (srow_j > *out)  *out  = srow_j; break;
+                            case GGML_OP_POOL_AVG: res += srow_j;                break;
+                            case GGML_OP_POOL_MAX: res =  std::max(srow_j, res); break;
                             case GGML_OP_POOL_COUNT:               GGML_ABORT("fatal error");
                         }
                     }
                 }
                 switch (op) {
-                    case GGML_OP_POOL_AVG:           *out /= ka; break;
-                    case GGML_OP_POOL_MAX:                       break;
+                    case GGML_OP_POOL_AVG:           res /= ka; break;
+                    case GGML_OP_POOL_MAX:                      break;
                     case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
                 }
+
+                out[ox] = res;
             }
         }
 
@@ -7603,8 +7694,7 @@ static void ggml_compute_forward_pad_f32(
 
     const ggml_tensor * src0 = dst->src[0];
 
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-    GGML_ASSERT( dst->nb[0] == sizeof(float));
+    assert(dst->nb[0] == sizeof(float));
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -8059,12 +8149,14 @@ void ggml_compute_forward_top_k(
     }
 }
 
-// ggml_compute_forward_flash_attn_ext
-
 static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
         const ggml_compute_params * params,
         ggml_tensor * dst,
-        int ir0, int ir1) {
+        int ir0, int ir1,
+        int64_t ic_start, int64_t ic_end,
+        float * partials, int64_t partial_stride) {
+
+    const bool write_partials = (partials != nullptr);
     const ggml_tensor * q     = dst->src[0];
     const ggml_tensor * k     = dst->src[1];
     const ggml_tensor * v     = dst->src[2];
@@ -8141,7 +8233,6 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
 
     int ith = params->ith;
 
-    // loop over n_batch and n_head
     for (int ir = ir0; ir < ir1; ++ir) {
         // q indices
         const int iq3 = ir/(neq2*neq1);
@@ -8181,7 +8272,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
         // online softmax / attention
         // loop over n_kv and n_head_kv
         // ref: https://arxiv.org/pdf/2112.05682.pdf
-        for (int64_t ic = 0; ic < nek1; ++ic) {
+
+        for (int64_t ic = ic_start; ic < ic_end; ++ic) {
             const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
             if (mv == -INFINITY) {
                 continue;
@@ -8254,8 +8346,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
             }
         }
 
-        // sinks
-        if (sinks) {
+        // sinks - apply only on the first kv-chunk
+        if (sinks && ic_start == 0) {
             const float s = ((float *)((char *) sinks->data))[h];
 
             float ms = 1.0f;
@@ -8263,6 +8355,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
 
             if (s > M) {
                 ms = expf(M - s);
+                M = s;
                 ggml_vec_scale_f32(DV, VKQ32, ms);
             } else {
                 vs = expf(s - M);
@@ -8271,20 +8364,386 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
             S = S*ms + vs;
         }
 
-        // V /= S
-        const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
-        ggml_vec_scale_f32(DV, VKQ32, S_inv);
+        if (write_partials) {
+            // Write M, S, VKQ to partials for later reduction
+            // partials layout: [M, S, VKQ[DV]] per query head
+            float * partial = partials + ir * partial_stride;
+            partial[0] = M;
+            partial[1] = S;
+            memcpy(partial + 2, VKQ32, DV * sizeof(float));
+        } else {
+            // V /= S
+            const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
+            ggml_vec_scale_f32(DV, VKQ32, S_inv);
 
-        // dst indices
-        const int i1 = iq1;
-        const int i2 = iq2;
-        const int i3 = iq3;
+            // dst indices
+            const int i1 = iq1;
+            const int i2 = iq2;
+            const int i3 = iq3;
+
+            // permute(0, 2, 1, 3)
+            memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
+        }
+    }
+}
+
+static void ggml_compute_forward_flash_attn_ext_tiled(
+        const ggml_compute_params * params,
+        ggml_tensor * dst,
+        int ir0, int ir1) {
+    const ggml_tensor * q     = dst->src[0];
+    const ggml_tensor * k     = dst->src[1];
+    const ggml_tensor * v     = dst->src[2];
+    const ggml_tensor * mask  = dst->src[3];
+    const ggml_tensor * sinks = dst->src[4];
+
+    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
+
+    const int64_t DK = nek0;
+    const int64_t DV = nev0;
+    const int64_t N  = neq1;
+
+    GGML_ASSERT(ne0 == DV);
+    GGML_ASSERT(ne2 == N);
+
+    // input tensor rows must be contiguous
+    GGML_ASSERT(nbq0 == ggml_type_size(q->type));
+    GGML_ASSERT(nbk0 == ggml_type_size(k->type));
+    GGML_ASSERT(nbv0 == ggml_type_size(v->type));
+
+    GGML_ASSERT(neq0 == DK);
+    GGML_ASSERT(nek0 == DK);
+    GGML_ASSERT(nev0 == DV);
+
+    GGML_ASSERT(neq1 == N);
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    GGML_ASSERT(k->type == v->type);
+    const ggml_type kv_type = k->type;
+
+
+    // broadcast factors
+    const int64_t rk2 = neq2/nek2;
+    const int64_t rk3 = neq3/nek3;
+
+    const int64_t rv2 = neq2/nev2;
+    const int64_t rv3 = neq3/nev3;
+
+    float scale         = 1.0f;
+    float max_bias      = 0.0f;
+    float logit_softcap = 0.0f;
+
+    memcpy(&scale,         (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias,      (float *) dst->op_params + 1, sizeof(float));
+    memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
+
+    if (logit_softcap != 0) {
+        scale /= logit_softcap;
+    }
+
+    const uint32_t n_head      = neq2;
+    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+    int ith = params->ith;
+
+    static constexpr int Q_TILE_SZ  = ggml_fa_tile_config::Q;
+    static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
+
+    int ir = ir0;
+    while (ir < ir1) {
+        // q indices for the start of this tile
+        const int iq3 = ir/(neq2*neq1);
+        const int iq2 = (ir - iq3*neq2*neq1)/neq1;
+        const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
+
+        // Number of valid rows in this tile:
+        // - limited by tile size (Q_TILE_SZ)
+        // - limited by chunk boundary (ir1 - ir)
+        // - limited by head boundary (neq1 - iq1) to avoid crossing into next head
+        const int tile_rows = MIN(Q_TILE_SZ, MIN((int)(ir1 - ir), (int)(neq1 - iq1)));
+        GGML_ASSERT(tile_rows > 0);
+
+        const uint32_t h = iq2; // head index
+        const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
+
+        float S[Q_TILE_SZ];
+        float M[Q_TILE_SZ];
+
+        for (int i = 0 ; i < Q_TILE_SZ; ++i) {
+            S[i] = 0.;
+            M[i] = -INFINITY;
+        }
+
+        // Per-thread scratch layout:
+        // Q_q:    Q_TILE_SZ * DK (converted Q tile — F32 for GEMM, KV type for scalar)
+        // KQ:     Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
+        // mask:   Q_TILE_SZ * KV_TILE_SZ (mask in float)
+        // VKQ32:  Q_TILE_SZ * DV (FP32 output accumulator)
+        // V32:    KV_TILE_SZ * DV (F32 buffer for V tile)
+        // K_f32:  KV_TILE_SZ * DK (F32 buffer for K tile — GEMM path)
+        float * base  = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + KV_TILE_SZ*DK + CACHE_LINE_SIZE_F32);
+
+        void  * Q_q    = base;
+        float * KQ     = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
+        float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
+        float * VKQ32  = mask32 + Q_TILE_SZ * KV_TILE_SZ;
+        float * V32    = VKQ32 + Q_TILE_SZ * DV;
+        float * K_f32  = V32 + KV_TILE_SZ * DV;
+
+        memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
+        memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
+
+        // k indices
+        const int ik3 = iq3 / rk3;
+        const int ik2 = iq2 / rk2;
+
+        // v indices
+        const int iv3 = iq3 / rv3;
+        const int iv2 = iq2 / rv2;
+
+        {
+            float * Q_f32 = (float *)Q_q;
+            for (int tq = 0; tq < tile_rows; tq++) {
+                const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
+                memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float));
+            }
+            for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
+                memset(Q_f32 + tq * DK, 0, DK * sizeof(float));
+            }
+        }
+
+        memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float));
+        memset(V32,   0, KV_TILE_SZ * DV * sizeof(float));
+
+        for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
+            const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic);
+
+            // skip the tile entirely if all the masks are -inf
+            if (mask) {
+                bool can_skip = true;
+                for (int tq = 0; tq < tile_rows; tq++) {
+                    const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
+                    for (int tk = 0; tk < kv_tile; tk++) {
+                        mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
+                        if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
+                            can_skip = false;
+                        }
+                    }
+                    // Pad remaining mask entries with -inf
+                    for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
+                        mask32[tq * KV_TILE_SZ + tk] = -INFINITY;
+                    }
+                }
+
+                if (can_skip) {
+                    continue;
+                }
+            }
+
+            // Pack K tile transposed: K_f32[dk][kv] so KV_TILE is contiguous (SIMD dim)
+            // Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns
+            for (int tk = 0; tk < kv_tile; tk++) {
+                const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3;
+                if (kv_type == GGML_TYPE_F16) {
+                    const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data;
+                    for (int64_t dk = 0; dk < DK; dk++) {
+                        K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]);
+                    }
+                } else {
+                    const float * k_f32_src = (const float *)k_data;
+                    for (int64_t dk = 0; dk < DK; dk++) {
+                        K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
+                    }
+                }
+            }
+            memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
+            simd_gemm(KQ, (const float *)Q_q, K_f32, Q_TILE_SZ, DK, KV_TILE_SZ);
+            ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, scale);
+
+            // Set padded KQ entries to -inf so softmax gives them zero weight
+            if (kv_tile < KV_TILE_SZ) {
+                for (int tq = 0; tq < Q_TILE_SZ; tq++) {
+                    for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
+                        KQ[tq * KV_TILE_SZ + tk] = -INFINITY;
+                    }
+                }
+            }
+
+            if (logit_softcap != 0.0f) {
+                ggml_vec_tanh_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, KQ);
+                ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, logit_softcap);
+            }
+
+            if (mask) {
+                ggml_vec_add_f32(tile_rows * KV_TILE_SZ, KQ, KQ, mask32);
+            }
+
+            bool skip[Q_TILE_SZ] = {};
+
+            for (int tq = 0; tq < Q_TILE_SZ; tq++) {
+                float * kq_row = KQ + tq * KV_TILE_SZ;
+
+                float tile_max;
+                ggml_vec_max_f32(KV_TILE_SZ, &tile_max, kq_row);
+
+                if (tile_max == -INFINITY) {
+                    skip[tq] = true;
+                    continue;
+                }
+
+                const float Mold = M[tq];
+                const float Mnew = fmaxf(Mold, tile_max);
+
+                if (Mnew > Mold) {
+                    const float ms = expf(Mold - Mnew);
+                    ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
+                    S[tq] *= ms;
+                }
+                M[tq] = Mnew;
+
+
+                S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
+            }
+
+            // V accumulation: VKQ32 += softmax(KQ) * V
+            // Pack V tile to contiguous F32, zero-padded
+            for (int tk = 0; tk < kv_tile; tk++) {
+                const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3;
+                if (kv_type == GGML_TYPE_F16) {
+                    ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV);
+                } else {
+                    memcpy(V32 + tk * DV, v_data, DV * sizeof(float));
+                }
+            }
+            for (int tq = 0; tq < Q_TILE_SZ; tq++) {
+                if (skip[tq]) {
+                    memset(KQ + tq * KV_TILE_SZ, 0, KV_TILE_SZ * sizeof(float));
+                }
+            }
+            simd_gemm(VKQ32, KQ, V32, Q_TILE_SZ, KV_TILE_SZ, DV);
+        }
+
+        // sinks (apply only to valid rows in the tile)
+        if (sinks) {
+            const float s = ((float *)((char *) sinks->data))[h];
+
+            for (int tq = 0; tq < tile_rows; tq++) {
+                float ms = 1.0f;
+                float vs = 1.0f;
+
+                if (s > M[tq]) {
+                    ms = expf(M[tq] - s);
+                    ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
+                } else {
+                    vs = expf(s - M[tq]);
+                }
+
+                S[tq] = S[tq] * ms + vs;
+            }
+        }
+
+        for (int tq = 0; tq < tile_rows; tq++) {
+            // V /= S
+            const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq];
+            ggml_vec_scale_f32(DV, VKQ32 + tq * DV, S_inv);
+
+            // dst indices
+            const int i1 = iq1 + tq;
+            const int i2 = iq2;
+            const int i3 = iq3;
+
+            // permute(0, 2, 1, 3)
+            memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 + tq * DV, nb1);
+        }
+
+        ir += tile_rows;
+    }
+}
+
+// Reduction function: combines partial results across KV chunks
+// Partials layout in wdata: [n_q_heads][n_chunks][2 + DV]
+static void ggml_flash_attn_ext_reduce_partials(
+        const ggml_compute_params * params,
+        ggml_tensor * dst,
+        const int64_t n_chunks,
+        const int64_t chunk_size) {
+
+    const ggml_tensor * q = dst->src[0];
+    const ggml_tensor * k = dst->src[1];
+    const ggml_tensor * v = dst->src[2];
+
+    const int64_t DK        = k->ne[0];
+    const int64_t DV        = v->ne[0];
+    const int64_t nek1      = k->ne[1];
+    const int64_t n_q_heads = q->ne[2];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t wdata_per_thread = DK + 2*DV + CACHE_LINE_SIZE_F32;
+    float *       thread_wdata     = (float *) params->wdata + ith * wdata_per_thread;
+
+    const int64_t partials_offset  = nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
+    const int64_t partial_size     = 2 + DV;
+    const float * partials_base    = (const float *) params->wdata + partials_offset;
+
+    // Output layout
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    const size_t  nb1 = dst->nb[1];
 
-        // original
-        //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
+    // Each thread reduces a subset of query heads
+    for (int64_t q_head = ith; q_head < n_q_heads; q_head += nth) {
+        float   M_final   = -INFINITY;
+        float   S_final   = 0.0f;
+        float * VKQ_final = thread_wdata;
+        memset(VKQ_final, 0, DV * sizeof(float));
 
-        // permute(0, 2, 1, 3)
-        memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
+        // Combine partials from all chunks
+        for (int64_t chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) {
+            const int64_t ic_start = chunk_idx * chunk_size;
+            if (ic_start >= nek1) continue;
+
+            const float * partial   = partials_base + (q_head * n_chunks + chunk_idx) * partial_size;
+            const float   M_chunk   = partial[0];
+            const float   S_chunk   = partial[1];
+            const float * VKQ_chunk = partial + 2;
+
+            if (S_chunk == 0.0f) continue;
+
+            const float M_new     = fmaxf(M_final, M_chunk);
+            const float scale_old = expf(M_final - M_new);
+            const float scale_new = expf(M_chunk - M_new);
+
+            for (int64_t d = 0; d < DV; ++d) {
+                VKQ_final[d] = VKQ_final[d] * scale_old + VKQ_chunk[d] * scale_new;
+            }
+            S_final = S_final * scale_old + S_chunk * scale_new;
+            M_final = M_new;
+        }
+
+        // Normalize and write to output
+        if (S_final != 0.0f) {
+            const float S_inv = 1.0f / S_final;
+            ggml_vec_scale_f32(DV, VKQ_final, S_inv);
+        }
+        // iq1=0, iq3=0 for decode
+        memcpy((char *) dst->data + (0*ne2*ne1 + q_head + 0*ne1)*nb1, VKQ_final, nb1);
     }
 }
 
@@ -8309,6 +8768,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
     const int64_t DV = nev0;
     const int64_t N  = neq1;
 
+
     GGML_ASSERT(ne0 == DV);
     GGML_ASSERT(ne2 == N);
 
@@ -8329,47 +8789,92 @@ static void ggml_compute_forward_flash_attn_ext_f16(
     GGML_ASSERT(nb1 <= nb2);
     GGML_ASSERT(nb2 <= nb3);
 
-    // parallelize by q rows using ggml_vec_dot_f32
-
-    // total rows in q
-    const int64_t nr = neq1*neq2*neq3;
-
-    // rows per thread
     const int ith = params->ith;
     const int nth = params->nth;
 
-    // disable for NUMA
-    const bool disable_chunking = ggml_is_numa();
+    // When use_ref is set, force the vec-only reference implementation (no tiling, no KV-chunking)
+    const bool use_ref = params->use_ref;
 
-    // 4x chunks per thread
-    int nth_scaled = nth * 4;
-    int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
-    int64_t nchunk     = (nr + chunk_size - 1) / chunk_size;
+    const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
+    const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512;
 
-    if (nth == 1 || nchunk < nth || disable_chunking) {
-        nchunk = nth;
-    }
+    if (use_split_kv_path) {
+        const int64_t chunk_size = (nek1 + nth - 1) / nth;
 
-    if (ith == 0) {
-        // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.
-        ggml_threadpool_chunk_set(params->threadpool, nth);
-    }
+        // Partials buffer layout: [q_head][kv_chunk][M, S, VKQ]
+        const int64_t partial_size  = 2 + DV;
+        float *       partials_base = (float *) params->wdata + nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
 
-    ggml_barrier(params->threadpool);
+        const int64_t ic_start = ith * chunk_size;
+        const int64_t ic_end   = std::min(ic_start + chunk_size, nek1);
+
+        const int64_t partial_stride = nth * partial_size;
+        float *       chunk_partials = partials_base + ith * partial_size;
+
+        if (ic_start < nek1) {
+            for (int64_t q_head = 0; q_head < neq2; q_head++) {
+                ggml_compute_forward_flash_attn_ext_f16_one_chunk(
+                    params, dst, q_head, q_head + 1, ic_start, ic_end,
+                    chunk_partials, partial_stride);
+            }
+        } else {
+            for (int64_t q_head = 0; q_head < neq2; q_head++) {
+                float * q_partials = chunk_partials + q_head * partial_stride;
+                q_partials[0] = -INFINITY;  // M
+                q_partials[1] = 0.0f;       // S
+            }
+        }
 
-    // The number of elements in each chunk
-    const int64_t dr = (nr + nchunk - 1) / nchunk;
+        ggml_barrier(params->threadpool);
+        ggml_flash_attn_ext_reduce_partials(params, dst, nth, chunk_size);
+    } else {
 
-    // The first chunk comes from our thread_id, the rest will get auto-assigned.
-    int current_chunk = ith;
+        // total rows in q
+        const int64_t nr = neq1*neq2*neq3;
 
-    while (current_chunk < nchunk) {
-        const int64_t ir0 = dr * current_chunk;
-        const int64_t ir1 = MIN(ir0 + dr, nr);
+        // disable for NUMA
+        const bool disable_chunking = ggml_is_numa();
 
-        ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
+        // 4x chunks per thread
+        int nth_scaled = nth * 4;
+        int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
+        int64_t nchunk     = (nr + chunk_size - 1) / chunk_size;
 
-        current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
+        if (nth == 1 || nchunk < nth || disable_chunking) {
+            nchunk = nth;
+        }
+
+        if (ith == 0) {
+            ggml_threadpool_chunk_set(params->threadpool, nth);
+        }
+
+        ggml_barrier(params->threadpool);
+
+        const int64_t dr = (nr + nchunk - 1) / nchunk;
+
+        static constexpr int64_t Q_TILE_SZ  = ggml_fa_tile_config::Q;
+        bool use_tiled = !use_ref &&
+                               (q->type == GGML_TYPE_F32 &&
+                                kv_is_f32_or_f16 &&
+                                k->type == v->type &&
+                                neq1 >= Q_TILE_SZ);
+#ifdef GGML_SIMD
+        use_tiled &= (DV % GGML_F32_EPR == 0);
+#endif
+        int current_chunk = ith;
+
+        while (current_chunk < nchunk) {
+            const int64_t ir0 = dr * current_chunk;
+            const int64_t ir1 = MIN(ir0 + dr, nr);
+
+            if (use_tiled) {
+                ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
+            } else {
+                ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1, 0, nek1, nullptr, 0);
+            }
+
+            current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
+        }
     }
 }
 
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/repack.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/repack.cpp
index b70ea7d78b9..5edba4212f6 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/repack.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/repack.cpp
@@ -256,6 +256,402 @@ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTR
     ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row);
 }
 
+template 
+static void ggml_gemv_q6_K_NxM_q8_K_generic_impl(int                        n,
+                                                 float * GGML_RESTRICT      s,
+                                                 size_t                     bs,
+                                                 const void * GGML_RESTRICT vx,
+                                                 const void * GGML_RESTRICT vy,
+                                                 int                        nr,
+                                                 int                        nc) {
+    constexpr int blocklen          = M;
+    constexpr int ncols_interleaved = N;
+    const int     qk                = QK_K;
+    const int     nb                = n / qk;
+    const int     blocks_per_half   = 64 / blocklen;
+
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(bs);
+    UNUSED(nr);
+
+    float sumf[8];
+
+    const block_q8_K * a_ptr = (const block_q8_K *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) {
+            sumf[j] = 0.0f;
+        }
+
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen;
+                const int base_h = base_l + 64;
+
+                const int scale_idx_l = base_l / 16;
+                const int scale_idx_h = base_h / 16;
+
+                const int qh_shift_l = ((base_l % 128) / 32) * 2;
+                const int qh_shift_h = ((base_h % 128) / 32) * 2;
+
+                const int qh_half_l = (base_l / 128) * 32;
+                const int qh_half_h = (base_h / 128) * 32;
+
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j];
+                    const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j];
+
+                    int sumi_l = 0;
+                    int sumi_h = 0;
+
+                    for (int i = 0; i < blocklen; i++) {
+                        const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i;
+                        const int l_4    = b_ptr[l].ql[ql_pos] & 0xF;
+                        const int hi_4   = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
+
+                        const int qh_idx_l    = qh_half_l + ((base_l + i) % 32);
+                        const int qh_chunk_l  = qh_idx_l / blocklen;
+                        const int qh_pos_l    = qh_idx_l % blocklen;
+                        const int qh_offset_l = qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l;
+                        const int hi_2_l      = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
+
+                        const int qh_idx_h    = qh_half_h + ((base_h + i) % 32);
+                        const int qh_chunk_h  = qh_idx_h / blocklen;
+                        const int qh_pos_h    = qh_idx_h % blocklen;
+                        const int qh_offset_h = qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h;
+                        const int hi_2_h      = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
+
+                        const int q_l = ((hi_2_l << 4) | l_4) - 32;
+                        const int q_h = ((hi_2_h << 4) | hi_4) - 32;
+
+                        const int8_t a_l = a_ptr[l].qs[base_l + i];
+                        const int8_t a_h = a_ptr[l].qs[base_h + i];
+
+                        sumi_l += q_l * a_l;
+                        sumi_h += q_h * a_h;
+                    }
+
+                    sumf[j] +=
+                        (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
+                }
+            }
+        }
+
+        for (int j = 0; j < ncols_interleaved; j++) {
+            s[x * ncols_interleaved + j] = sumf[j];
+        }
+    }
+}
+
+template 
+static void ggml_gemm_q6_K_NxM_q8_K_generic_impl(int                        n,
+                                                 float * GGML_RESTRICT      s,
+                                                 size_t                     bs,
+                                                 const void * GGML_RESTRICT vx,
+                                                 const void * GGML_RESTRICT vy,
+                                                 int                        nr,
+                                                 int                        nc) {
+    constexpr int blocklen          = M;
+    constexpr int ncols_interleaved = N;
+    const int     qk                = QK_K;
+    const int     nb                = n / qk;
+    const int     blocks_per_half   = 64 / blocklen;
+    const int     q8_half_stride    = 512;
+    const int     q8_low_high_step  = 256;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(bs);
+
+    float sumf[4][8];
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
+
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumf[m][j] = 0.0f;
+                }
+            }
+
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen;
+                    const int base_h = base_l + 64;
+
+                    const int scale_idx_l = base_l / 16;
+                    const int scale_idx_h = base_h / 16;
+
+                    const int qh_shift_l = ((base_l % 128) / 32) * 2;
+                    const int qh_shift_h = ((base_h % 128) / 32) * 2;
+
+                    const int qh_half_l = (base_l / 128) * 32;
+                    const int qh_half_h = (base_h / 128) * 32;
+
+                    const int q8_base = (k / blocks_per_half) * q8_half_stride + (k % blocks_per_half) * (blocklen * 4);
+
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j];
+                            const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j];
+
+                            int sumi_l = 0;
+                            int sumi_h = 0;
+
+                            for (int i = 0; i < blocklen; i++) {
+                                const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i;
+                                const int l_4    = b_ptr[l].ql[ql_pos] & 0xF;
+                                const int hi_4   = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
+
+                                const int qh_idx_l   = qh_half_l + ((base_l + i) % 32);
+                                const int qh_chunk_l = qh_idx_l / blocklen;
+                                const int qh_pos_l   = qh_idx_l % blocklen;
+                                const int qh_offset_l =
+                                    qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l;
+                                const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
+
+                                const int qh_idx_h   = qh_half_h + ((base_h + i) % 32);
+                                const int qh_chunk_h = qh_idx_h / blocklen;
+                                const int qh_pos_h   = qh_idx_h % blocklen;
+                                const int qh_offset_h =
+                                    qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h;
+                                const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
+
+                                const int q_l = ((hi_2_l << 4) | l_4) - 32;
+                                const int q_h = ((hi_2_h << 4) | hi_4) - 32;
+
+                                const int8_t q8_l = a_ptr[l].qs[q8_base + m * blocklen + i];
+                                const int8_t q8_h = a_ptr[l].qs[q8_base + m * blocklen + i + q8_low_high_step];
+
+                                sumi_l += q_l * q8_l;
+                                sumi_h += q_h * q8_h;
+                            }
+
+                            sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) *
+                                          a_ptr[l].d[m];
+                        }
+                    }
+                }
+            }
+
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+                }
+            }
+        }
+    }
+}
+
+template 
+static void ggml_gemv_q5_K_NxM_q8_K_generic_impl(int                        n,
+                                                 float * GGML_RESTRICT      s,
+                                                 size_t                     bs,
+                                                 const void * GGML_RESTRICT vx,
+                                                 const void * GGML_RESTRICT vy,
+                                                 int                        nr,
+                                                 int                        nc) {
+    constexpr int         blocklen          = M;
+    constexpr int         ncols_interleaved = N;
+    const int             qk                = QK_K;
+    const int             nb                = n / qk;
+    static const uint32_t kmask1            = 0x3f3f3f3f;
+    static const uint32_t kmask2            = 0x0f0f0f0f;
+    static const uint32_t kmask3            = 0x03030303;
+
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(bs);
+    UNUSED(nr);
+
+    float    sumf[ncols_interleaved];
+    float    sum_minf[ncols_interleaved];
+    uint32_t utmp[32];
+    int      sumi1;
+    int      sumi2;
+    int      sumi;
+
+    const block_q8_K * a_ptr = (const block_q8_K *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) {
+            sumf[j]     = 0.0;
+            sum_minf[j] = 0.0;
+        }
+        for (int l = 0; l < nb; l++) {
+            for (int sb = 0; sb < 8; sb++) {
+                memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE);
+                utmp[sb * 4 + 3]      = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
+                const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
+                utmp[sb * 4 + 1]      = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
+                utmp[sb * 4 + 2]      = uaux_0;
+                utmp[sb * 4 + 0] &= kmask1;
+            }
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                constexpr int scale_stride = 32;
+                uint8_t *     scales_0     = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride;
+                uint8_t *     scales_1     = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16;
+
+                const int qh_shift = (k / (32 / blocklen)) * 2;
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi1 = 0;
+                    sumi2 = 0;
+                    sumi  = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
+
+                        const int qh_idx      = (k * blocklen + i) % 32;
+                        const int qh_chunk    = qh_idx / blocklen;
+                        const int qh_pos      = qh_idx % blocklen;
+                        const int b_qh_offset = qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos;
+
+                        const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
+                        const uint8_t h0     = (qh_val >> qh_shift) & 1;
+                        const uint8_t h1     = (qh_val >> (qh_shift + 1)) & 1;
+
+                        const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
+                        const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
+
+                        const int q8_offset = (k / (32 / blocklen)) * 64 + (k % (32 / blocklen)) * blocklen + i;
+
+                        sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
+                        sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]);
+                        sumi1 = sumi1 * scales_0[j];
+                        sumi2 = sumi2 * scales_1[j];
+                        sumi += sumi1 + sumi2;
+                    }
+                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
+                }
+            }
+            for (int sb = 0; sb < 8; sb++) {
+                uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) *
+                                   GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) {
+            s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
+        }
+    }
+}
+
+template 
+static void ggml_gemm_q5_K_NxM_q8_K_generic_impl(int                        n,
+                                                 float * GGML_RESTRICT      s,
+                                                 size_t                     bs,
+                                                 const void * GGML_RESTRICT vx,
+                                                 const void * GGML_RESTRICT vy,
+                                                 int                        nr,
+                                                 int                        nc) {
+    constexpr int         blocklen          = M;
+    constexpr int         ncols_interleaved = N;
+    const int             qk                = QK_K;
+    const int             nb                = n / qk;
+    static const uint32_t kmask1            = 0x3f3f3f3f;
+    static const uint32_t kmask2            = 0x0f0f0f0f;
+    static const uint32_t kmask3            = 0x03030303;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    float    sumf[4][ncols_interleaved];
+    float    sum_minf[4][ncols_interleaved];
+    uint32_t utmp[32];
+    int      sumi1;
+    int      sumi2;
+    int      sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumf[m][j]     = 0.0;
+                    sum_minf[m][j] = 0.0;
+                }
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int sb = 0; sb < 8; sb++) {
+                    memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE);
+                    utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
+                    utmp[sb * 4 + 1]      = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
+                    utmp[sb * 4 + 2]      = uaux_0;
+                    utmp[sb * 4 + 0] &= kmask1;
+                }
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    constexpr int scale_stride = 32;
+                    uint8_t *     scales_0     = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride;
+                    uint8_t *     scales_1     = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16;
+
+                    const int qh_shift = (k / (32 / blocklen)) * 2;
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi1 = 0;
+                            sumi2 = 0;
+                            sumi  = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
+
+                                const int qh_idx   = (k * blocklen + i) % 32;
+                                const int qh_chunk = qh_idx / blocklen;
+                                const int qh_pos   = qh_idx % blocklen;
+                                const int b_qh_offset =
+                                    qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos;
+
+                                const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
+                                const uint8_t h0     = (qh_val >> qh_shift) & 1;
+                                const uint8_t h1     = (qh_val >> (qh_shift + 1)) & 1;
+
+                                const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
+                                const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
+
+                                const int q8_offset = (k / (32 / blocklen)) * 256 +
+                                                      (k % (32 / blocklen)) * 4 * blocklen + m * blocklen + i;
+
+                                sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
+                                sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]);
+                                sumi1 = sumi1 * scales_0[j];
+                                sumi2 = sumi2 * scales_1[j];
+                                sumi += sumi1 + sumi2;
+                            }
+                            sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
+                        }
+                    }
+                }
+                for (int sb = 0; sb < 8; sb++) {
+                    uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
+                    for (int m = 0; m < 4; m++) {
+                        const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) *
+                                              GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
+                }
+            }
+        }
+    }
+}
+
 extern "C" {
 
 void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -474,15 +870,8 @@ void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
     assert (n % qk == 0);
     assert (nc % ncols_interleaved == 0);
 
-    UNUSED(s);
     UNUSED(bs);
-    UNUSED(vx);
-    UNUSED(vy);
     UNUSED(nr);
-    UNUSED(nc);
-    UNUSED(nb);
-    UNUSED(ncols_interleaved);
-    UNUSED(blocklen);
 
     float sumf[8];
     float sum_minf[8];
@@ -616,6 +1005,23 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
     }
 }
 
+void ggml_gemv_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    ggml_gemv_q5_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    ggml_gemv_q5_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
+}
+
+
+void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    ggml_gemv_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    ggml_gemv_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
+}
+
 void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     const int qk = QK8_0;
     const int nb = n / qk;
@@ -692,49 +1098,219 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
     }
 }
 
-void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     const int qk = QK8_0;
     const int nb = n / qk;
     const int ncols_interleaved = 4;
     const int blocklen = 4;
 
-    assert (n % qk == 0);
-    assert (nr % 4 == 0);
-    assert (nc % ncols_interleaved == 0);
+    assert(nr == 1);
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
 
-    UNUSED(s);
     UNUSED(bs);
-    UNUSED(vx);
-    UNUSED(vy);
     UNUSED(nr);
-    UNUSED(nc);
-    UNUSED(nb);
-    UNUSED(ncols_interleaved);
-    UNUSED(blocklen);
 
-    {
-        float sumf[4][4];
-        int sumi;
+    float sumf[4];
+    int sumi;
 
-        for (int y = 0; y < nr / 4; y++) {
-            const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
-            for (int x = 0; x < nc / ncols_interleaved; x++) {
-                const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
-                for (int m = 0; m < 4; m++) {
-                    for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
+                        const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
+                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
+                    }
+                    sumf[j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
                 }
-                for (int l = 0; l < nb; l++) {
-                    for (int k = 0; k < (qk / (2 * blocklen)); k++) {
-                        for (int m = 0; m < 4; m++) {
-                            for (int j = 0; j < ncols_interleaved; j++) {
-                                sumi = 0;
-                                for (int i = 0; i < blocklen; ++i) {
-                                    const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
-                                    const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
-                                    sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
-                                            (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
-                                }
-                                sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+    }
+}
+
+void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 8;
+    const int blocklen = 8;
+
+    assert(nr == 1);
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(bs);
+    UNUSED(nr);
+
+    float sumf[8];
+    int sumi;
+
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_mxfp4x8 * b_ptr = (const block_mxfp4x8 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
+                        const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
+                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
+                    }
+                    sumf[j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+    }
+}
+
+void ggml_gemv_q8_0_4x4_q8_0_generic(int                        n,
+                                     float * GGML_RESTRICT      s,
+                                     size_t                     bs,
+                                     const void * GGML_RESTRICT vx,
+                                     const void * GGML_RESTRICT vy,
+                                     int                        nr,
+                                     int                        nc) {
+    const int qk                = QK8_0;
+    const int nb                = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen          = 4;
+
+    assert(nr == 1);
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(bs);
+    UNUSED(nr);
+
+    float sumf[4];
+    int   sumi;
+
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) {
+            sumf[j] = 0.0;
+        }
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / blocklen); k++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
+                        sumi += v0 * a_ptr[l].qs[k * blocklen + i];
+                    }
+                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) {
+            s[x * ncols_interleaved + j] = sumf[j];
+        }
+    }
+}
+
+void ggml_gemv_q8_0_4x8_q8_0_generic(int                        n,
+                                     float * GGML_RESTRICT      s,
+                                     size_t                     bs,
+                                     const void * GGML_RESTRICT vx,
+                                     const void * GGML_RESTRICT vy,
+                                     int                        nr,
+                                     int                        nc) {
+    const int qk                = QK8_0;
+    const int nb                = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen          = 8;
+
+    assert(nr == 1);
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(bs);
+    UNUSED(nr);
+
+    float sumf[4];
+    int   sumi;
+
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) {
+            sumf[j] = 0.0;
+        }
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / blocklen); k++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
+                        sumi += v0 * a_ptr[l].qs[k * blocklen + i];
+                    }
+                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) {
+            s[x * ncols_interleaved + j] = sumf[j];
+        }
+    }
+}
+
+void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+    {
+        float sumf[4][4];
+        int sumi;
+
+        for (int y = 0; y < nr / 4; y++) {
+            const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+            for (int x = 0; x < nc / ncols_interleaved; x++) {
+                const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
+                for (int m = 0; m < 4; m++) {
+                    for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+                }
+                for (int l = 0; l < nb; l++) {
+                    for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                        for (int m = 0; m < 4; m++) {
+                            for (int j = 0; j < ncols_interleaved; j++) {
+                                sumi = 0;
+                                for (int i = 0; i < blocklen; ++i) {
+                                    const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                                    const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                                    sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+                                            (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
+                                }
+                                sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
                             }
                         }
                     }
@@ -952,15 +1528,7 @@ void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
     assert (nr % 4 == 0);
     assert (nc % ncols_interleaved == 0);
 
-    UNUSED(s);
     UNUSED(bs);
-    UNUSED(vx);
-    UNUSED(vy);
-    UNUSED(nr);
-    UNUSED(nc);
-    UNUSED(nb);
-    UNUSED(ncols_interleaved);
-    UNUSED(blocklen);
 
     float sumf[4][8];
     float sum_minf[4][8];
@@ -1118,6 +1686,21 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
     }
 }
 
+void ggml_gemm_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    ggml_gemm_q5_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    ggml_gemm_q5_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    ggml_gemm_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+   ggml_gemm_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
+}
 
 void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     const int qk = QK8_0;
@@ -1219,8 +1802,217 @@ void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
     }
 }
 
+void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    float sumf[4][4];
+    int sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
+                                const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
+                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
+                            }
+                            sumf[m][j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++)
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+            }
+        }
+    }
+}
+
+void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 8;
+    const int blocklen = 8;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    float sumf[4][8];
+    int sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_mxfp4x8 * b_ptr = (const block_mxfp4x8 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
+                                const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
+                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
+                            }
+                            sumf[m][j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++)
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+            }
+        }
+    }
+}
+
+void ggml_gemm_q8_0_4x4_q8_0_generic(int                        n,
+                                     float * GGML_RESTRICT      s,
+                                     size_t                     bs,
+                                     const void * GGML_RESTRICT vx,
+                                     const void * GGML_RESTRICT vy,
+                                     int                        nr,
+                                     int                        nc) {
+    const int qk                = QK8_0;
+    const int nb                = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen          = 4;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    float sumf[4][4];
+    int   sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumf[m][j] = 0.0;
+                }
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / blocklen); k++) {
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
+                                sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i];
+                            }
+                            sumf[m][j] +=
+                                sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+                }
+            }
+        }
+    }
+}
+
+void ggml_gemm_q8_0_4x8_q8_0_generic(int                        n,
+                                     float * GGML_RESTRICT      s,
+                                     size_t                     bs,
+                                     const void * GGML_RESTRICT vx,
+                                     const void * GGML_RESTRICT vy,
+                                     int                        nr,
+                                     int                        nc) {
+    const int qk                = QK8_0;
+    const int nb                = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen          = 8;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    float sumf[4][4];
+    int   sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumf[m][j] = 0.0;
+                }
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / blocklen); k++) {
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
+                                sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i];
+                            }
+                            sumf[m][j] +=
+                                sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+                }
+            }
+        }
+    }
+}
+
 } // extern "C"
 
+static block_q8_0x4 make_block_q8_0x4(block_q8_0 * in, unsigned int blck_size_interleave) {
+    block_q8_0x4 out;
+
+    for (int i = 0; i < 4; i++) {
+        out.d[i] = in[i].d;
+    }
+
+    const int end = QK8_0 * 4 / blck_size_interleave;
+    for (int i = 0; i < end; ++i) {
+        int src_id     = i % 4;
+        int src_offset = (i / 4) * blck_size_interleave;
+        int dst_offset = i * blck_size_interleave;
+        memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], blck_size_interleave);
+    }
+    return out;
+}
+
 static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
     block_q4_0x4 out;
 
@@ -1309,9 +2101,10 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in
         int src_offset = (i / 8) * blck_size_interleave;
         int dst_offset = i * blck_size_interleave;
 
+        // buffer large enough for the max interleave block size (8 bytes)
         uint64_t elems;
-        memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
-        memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
+        memcpy(&elems, &in[src_id].qs[src_offset], blck_size_interleave);
+        memcpy(&out.qs[dst_offset], &elems, blck_size_interleave);
     }
 
     // The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K
@@ -1397,17 +2190,146 @@ static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_in
     // Every 16 byte is packed such that it contains scales and mins for corresponding sub blocks from Q2_K structure
     // For eg - First 16 bytes contains 16 scales and 16 mins - each of first and second sub blocks from different Q2_K structures
 
-    for(int i = 0; i < 128; i++){
-
+    for (int i = 0; i < 128; i++) {
         // Index for selecting which q2k super block
         int src1 = (i % 16) / 2;
         // Index for selecting scale
         int src2 = ((i / 16) * 2) + (i % 2);
 
-        out.scales[i] = in[src1].scales[src2];
+        out.scales[i] = in[src1].scales[src2];
+    }
+    return out;
+}
+
+static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_interleave) {
+    block_q5_Kx8 out;
+    //Delta(scale) and dmin values of the eight Q5_K structures are copied onto the output interleaved structure
+    for (int i = 0; i < 8; i++) {
+        out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
+    }
+
+    for (int i = 0; i < 8; i++) {
+        out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
+    }
+
+    const int end = QK_K * 4 / blck_size_interleave;
+
+    // Interleave Q5_K quants by taking blck_size_interleave bytes at a time
+    for (int i = 0; i < end; ++i) {
+        int src_id     = i % 8;
+        int src_offset = (i / 8) * blck_size_interleave;
+        int dst_offset = i * blck_size_interleave;
+
+        memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], blck_size_interleave);
+    }
+
+    // Repeat for high bits with the same chunk size, since
+    // the high bits are interleaved in Q5_K and the index is
+    // qh_idx = (qs_idx % 32);
+    // qh_val = qh[qh_idx] >> (qs_idx / 32);
+    for (int i = 0; i < end / 4; ++i) {
+        int src_id     = i % 8;
+        int src_offset = (i / 8) * blck_size_interleave;
+        int dst_offset = i * blck_size_interleave;
+
+        memcpy(&out.qh[dst_offset], &in[src_id].qh[src_offset], blck_size_interleave);
+    }
+
+    // The below logic is copied over from Q4_K
+    // The point is to unpack all the scales and mins for each sub block every time we load 12 bytes.
+    // Currently the Q5_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value)
+    // The output Q5_Kx8 structure has 96 bytes
+    // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q5_K structure
+    // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q5_K structures
+    uint8_t s[8], m[8];
+
+    for (int i = 0; i < 4; i++) {
+        for (int j = 0; j < 8; j++) {
+            s[j] = in[j].scales[i] & 63;
+            m[j] = in[j].scales[i + 4] & 63;
+        }
+
+        out.scales[i * 12]      = (s[0] & 63) + ((s[4] & 48) << 2);
+        out.scales[i * 12 + 1]  = (s[1] & 63) + ((s[5] & 48) << 2);
+        out.scales[i * 12 + 2]  = (s[2] & 63) + ((s[6] & 48) << 2);
+        out.scales[i * 12 + 3]  = (s[3] & 63) + ((s[7] & 48) << 2);
+        out.scales[i * 12 + 4]  = (m[0] & 63) + ((m[4] & 48) << 2);
+        out.scales[i * 12 + 5]  = (m[1] & 63) + ((m[5] & 48) << 2);
+        out.scales[i * 12 + 6]  = (m[2] & 63) + ((m[6] & 48) << 2);
+        out.scales[i * 12 + 7]  = (m[3] & 63) + ((m[7] & 48) << 2);
+        out.scales[i * 12 + 8]  = (s[4] & 15) + ((m[4] & 15) << 4);
+        out.scales[i * 12 + 9]  = (s[5] & 15) + ((m[5] & 15) << 4);
+        out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4);
+        out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4);
+    }
+
+    for (int i = 0; i < 4; i++) {
+        for (int j = 0; j < 8; j++) {
+            s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i + 8] & 15);
+            m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i + 8] & 240) >> 4);
+        }
+
+        out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2);
+        out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2);
+        out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2);
+        out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2);
+        out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2);
+        out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2);
+        out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2);
+        out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2);
+        out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4);
+        out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4);
+        out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4);
+        out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4);
+    }
+
+    return out;
+}
+
+static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_interleave) {
+    block_q6_Kx8  out;
+    constexpr int n_blocks = 8;  // Kx8
+    for (int i = 0; i < n_blocks; i++) {
+        out.d[i] = in[i].d;
+    }
+
+    const int end_ls = QK_K * 4 / blck_size_interleave;
+    // Interleave Q6_K quants by taking blck_size_interleave bytes at a time
+    for (int i = 0; i < end_ls; ++i) {
+        int src_id     = i % n_blocks;
+        int src_offset = (i / n_blocks) * blck_size_interleave;
+        int dst_offset = i * blck_size_interleave;
+
+        uint64_t elem_ls;
+        memcpy(&elem_ls, &in[src_id].ql[src_offset], blck_size_interleave);
+        memcpy(&out.ql[dst_offset], &elem_ls, blck_size_interleave);
+    }
+
+    // Interleave high bits using same chunk size as low bits
+    const int end_hs = end_ls / 2;
+    for (int i = 0; i < end_hs; ++i) {
+        int src_id     = i % n_blocks;
+        int src_offset = (i / n_blocks) * blck_size_interleave;
+        int dst_offset = i * blck_size_interleave;
+
+        uint64_t elem_hs;
+        memcpy(&elem_hs, &in[src_id].qh[src_offset], blck_size_interleave);
+        memcpy(&out.qh[dst_offset], &elem_hs, blck_size_interleave);
+    }
+
+    // The below logic is designed so as to unpack and rearrange scales in Q6_K
+    // The output Q6_Kx8 structure interleaves the 8 bit scales in the same fashion as the quants
+    // Q6_K structure has an 8-bit scale per 16 elements -> 16 scales
+    // scales: [0 bl0 0 bl1 ... 0 bl7][1 bl0 ... 1 bl7] ... [15 bl0 ... 15 bl7]  (bl = block)
+    constexpr int n_scales = QK_K / 16;
+
+    for (int i = 0; i < n_blocks; i++) {
+        for (int j = 0; j < n_scales; j++) {
+            out.scales[j * n_blocks + i] = in[i].scales[j];
+        }
     }
-    return out;
 
+    return out;
 }
 
 static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
@@ -1491,7 +2413,7 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block
 
     for (int b = 0; b < nrow; b += nrows_interleaved) {
         for (int64_t x = 0; x < nblocks; x++) {
-            for (int i  = 0; i < nrows_interleaved; i++ ) {
+            for (int i = 0; i < nrows_interleaved; i++) {
                 dst_tmp[i] = src[x + i * nblocks];
             }
             *dst++ = make_block_q2_Kx8(dst_tmp, interleave_block);
@@ -1503,6 +2425,67 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block
     GGML_UNUSED(data_size);
 }
 
+static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor *       t,
+                                    int                        interleave_block,
+                                    const void * GGML_RESTRICT data,
+                                    size_t                     data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_Q5_K);
+    GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
+    constexpr int nrows_interleaved = 8;
+
+    block_q5_Kx8 *     dst = (block_q5_Kx8 *) t->data;
+    const block_q5_K * src = (const block_q5_K *) data;
+    block_q5_K         dst_tmp[8];
+    int                nrow    = ggml_nrows(t);
+    int                nblocks = t->ne[0] / QK_K;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_K));
+
+    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i = 0; i < nrows_interleaved; i++) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_q5_Kx8(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+}
+
+static int repack_q6_K_to_q6_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_Q6_K);
+    GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
+    constexpr int nrows_interleaved = 8;
+
+    block_q6_Kx8 * dst = (block_q6_Kx8 *)t->data;
+    const block_q6_K * src = (const block_q6_K *) data;
+    block_q6_K dst_tmp[8];
+    int nrow = ggml_nrows(t);
+    int nblocks = t->ne[0] / QK_K;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q6_K));
+
+    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i = 0; i < nrows_interleaved; i++) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_q6_Kx8(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+}
+
 static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
     GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
     GGML_ASSERT(interleave_block == 8);
@@ -1534,6 +2517,38 @@ static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block
     GGML_UNUSED(data_size);
 }
 
+static int repack_q8_0_to_q8_0_4_bl(struct ggml_tensor *       t,
+                                    int                        interleave_block,
+                                    const void * GGML_RESTRICT data,
+                                    size_t                     data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_Q8_0);
+    GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
+    constexpr int nrows_interleaved = 4;
+
+    block_q8_0x4 *     dst = (block_q8_0x4 *) t->data;
+    const block_q8_0 * src = (const block_q8_0 *) data;
+    block_q8_0         dst_tmp[4];
+    int                nrow    = ggml_nrows(t);
+    int                nblocks = t->ne[0] / QK8_0;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0));
+
+    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i = 0; i < nrows_interleaved; i++) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_q8_0x4(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+}
+
 static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) {
     block_iq4_nlx4 out;
 
@@ -1659,6 +2674,121 @@ static int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_b
     GGML_UNUSED(data_size);
 }
 
+
+static block_mxfp4x4 make_block_mxfp4x4(block_mxfp4 * in, unsigned int blck_size_interleave) {
+    block_mxfp4x4 out;
+
+    for (int i = 0; i < 4; i++) {
+        out.e[i] = in[i].e;
+    }
+
+    const int end = QK_MXFP4 * 2 / blck_size_interleave;
+
+    if (blck_size_interleave == 4) {
+        for (int i = 0; i < end; ++i) {
+            int src_id = i % 4;
+            int src_offset = (i / 4) * blck_size_interleave;
+            int dst_offset = i * blck_size_interleave;
+
+            memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t));
+        }
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    return out;
+}
+
+static int repack_mxfp4_to_mxfp4_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_MXFP4);
+    GGML_ASSERT(interleave_block == 4);
+
+    const block_mxfp4   * src = (const block_mxfp4   *)data;
+          block_mxfp4x4 * dst = (      block_mxfp4x4 *)t->data;
+
+    block_mxfp4 dst_tmp[4];
+
+    int nrow = ggml_nrows(t);
+    int nrows_interleaved = 4;
+    int nblocks = t->ne[0] / QK_MXFP4;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4));
+
+    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i = 0; i < nrows_interleaved; i++) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_mxfp4x4(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+
+    GGML_UNUSED(data_size);
+}
+
+static block_mxfp4x8 make_block_mxfp4x8(block_mxfp4 * in, unsigned int blck_size_interleave) {
+    block_mxfp4x8 out;
+
+    for (int i = 0; i < 8; i++) {
+        out.e[i] = in[i].e;
+    }
+
+    const int end = QK_MXFP4 * 4 / blck_size_interleave;
+
+    if (blck_size_interleave == 8) {
+        for (int i = 0; i < end; ++i) {
+            int src_id = i % 8;
+            int src_offset = (i / 8) * blck_size_interleave;
+            int dst_offset = i * blck_size_interleave;
+
+            memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
+        }
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    return out;
+}
+
+static int repack_mxfp4_to_mxfp4_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_MXFP4);
+    GGML_ASSERT(interleave_block == 8);
+
+    const block_mxfp4   * src = (const block_mxfp4   *)data;
+          block_mxfp4x8 * dst = (      block_mxfp4x8 *)t->data;
+
+    block_mxfp4 dst_tmp[8];
+
+    int nrow = ggml_nrows(t);
+    int nrows_interleaved = 8;
+    int nblocks = t->ne[0] / QK_MXFP4;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4));
+
+    if (t->ne[1] % nrows_interleaved != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i = 0; i < nrows_interleaved; i++) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_mxfp4x8(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+
+    GGML_UNUSED(data_size);
+}
+
 namespace ggml::cpu::repack {
 // repack
 template 
@@ -1689,6 +2819,22 @@ template <> int repack(struct ggml_tensor * t, const void * da
     return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
 }
 
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_q5_K_to_q5_K_8_bl(t, 4, data, data_size);
+}
+
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size);
+}
+
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_q6_K_to_q6_K_8_bl(t, 4, data, data_size);
+}
+
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_q6_K_to_q6_K_8_bl(t, 8, data, data_size);
+}
+
 template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
     return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
 }
@@ -1702,6 +2848,22 @@ template <> int repack(struct ggml_tensor * t, const void *
     return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size);
 }
 
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_mxfp4_to_mxfp4_4_bl(t, 4, data, data_size);
+}
+
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_mxfp4_to_mxfp4_8_bl(t, 8, data, data_size);
+}
+
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_q8_0_to_q8_0_4_bl(t, 4, data, data_size);
+}
+
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_q8_0_to_q8_0_4_bl(t, 8, data, data_size);
+}
+
 // gemv
 template 
 void gemv(int, float *, size_t, const void *, const void *, int, int);
@@ -1718,6 +2880,17 @@ template <> void gemv(int n, float * s, size_t
     ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
+template <>
+void gemv(int          n,
+                                            float *      s,
+                                            size_t       bs,
+                                            const void * vx,
+                                            const void * vy,
+                                            int          nr,
+                                            int          nc) {
+    ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
 template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
 }
@@ -1726,8 +2899,20 @@ template <> void gemv(int n, float * s, size_t
     ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
 }
 
-template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
-    ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_q5_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_q6_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
 }
 
 template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
@@ -1738,6 +2923,22 @@ template <> void gemv(int n, float * s, size
     ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_mxfp4_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_mxfp4_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
 // gemm
 template 
 void gemm(int, float *, size_t, const void *, const void *, int, int);
@@ -1750,20 +2951,43 @@ template <> void gemm(int n, float * s, size_t
     ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
-template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
-    ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
+template <>
+void gemm(int          n,
+                                            float *      s,
+                                            size_t       bs,
+                                            const void * vx,
+                                            const void * vy,
+                                            int          nr,
+                                            int          nc) {
+    ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
-template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
-    ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
 }
 
 template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
 }
 
-template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
-    ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q5_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q6_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
 }
 
 template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
@@ -1774,6 +2998,22 @@ template <> void gemm(int n, float * s, size
     ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_mxfp4_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_mxfp4_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
 class tensor_traits_base : public ggml::cpu::tensor_traits {
   public:
     virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
@@ -2122,20 +3362,19 @@ template (ne00,
-                        (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
-                        src0_cur + src0_cur_start * nb01,
-                        src1_col, 1, src0_cur_end - src0_cur_start);
+                gemv(
+                    ne00, (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
+                    src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start);
             }
         }
 #undef MMID_MATRIX_ROW
@@ -2151,7 +3390,6 @@ template  q4_0_4x4_q8_0;
     static const ggml::cpu::repack::tensor_traits q4_0_4x8_q8_0;
@@ -2161,6 +3399,14 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
     static const ggml::cpu::repack::tensor_traits q4_K_8x4_q8_K;
     static const ggml::cpu::repack::tensor_traits q4_K_8x8_q8_K;
 
+    // instance for Q5_K
+    static const ggml::cpu::repack::tensor_traits q5_K_8x4_q8_K;
+    static const ggml::cpu::repack::tensor_traits q5_K_8x8_q8_K;
+
+    // instance for Q6_K
+    static const ggml::cpu::repack::tensor_traits q6_K_8x4_q8_K;
+    static const ggml::cpu::repack::tensor_traits q6_K_8x8_q8_K;
+
     // instance for Q2
     static const ggml::cpu::repack::tensor_traits q2_K_8x8_q8_K;
 
@@ -2168,6 +3414,14 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
     static const ggml::cpu::repack::tensor_traits iq4_nl_4x4_q8_0;
     static const ggml::cpu::repack::tensor_traits iq4_nl_8x8_q8_0;
 
+    // instance for MXFP4
+    static const ggml::cpu::repack::tensor_traits mxfp4_4x4_q8_0;
+    static const ggml::cpu::repack::tensor_traits mxfp4_8x8_q8_0;
+
+    // instance for Q8_0
+    static const ggml::cpu::repack::tensor_traits q8_0_4x4_q8_0;
+    static const ggml::cpu::repack::tensor_traits q8_0_4x8_q8_0;
+
     if (cur->type == GGML_TYPE_Q4_0) {
         if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)
             || (ggml_cpu_has_riscv_v() && (ggml_cpu_get_rvv_vlen() >= QK4_0))) {
@@ -2207,6 +3461,28 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
                 return &q2_K_8x8_q8_K;
             }
         }
+    } else if (cur->type == GGML_TYPE_Q5_K) {
+        if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
+            if (cur->ne[1] % 8 == 0) {
+                return &q5_K_8x8_q8_K;
+            }
+        }
+        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+            if (cur->ne[1] % 8 == 0) {
+                return &q5_K_8x4_q8_K;
+            }
+        }
+    } else if (cur->type == GGML_TYPE_Q6_K) {
+        if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
+            if (cur->ne[1] % 8 == 0) {
+                return &q6_K_8x8_q8_K;
+            }
+        }
+        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+            if (cur->ne[1] % 8 == 0) {
+                return &q6_K_8x4_q8_K;
+            }
+        }
     } else if (cur->type == GGML_TYPE_IQ4_NL) {
         if (ggml_cpu_has_avx2()) {
             if (cur->ne[1] % 8 == 0) {
@@ -2218,6 +3494,28 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
                 return &iq4_nl_4x4_q8_0;
             }
         }
+    } else if (cur->type == GGML_TYPE_MXFP4) {
+        if (ggml_cpu_has_avx2()) {
+            if (cur->ne[1] % 8 == 0) {
+                return &mxfp4_8x8_q8_0;
+            }
+        }
+        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+            if (cur->ne[1] % 4 == 0) {
+                return &mxfp4_4x4_q8_0;
+            }
+        }
+    } else if (cur->type == GGML_TYPE_Q8_0) {
+        if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
+            if (cur->ne[1] % 4 == 0) {
+                return &q8_0_4x8_q8_0;
+            }
+        }
+        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+            if (cur->ne[1] % 4 == 0) {
+                return &q8_0_4x4_q8_0;
+            }
+        }
     }
 
     return nullptr;
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/repack.h b/ml/backend/ggml/ggml/src/ggml-cpu/repack.h
index c4d928cd15a..b9f821630c4 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/repack.h
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/repack.h
@@ -44,6 +44,7 @@ struct block_q4_Kx8 {
 };
 
 static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
+
 struct block_q2_Kx8 {
     ggml_half d[8];      // super-block scale for quantized scales
     ggml_half dmin[8];   // super-block scale for quantized mins
@@ -52,6 +53,28 @@ struct block_q2_Kx8 {
 };
 
 static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding");
+
+struct block_q5_Kx8 {
+    ggml_half d[8];              // super-block scale for quantized scales
+    ggml_half dmin[8];           // super-block scale for quantized mins
+    uint8_t   scales[96];        // scales and mins, quantized with 6 bits
+    uint8_t   qh[QK_K * 8 / 8];  // high bits of 5-bit quants
+    uint8_t   qs[QK_K * 8 / 2];  // low bits of 5-bit quants (in groups of 4)
+};
+
+static_assert(sizeof(block_q5_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 5,
+              "wrong q5_K block size/padding");
+
+struct block_q6_Kx8 {
+    ggml_half d[8];
+    int8_t    scales[QK_K / 16 * 8];
+    uint8_t   ql[QK_K / 2 * 8];  // low bits of 6-bit quants (groups of 2)
+    uint8_t   qh[QK_K / 4 * 8];  // high bits of 6-bit quants (groups of 4)
+};
+
+static_assert(sizeof(block_q6_Kx8) == sizeof(ggml_half) * 8 + QK_K / 16 * 8 + 3 * QK_K / 4 * 8,
+              "wrong q6_K block size/padding");
+
 struct block_q8_Kx4 {
     float d[4];              // delta
     int8_t qs[QK_K * 4];     // quants
@@ -74,6 +97,19 @@ struct block_iq4_nlx8 {
 
 static_assert(sizeof(block_iq4_nlx8) == 8 * sizeof(ggml_half) + QK4_NL * 4, "wrong iq4_nlx8 block size/padding");
 
+struct block_mxfp4x4 {
+    uint8_t e[4];
+    uint8_t qs[QK_MXFP4 * 2];
+};
+static_assert(sizeof(block_mxfp4x4) == 4 + QK_MXFP4 * 2, "wrong mxfp4x4 block size/padding");
+
+struct block_mxfp4x8 {
+    uint8_t e[8];
+    uint8_t qs[QK_MXFP4 * 4];
+};
+static_assert(sizeof(block_mxfp4x8) == 8 + QK_MXFP4 * 4, "wrong mxfp4x8 block size/padding");
+
+
 #if defined(__cplusplus)
 extern "C" {
 #endif
@@ -85,19 +121,35 @@ void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
 void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
-void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
-void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 
 // Native implementations
 void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
@@ -107,19 +159,35 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG
 void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
-void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
-void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 
 #if defined(__cplusplus)
 } // extern "C"
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/simd-gemm.h b/ml/backend/ggml/ggml/src/ggml-cpu/simd-gemm.h
new file mode 100644
index 00000000000..78d663e593e
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/simd-gemm.h
@@ -0,0 +1,136 @@
+#pragma once
+
+// Computes C[M x N] += A[M x K] * B[K x N]
+
+#include "simd-mappings.h"
+
+// TODO: add support for sizeless vector types
+#if defined(GGML_SIMD) && !defined(__ARM_FEATURE_SVE) && !defined(__riscv_v_intrinsic)
+
+// TODO: untested on avx512
+// These are in units of GGML_F32_EPR
+#if defined(__AVX512F__) || defined (__ARM_NEON__)
+    static constexpr int GEMM_RM = 4;
+    static constexpr int GEMM_RN = 4; // 16+4+1 = 25/32
+#elif defined(__AVX2__) || defined(__AVX__)
+    static constexpr int GEMM_RM = 6;
+    static constexpr int GEMM_RN = 2; // 12+2+1 = 15/16
+#else
+    static constexpr int GEMM_RM = 2;
+    static constexpr int GEMM_RN = 2;
+#endif
+
+template 
+static inline void simd_gemm_ukernel(
+    float       * GGML_RESTRICT C,
+    const float * GGML_RESTRICT A,
+    const float * GGML_RESTRICT B,
+    int K, int N)
+{
+    static constexpr int KN = GGML_F32_EPR;
+
+    GGML_F32_VEC acc[RM][RN];
+    for (int64_t i = 0; i < RM; i++) {
+        for (int r = 0; r < RN; r++) {
+            acc[i][r] = GGML_F32_VEC_LOAD(C + i * N + r * KN);
+        }
+    }
+
+    for (int64_t kk = 0; kk < K; kk++) {
+        GGML_F32_VEC Bv[RN];
+        for (int r = 0; r < RN; r++) {
+            Bv[r] = GGML_F32_VEC_LOAD(B + kk * N + r * KN);
+        }
+        for (int64_t i = 0; i < RM; i++) {
+            GGML_F32_VEC p = GGML_F32_VEC_SET1(A[i * K + kk]);
+            for (int r = 0; r < RN; r++) {
+                acc[i][r] = GGML_F32_VEC_FMA(acc[i][r], Bv[r], p);
+            }
+        }
+    }
+
+    for (int64_t i = 0; i < RM; i++) {
+        for (int r = 0; r < RN; r++) {
+            GGML_F32_VEC_STORE(C + i * N + r * KN, acc[i][r]);
+        }
+    }
+}
+
+// C[M x N] += A[M x K] * B[K x N]
+static void simd_gemm(
+    float       * GGML_RESTRICT C,
+    const float * GGML_RESTRICT A,
+    const float * GGML_RESTRICT B,
+    int M, int K, int N)
+{
+    static constexpr int KN = GGML_F32_EPR;
+
+    int64_t ii = 0;
+    for (; ii + GEMM_RM <= M; ii += GEMM_RM) {
+        int64_t jj = 0;
+        for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
+            simd_gemm_ukernel(C + jj, A, B + jj, K, N);
+        }
+        for (; jj + KN <= N; jj += KN) {
+            simd_gemm_ukernel(C + jj, A, B + jj, K, N);
+        }
+        for (; jj < N; jj++) {
+            for (int64_t i = 0; i < GEMM_RM; i++) {
+                float a = C[i * N + jj];
+                for (int64_t kk = 0; kk < K; kk++) {
+                    a += A[i + kk] * B[kk * N + jj];
+                }
+                C[i * N + jj] = a;
+            }
+        }
+
+        A += GEMM_RM * K;
+        C += GEMM_RM * N;
+    }
+
+    // Tail rows: one at a time
+    for (; ii < M; ii++) {
+        int64_t jj = 0;
+        for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
+            simd_gemm_ukernel<1, GEMM_RN>(C + jj, A, B + jj, K, N);
+        }
+        for (; jj + KN <= N; jj += KN) {
+            simd_gemm_ukernel<1, 1>(C + jj, A, B + jj, K, N);
+        }
+        for (; jj < N; jj++) {
+            float a = C[jj];
+            for (int64_t kk = 0; kk < K; kk++) {
+                a += A[kk] * B[kk * N + jj];
+            }
+            C[jj] = a;
+        }
+
+        A += K;
+        C += N;
+    }
+}
+
+#if defined(__GNUC__) && !defined(__clang__)
+#pragma GCC diagnostic pop
+#endif
+
+#else // scalar path
+
+static void simd_gemm(
+    float       * GGML_RESTRICT C,
+    const float * GGML_RESTRICT A,
+    const float * GGML_RESTRICT B,
+    int M, int K, int N)
+{
+    for (int64_t i = 0; i < M; i++) {
+        for (int64_t j = 0; j < N; j++) {
+            float sum = C[i * N + j];
+            for (int64_t kk = 0; kk < K; kk++) {
+                sum += A[i * K + kk] * B[kk * N + j];
+            }
+            C[i * N + j] = sum;
+        }
+    }
+}
+
+#endif // GGML_SIMD
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/simd-mappings.h b/ml/backend/ggml/ggml/src/ggml-cpu/simd-mappings.h
index 101a9c086b2..22de55700d4 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/simd-mappings.h
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/simd-mappings.h
@@ -14,10 +14,6 @@
 #include 
 #endif
 
-#if defined(__F16C__)
-#include 
-#endif
-
 #if defined(__riscv_v_intrinsic)
 #include 
 #endif
@@ -120,6 +116,17 @@ extern "C" {
 // defined in ggml-cpu.c, initialized in ggml_cpu_init()
 extern float ggml_table_f32_f16[1 << 16];
 
+// precomputed f32 table for e8m0 half (1 KB)
+// defined in ggml-cpu.c, initialized in ggml_cpu_init()
+extern float ggml_table_f32_e8m0_half[1 << 8];
+
+// Use lookup table for E8M0 on x86 (faster than bit manipulation)
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+#define GGML_CPU_E8M0_TO_FP32_HALF(x) ggml_table_f32_e8m0_half[(uint8_t)(x)]
+#else
+#define GGML_CPU_E8M0_TO_FP32_HALF(x) GGML_E8M0_TO_FP32_HALF(x)
+#endif
+
 // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
 // so we define GGML_CPU_FP16_TO_FP32 and GGML_CPU_FP32_TO_FP16 elsewhere for NEON.
 // This is also true for POWER9.
@@ -658,6 +665,14 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
           vec_extract(x[0], 2) +               \
           vec_extract(x[0], 3);                \
 }
+#define GGML_F32x4_REDUCE_4(res, s0, s1, s2, s3)        \
+{                                                       \
+    vector float v = vec_add(vec_add(s0, s1),           \
+                             vec_add(s2, s3));          \
+    v = vec_add(v, vec_sld(v, v, 8));                   \
+    v = vec_add(v, vec_sld(v, v, 4));                   \
+    res += (ggml_float) vec_extract(v, 0);              \
+}
 
 #define GGML_F32_VEC        GGML_F32x4
 #define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
@@ -694,6 +709,29 @@ static inline unsigned char ggml_endian_byte(int i) {
                                    r[i - GGML_ENDIAN_BYTE(0)]), \
             0, p - GGML_F16_EPR)
 
+//BF16 POWER9
+#define GGML_BF16_STEP 16
+#define GGML_BF16_EPR  8
+
+#define GGML_BF16x8         vector unsigned short
+#define GGML_BF16x8_ZERO    vec_splats((unsigned short)0)
+#define GGML_BF16x8_LOAD(p) vec_xl(0, (const unsigned short *)(p))
+
+#define GGML_BF16_VEC          GGML_BF16x8
+#define GGML_BF16_VEC_ZERO     GGML_BF16x8_ZERO
+#define GGML_BF16_VEC_LOAD     GGML_BF16x8_LOAD
+#if defined(__LITTLE_ENDIAN__)
+#define GGML_BF16_TO_F32_LO(v) ((vector float) vec_mergel(GGML_BF16_VEC_ZERO, (v)))
+#define GGML_BF16_TO_F32_HI(v) ((vector float) vec_mergeh(GGML_BF16_VEC_ZERO, (v)))
+#else
+#define GGML_BF16_TO_F32_LO(v) ((vector float) vec_mergel((v), GGML_BF16_VEC_ZERO))
+#define GGML_BF16_TO_F32_HI(v) ((vector float) vec_mergeh((v), GGML_BF16_VEC_ZERO))
+#endif
+#define GGML_BF16_FMA_LO(acc, x, y) \
+    (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_LO(x), GGML_BF16_TO_F32_LO(y))
+#define GGML_BF16_FMA_HI(acc, x, y) \
+    (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_HI(x), GGML_BF16_TO_F32_HI(y))
+
 #elif defined(__wasm_simd128__)
 
 #define GGML_SIMD
@@ -1122,6 +1160,14 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
     float32x4_t tmp = x[0] + vec_reve(x[0]);        \
     res = tmp[0] + tmp[1];                          \
 }
+#define GGML_F32x4_REDUCE_4(res, s0, s1, s2, s3) \
+{                                                \
+    float32x4_t v = vec_add(vec_add(s0, s1),     \
+                            vec_add(s2, s3));    \
+    v = vec_add(v, vec_sld(v, v, 8));            \
+    v = vec_add(v, vec_sld(v, v, 4));            \
+    res += (ggml_float)vec_extract(v, 0);        \
+}
 
 #define GGML_F32_VEC        GGML_F32x4
 #define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
@@ -1171,6 +1217,24 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) {
 #define GGML_F16_VEC_MUL            GGML_F32x4_MUL
 #define GGML_F16_VEC_REDUCE         GGML_F32x4_REDUCE
 
+// BF16 s390x
+#define GGML_BF16_STEP 16
+#define GGML_BF16_EPR  8
+
+#define GGML_BF16x8         __vector unsigned short
+#define GGML_BF16x8_ZERO    vec_splats((unsigned short)0)
+#define GGML_BF16x8_LOAD(p) vec_xl(0, (const unsigned short *)(p))
+
+#define GGML_BF16_VEC      GGML_BF16x8
+#define GGML_BF16_VEC_ZERO GGML_BF16x8_ZERO
+#define GGML_BF16_VEC_LOAD GGML_BF16x8_LOAD
+#define GGML_BF16_TO_F32_LO(v) ((float32x4_t) vec_mergel((v), GGML_BF16_VEC_ZERO))
+#define GGML_BF16_TO_F32_HI(v) ((float32x4_t) vec_mergeh((v), GGML_BF16_VEC_ZERO))
+#define GGML_BF16_FMA_LO(acc, x, y) \
+    (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_LO(x), GGML_BF16_TO_F32_LO(y))
+#define GGML_BF16_FMA_HI(acc, x, y) \
+    (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_HI(x), GGML_BF16_TO_F32_HI(y))
+
 #elif defined(__riscv_v_intrinsic)
 
 // compatible with vlen >= 128
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.cpp
index 1d9873ad0f2..1d8344436f0 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.cpp
@@ -111,7 +111,7 @@ template 
 static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
 
-    GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_contiguous_rows(src0) && ggml_is_contiguous_rows(dst) && ggml_are_same_shape(src0, dst));
 
     GGML_TENSOR_UNARY_OP_LOCALS
 
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/vec.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/vec.cpp
index ac8633e2128..d0e4001338a 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/vec.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/vec.cpp
@@ -195,6 +195,63 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t *
     sumf += (ggml_float)_mm_cvtss_f32(g);
 
 #undef LOAD
+#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfbfwma)
+    size_t vl = __riscv_vsetvlmax_e32m4();
+
+    // initialize accumulators to all zeroes
+    vfloat32m4_t vsum0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
+    vfloat32m4_t vsum1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
+
+    // calculate step size
+    const size_t epr = __riscv_vsetvlmax_e16m2();
+    const size_t step = epr * 2;
+    const int np = (n & ~(step - 1));
+
+    // unroll by 2
+    for (; i < np; i += step) {
+        vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i], epr);
+        vbfloat16m2_t ay0 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i], epr);
+        vsum0 = __riscv_vfwmaccbf16_vv_f32m4(vsum0, ax0, ay0, epr);
+        __asm__ __volatile__ ("" ::: "memory");
+
+        vbfloat16m2_t ax1 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i + epr], epr);
+        vbfloat16m2_t ay1 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i + epr], epr);
+        vsum1 = __riscv_vfwmaccbf16_vv_f32m4(vsum1, ax1, ay1, epr);
+        __asm__ __volatile__ ("" ::: "memory");
+    }
+
+    // accumulate in 1 register
+    vsum0 = __riscv_vfadd_vv_f32m4(vsum0, vsum1, vl);
+
+    // leftovers
+    for (i = np; i < n; i += vl) {
+        vl = __riscv_vsetvl_e16m2(n - i);
+        vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i], vl);
+        vbfloat16m2_t ay0 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i], vl);
+        vsum0 = __riscv_vfwmaccbf16_vv_f32m4(vsum0, ax0, ay0, vl);
+    }
+
+    // reduce
+    vl = __riscv_vsetvlmax_e32m4();
+    vfloat32m1_t redsum = __riscv_vfredusum_vs_f32m4_f32m1(vsum0, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
+    sumf += __riscv_vfmv_f_s_f32m1_f32(redsum);
+
+#elif defined(__POWER9_VECTOR__) || defined(__VXE__) || defined(__VXE2__)
+    const int np = (n & ~(GGML_BF16_STEP - 1));
+    if (np > 0) {
+        GGML_F32_VEC sum[4] = {GGML_F32_VEC_ZERO};
+        for (; i < np; i += GGML_BF16_STEP) {
+            GGML_BF16_VEC vx0 = GGML_BF16_VEC_LOAD(x + i);
+            GGML_BF16_VEC vx1 = GGML_BF16_VEC_LOAD(x + i + 8);
+            GGML_BF16_VEC vy0 = GGML_BF16_VEC_LOAD(y + i);
+            GGML_BF16_VEC vy1 = GGML_BF16_VEC_LOAD(y + i + 8);
+            GGML_BF16_FMA_LO(sum[0], vx0, vy0);
+            GGML_BF16_FMA_HI(sum[1], vx0, vy0);
+            GGML_BF16_FMA_LO(sum[2], vx1, vy1);
+            GGML_BF16_FMA_HI(sum[3], vx1, vy1);
+        }
+        GGML_F32x4_REDUCE_4(sumf, sum[0], sum[1], sum[2], sum[3]);
+    }
 #endif
 
     for (; i < n; ++i) {
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/vec.h b/ml/backend/ggml/ggml/src/ggml-cpu/vec.h
index bd80805fdc5..3198b33b509 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/vec.h
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/vec.h
@@ -224,13 +224,71 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
         }
         GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03);
         GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13);
-    #elif defined(__riscv_v_intrinsic)
-      // todo: RVV impl
-      for (int i = 0; i < n; ++i) {
-          for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
-              sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
-          }
-      }
+
+    #elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
+        size_t vl = __riscv_vsetvlmax_e32m4();
+
+        // initialize accumulators to all zeroes
+        vfloat32m4_t vsum0_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
+        vfloat32m4_t vsum0_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
+        vfloat32m4_t vsum1_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
+        vfloat32m4_t vsum1_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
+
+        // calculate step size
+        const size_t epr = __riscv_vsetvlmax_e16m2();
+        const size_t step = epr * 2;
+        const int np = (n & ~(step - 1));
+
+        // unroll by 2 along the row dimension
+        for (int i = 0; i < np; i += step) {
+            vfloat16m2_t ay0 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), epr);
+            vfloat16m2_t ax0_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), epr);
+            vfloat16m2_t ax1_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), epr);
+            vsum0_0 = __riscv_vfwmacc_vv_f32m4(vsum0_0, ax0_0, ay0, epr);
+            vsum1_0 = __riscv_vfwmacc_vv_f32m4(vsum1_0, ax1_0, ay0, epr);
+
+            vfloat16m2_t ay1 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i + epr), epr);
+            vfloat16m2_t ax0_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i + epr), epr);
+            vfloat16m2_t ax1_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i + epr), epr);
+            vsum0_1 = __riscv_vfwmacc_vv_f32m4(vsum0_1, ax0_1, ay1, epr);
+            vsum1_1 = __riscv_vfwmacc_vv_f32m4(vsum1_1, ax1_1, ay1, epr);
+        }
+
+        vfloat32m4_t vsum0 = __riscv_vfadd_vv_f32m4(vsum0_0, vsum0_1, vl);
+        vfloat32m4_t vsum1 = __riscv_vfadd_vv_f32m4(vsum1_0, vsum1_1, vl);
+
+        // leftovers
+        for (int i = np; i < n; i += vl) {
+            vl = __riscv_vsetvl_e16m2(n - i);
+            vfloat16m2_t ay = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), vl);
+            vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), vl);
+            vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), vl);
+
+            vsum0 = __riscv_vfwmacc_vv_f32m4(vsum0, ax0, ay, vl);
+            vsum1 = __riscv_vfwmacc_vv_f32m4(vsum1, ax1, ay, vl);
+        }
+
+        // reduce
+        vl = __riscv_vsetvlmax_e32m2();
+        vfloat32m2_t acc0_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum0, 0),
+                                    __riscv_vget_v_f32m4_f32m2(vsum0, 1), vl);
+        vl = __riscv_vsetvlmax_e32m1();
+        vfloat32m1_t acc0_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc0_0, 0),
+        __riscv_vget_v_f32m2_f32m1(acc0_0, 1), vl);
+        vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1(
+                                    acc0_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
+
+        vl = __riscv_vsetvlmax_e32m2();
+        vfloat32m2_t acc1_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum1, 0),
+                                    __riscv_vget_v_f32m4_f32m2(vsum1, 1), vl);
+        vl = __riscv_vsetvlmax_e32m1();
+        vfloat32m1_t acc1_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc1_0, 0),
+                                    __riscv_vget_v_f32m2_f32m1(acc1_0, 1), vl);
+        vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1(
+                                    acc1_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
+        sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0);
+        sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1);
+
     #else
         const int np = (n & ~(GGML_F16_STEP - 1));
 
@@ -475,15 +533,39 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y,
     }
     np = n;
 #elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic
-    const int np = n;
-    _Float16 hv = (_Float16)v;
-    for (int i = 0, avl; i < n; i += avl) {
-        avl = __riscv_vsetvl_e16m8(n - i);
-        vfloat16m8_t ax = __riscv_vle16_v_f16m8((const _Float16 *)&x[i], avl);
-        vfloat16m8_t ay = __riscv_vle16_v_f16m8((_Float16 *)&y[i], avl);
-        vfloat16m8_t ny = __riscv_vfmadd_vf_f16m8(ax, hv, ay, avl);
-        __riscv_vse16_v_f16m8((_Float16 *)&y[i], ny, avl);
+    const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
+    const _Float16 scale = *(const _Float16*)(&s);
+
+    // calculate step size
+    const int epr = __riscv_vsetvlmax_e16m4();
+    const int step = epr * 2;
+    int np = (n & ~(step - 1));
+
+    // unroll by 2
+    for (int i = 0; i < np; i += step) {
+        vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr);
+        vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
+        ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr);
+        __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
+        __asm__ __volatile__ ("" ::: "memory");
+
+        vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr);
+        vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
+        ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr);
+        __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
+        __asm__ __volatile__ ("" ::: "memory");
+    }
+
+    // leftovers
+    int vl;
+    for (int i = np; i < n; i += vl) {
+        vl = __riscv_vsetvl_e16m4(n - i);
+        vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, vl);
+        vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
+        ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl);
+        __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
     }
+    np = n;
 #elif defined(GGML_SIMD)
     const int np = (n & ~(GGML_F16_STEP - 1));
 
@@ -724,13 +806,34 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
         svst1_f16(pg, (__fp16 *)(y + np), out);
     }
 #elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
-    for (int i = 0, vl; i < n; i += vl) {
-        vl = __riscv_vsetvl_e16m2(n - i);
-        vfloat16m2_t vy = __riscv_vle16_v_f16m2((_Float16 *)&y[i], vl);
-        vfloat32m4_t vy32 = __riscv_vfwcvt_f_f_v_f32m4(vy, vl);
-        vy32 = __riscv_vfmul_vf_f32m4(vy32, v, vl);
-        vy = __riscv_vfncvt_f_f_w_f16m2(vy32, vl);
-        __riscv_vse16_v_f16m2((_Float16 *)&y[i], vy, vl);
+    const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
+    const _Float16 scale = *(const _Float16*)(&s);
+
+    // calculate step size
+    const int epr = __riscv_vsetvlmax_e16m4();
+    const int step = epr * 2;
+    const int np = (n & ~(step - 1));
+
+    // unroll by 2
+    for (int i = 0; i < np; i += step) {
+        vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
+        ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, epr);
+        __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
+        __asm__ __volatile__ ("" ::: "memory");
+
+        vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
+        ay1 = __riscv_vfmul_vf_f16m4(ay1, scale, epr);
+        __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
+        __asm__ __volatile__ ("" ::: "memory");
+    }
+
+    // leftovers
+    int vl;
+    for (int i = np; i < n; i += vl) {
+        vl = __riscv_vsetvl_e16m4(n - i);
+        vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
+        ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, vl);
+        __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
     }
 #elif defined(GGML_SIMD)
     const int np = (n & ~(GGML_F16_STEP - 1));
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt
index 67af1d8ccc1..262f88204e0 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt
@@ -15,6 +15,7 @@ if (CUDAToolkit_FOUND)
         # 80     == Ampere, asynchronous data loading, faster tensor core instructions
         # 86     == RTX 3000, needs CUDA v11.1
         # 89     == RTX 4000, needs CUDA v11.8
+        # 120    == Blackwell, needs CUDA v12.8, FP4 tensor cores
         #
         # XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run
         # XX-real    == compile CUDA code as device code for this specific architecture
@@ -34,12 +35,69 @@ if (CUDAToolkit_FOUND)
             if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
                 list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real)
             endif()
+
+            if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
+                # The CUDA architecture 120f-virtual would in principle work for Blackwell support
+                #     but the newly added "f" suffix conflicted with a preexising regex for validating CUDA architectures in CMake.
+                # So either a recent CMake version or one with the backported fix is needed.
+                # The following versions should work:
+                #   - CMake >= v3.31.8 && CMake < v4.0.0
+                #   - CMake >= v4.0.2
+                # This is NOT documented in the CMake release notes,
+                #     check Modules/Internal/CMakeCUDAArchitecturesValidate.cmake in the CMake git repository instead.
+                # However, the architectures 120a-real and 121a-real should work with basically any CMake version and
+                #     until the release of e.g. Rubin there is no benefit to shipping virtual architectures for Blackwell.
+                list(APPEND CMAKE_CUDA_ARCHITECTURES 120a-real)
+            endif()
+            if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.9")
+                list(APPEND CMAKE_CUDA_ARCHITECTURES 121a-real)
+            endif()
         endif()
     endif()
-    message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
 
     enable_language(CUDA)
 
+    # TODO: Remove once CCCL 3.2 has been released and bundled with CUDA Toolkit
+    if (GGML_CUDA_CUB_3DOT2)
+        include(FetchContent)
+
+        FetchContent_Declare(
+            CCCL
+            GIT_REPOSITORY https://github.com/nvidia/cccl.git
+            GIT_TAG        v3.2.0
+            GIT_SHALLOW    TRUE
+        )
+
+        FetchContent_MakeAvailable(CCCL)
+    endif()
+
+    # Replace any plain 12X CUDA architectures with their "architecture-specific" equivalents 12Xa.
+    # 12X is forwards-compatible, 12Xa is not.
+    # Notably the Blackwell FP4 tensor core instructions are not forwards compatible and therefore need 12Xa.
+    # But while 12X vs. 12Xa can be checked in device code there is (to my knowledge) no easy way to do the same check in host code.
+    # So for now just replace all instances of 12X with 12Xa, this should be fine until Rubin is released.
+    foreach(ARCHS IN ITEMS CMAKE_CUDA_ARCHITECTURES CMAKE_CUDA_ARCHITECTURES_NATIVE)
+        set(FIXED_ARCHS "")
+        foreach(ARCH IN LISTS ${ARCHS})
+            if (ARCH MATCHES "^12[0-9](-real|-virtual)?$")
+                string(REGEX REPLACE "^(12[0-9])((-real|-virtual)?)$" "\\1a\\2" FIXED_ARCH ${ARCH})
+                message(STATUS "Replacing ${ARCH} in ${ARCHS} with ${FIXED_ARCH}")
+                list(APPEND FIXED_ARCHS "${FIXED_ARCH}")
+            else()
+                list(APPEND FIXED_ARCHS "${ARCH}")
+            endif()
+        endforeach()
+        set(${ARCHS} ${FIXED_ARCHS})
+    endforeach()
+
+    # If we try to compile a "native" build it will use the 12X architectures and fail.
+    # So we should instead use the native architectures as determined by CMake after replacing 12X with 12Xa.
+    # But if at the time of the build no GPUs are connected at all CMAKE_CUDA_ARCHITECTURES will contain garbage that we should not use.
+    if (CMAKE_CUDA_ARCHITECTURES STREQUAL "native" AND CMAKE_CUDA_ARCHITECTURES_NATIVE MATCHES "^[0-9]+(a|f)?(-real|-virtual)?(;[0-9]+(a|f)?(-real|-virtual)?|;)*$")
+        set(CMAKE_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES_NATIVE})
+    endif()
+    message(STATUS "Using CMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} CMAKE_CUDA_ARCHITECTURES_NATIVE=${CMAKE_CUDA_ARCHITECTURES_NATIVE}")
+
     file(GLOB   GGML_HEADERS_CUDA "*.cuh")
     list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
 
@@ -102,6 +160,9 @@ if (CUDAToolkit_FOUND)
             # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
             target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas)
         else ()
+            if (GGML_CUDA_CUB_3DOT2)
+                target_link_libraries(ggml-cuda PRIVATE  CCCL::CCCL)
+            endif()
             if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "10.1")
                 target_link_libraries(ggml-cuda PRIVATE  CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
             else()
@@ -109,6 +170,9 @@ if (CUDAToolkit_FOUND)
             endif()
         endif()
     else()
+        if (GGML_CUDA_CUB_3DOT2)
+            target_link_libraries(ggml-cuda PRIVATE  CCCL::CCCL)
+        endif()
         target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas)
     endif()
 
@@ -177,6 +241,10 @@ if (CUDAToolkit_FOUND)
 
     if (NOT MSVC)
         list(APPEND CUDA_CXX_FLAGS -Wno-pedantic)
+    else()
+        # CCCL 3.2 onwards will require a cpp-standard-compliant preprocessor for MSVC
+        # https://github.com/NVIDIA/cccl/pull/6827
+        list(APPEND CUDA_CXX_FLAGS /Zc:preprocessor)
     endif()
 
     list(JOIN   CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED)  # pass host compiler flags as a single argument
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/argmax.cu b/ml/backend/ggml/ggml/src/ggml-cuda/argmax.cu
index 5340eedc089..51967c667cf 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/argmax.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/argmax.cu
@@ -21,7 +21,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
     }
 
 #pragma unroll
-    for (int offset = 16; offset > 0; offset >>= 1) {
+    for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
         const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
         const int   col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
         if (val > maxval) {
@@ -50,7 +50,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
                 argmax = shared_argmax[lane_id];
             }
 #pragma unroll
-            for (int offset = 16; offset > 0; offset >>= 1) {
+            for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
                 const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
                 const int   col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
                 if (val > maxval) {
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/argsort.cu b/ml/backend/ggml/ggml/src/ggml-cuda/argsort.cu
index b82be371c9e..6fae8b80867 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/argsort.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/argsort.cu
@@ -2,6 +2,9 @@
 
 #ifdef GGML_CUDA_USE_CUB
 #    include 
+#    if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 1)
+#        define STRIDED_ITERATOR_AVAILABLE
+#    endif
 using namespace cub;
 #endif  // GGML_CUDA_USE_CUB
 
@@ -14,63 +17,90 @@ static __global__ void init_indices(int * indices, const int ncols, const int nr
     }
 }
 
+#ifndef STRIDED_ITERATOR_AVAILABLE
 static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
     const int idx = blockIdx.x * blockDim.x + threadIdx.x;
     if (idx <= nrows) {
         offsets[idx] = idx * ncols;
     }
 }
+#endif  // STRIDED_ITERATOR_AVAILABLE
 
 #ifdef GGML_CUDA_USE_CUB
-static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
-                                     const float *    x,
-                                     int *            dst,
-                                     const int        ncols,
-                                     const int        nrows,
-                                     ggml_sort_order  order,
-                                     cudaStream_t     stream) {
+void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
+                              const float *    x,
+                              int *            dst,
+                              const int        ncols,
+                              const int        nrows,
+                              ggml_sort_order  order,
+                              cudaStream_t     stream) {
     ggml_cuda_pool_alloc   temp_indices_alloc(pool, ncols * nrows);
     ggml_cuda_pool_alloc temp_keys_alloc(pool, ncols * nrows);
-    ggml_cuda_pool_alloc   offsets_alloc(pool, nrows + 1);
 
     int *   temp_indices = temp_indices_alloc.get();
     float * temp_keys    = temp_keys_alloc.get();
-    int *   d_offsets    = offsets_alloc.get();
 
     static const int block_size = 256;
     const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
     init_indices<<>>(temp_indices, ncols, nrows);
 
-    const dim3 offset_grid((nrows + block_size - 1) / block_size);
-    init_offsets<<>>(d_offsets, ncols, nrows);
-
+#ifdef STRIDED_ITERATOR_AVAILABLE
+    auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols);
+#else
+    ggml_cuda_pool_alloc offsets_alloc(pool, nrows + 1);
+    int *                     offset_iterator = offsets_alloc.get();
+    const dim3                offset_grid((nrows + block_size - 1) / block_size);
+    init_offsets<<>>(offset_iterator, ncols, nrows);
+#endif
     CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream));
 
     size_t temp_storage_bytes = 0;
 
     if (order == GGML_SORT_ORDER_ASC) {
-        DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)
-                                            temp_indices, dst,                                  // values (indices)
-                                            ncols * nrows, nrows,                            // num items, num segments
-                                            d_offsets, d_offsets + 1, 0, sizeof(float) * 8,  // all bits
-                                            stream);
+        if (nrows == 1) {
+            DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)
+                                       temp_indices, dst,                                  // values (indices)
+                                       ncols, 0, sizeof(float) * 8, stream);
+        } else {
+            DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)
+                                           temp_indices, dst,                                  // values (indices)
+                                           ncols * nrows, nrows,  // num items, num segments
+                                           offset_iterator, offset_iterator + 1, stream);
+        }
     } else {
-        DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
-                                                      dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0,
-                                                      sizeof(float) * 8, stream);
+        if (nrows == 1) {
+            DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)
+                                                 temp_indices, dst,                                  // values (indices)
+                                                 ncols, 0, sizeof(float) * 8, stream);
+        } else {
+            DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
+                                                     dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1,
+                                                     stream);
+        }
     }
 
     ggml_cuda_pool_alloc temp_storage_alloc(pool, temp_storage_bytes);
     void *                        d_temp_storage = temp_storage_alloc.get();
 
     if (order == GGML_SORT_ORDER_ASC) {
-        DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
-                                            ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8,
-                                            stream);
+        if (nrows == 1) {
+            DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)
+                                       temp_indices, dst,  // values (indices)
+                                       ncols, 0, sizeof(float) * 8, stream);
+        } else {
+            DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
+                                           ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream);
+        }
     } else {
-        DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
-                                                      temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
-                                                      0, sizeof(float) * 8, stream);
+        if (nrows == 1) {
+            DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)
+                                                 temp_indices, dst,                                  // values (indices)
+                                                 ncols, 0, sizeof(float) * 8, stream);
+        } else {
+            DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
+                                                     temp_indices, dst, ncols * nrows, nrows, offset_iterator,
+                                                     offset_iterator + 1, stream);
+        }
     }
 }
 #endif  // GGML_CUDA_USE_CUB
@@ -141,12 +171,12 @@ static int next_power_of_2(int x) {
     return n;
 }
 
-static void argsort_f32_i32_cuda_bitonic(const float *   x,
-                                         int *           dst,
-                                         const int       ncols,
-                                         const int       nrows,
-                                         ggml_sort_order order,
-                                         cudaStream_t    stream) {
+void argsort_f32_i32_cuda_bitonic(const float *   x,
+                                  int *           dst,
+                                  const int       ncols,
+                                  const int       nrows,
+                                  ggml_sort_order order,
+                                  cudaStream_t    stream) {
     // bitonic sort requires ncols to be power of 2
     const int ncols_pad = next_power_of_2(ncols);
 
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/argsort.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/argsort.cuh
index 68a001547ff..22b7306f202 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/argsort.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/argsort.cuh
@@ -1,3 +1,19 @@
 #include "common.cuh"
 
 void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+#ifdef GGML_CUDA_USE_CUB
+void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
+                              const float *    x,
+                              int *            dst,
+                              const int        ncols,
+                              const int        nrows,
+                              ggml_sort_order  order,
+                              cudaStream_t     stream);
+#endif  // GGML_CUDA_USE_CUB
+void argsort_f32_i32_cuda_bitonic(const float *   x,
+                                  int *           dst,
+                                  const int       ncols,
+                                  const int       nrows,
+                                  ggml_sort_order order,
+                                  cudaStream_t    stream);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cu b/ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cu
index 0e6d777b1e6..7339fe0c070 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cu
@@ -39,13 +39,16 @@ static __global__ void k_bin_bcast(const src0_t *         src0,
                                    const uint3            ne11,
                                    const uint3            ne12,
                                    const uint3            ne13,
-                                   /*int s0, */ const int s1,
+                                 /*const int              s0,*/
+                                   const int              s1,
                                    const int              s2,
                                    const int              s3,
-                                   /*int s00,*/ const int s01,
+                                   const int              s00,
+                                   const int              s01,
                                    const int              s02,
                                    const int              s03,
-                                   /*int s10,*/ const int s11,
+                                   const int              s10,
+                                   const int              s11,
                                    const int              s12,
                                    const int              s13,
                                    src1_ptrs... src1s) {
@@ -72,11 +75,11 @@ static __global__ void k_bin_bcast(const src0_t *         src0,
     for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
         const uint32_t i10 = fastmodulo(i0, ne10);
 
-        float result = src0_row ? (float) src0_row[i0] : 0.0f;
+        float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
         if constexpr (sizeof...(src1_ptrs) > 0) {
-            result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
+            result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
         } else {
-            result = bin_op(result, (float)src1[i_src1 + i10]);
+            result = bin_op(result, (float)src1[i_src1 + i10*s10]);
         }
 
         dst_row[i0] = (dst_t) result;
@@ -101,13 +104,16 @@ static __global__ void k_bin_bcast_unravel(const src0_t *         src0,
                                            const uint3            ne11,
                                            const uint3            ne12,
                                            const uint3            ne13,
-                                           /*int s0, */ const int s1,
+                                         /*const int              s0,*/
+                                           const int              s1,
                                            const int              s2,
                                            const int              s3,
-                                           /*int s00,*/ const int s01,
+                                           const int              s00,
+                                           const int              s01,
                                            const int              s02,
                                            const int              s03,
-                                           /*int s10,*/ const int s11,
+                                           const int              s10,
+                                           const int              s11,
                                            const int              s12,
                                            const int              s13,
                                            src1_ptrs... src1s) {
@@ -135,11 +141,11 @@ static __global__ void k_bin_bcast_unravel(const src0_t *         src0,
 
     const int i10 = fastmodulo(i0, ne10);
 
-    float result = src0_row ? (float) src0_row[i0] : 0.0f;
+    float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
     if constexpr (sizeof...(src1_ptrs) > 0) {
-        result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
+        result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
     } else {
-        result = bin_op(result, (float)src1[i_src1 + i10]);
+        result = bin_op(result, (float)src1[i_src1 + i10*s10]);
     }
 
     dst_row[i0] = (dst_t) result;
@@ -179,7 +185,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
         cnb[3] *= cne[3];
     };
 
-    if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
+    if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
         for (int i = 0; i < 4; i++) {
             if (nr[i] != 1) {
                 break;
@@ -221,7 +227,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
         size_t nb12 = cnb1[2];
         size_t nb13 = cnb1[3];
 
-        size_t s0 = nb0 / sizeof(dst_t);
+      //size_t s0 = nb0 / sizeof(dst_t);
         size_t s1 = nb1 / sizeof(dst_t);
         size_t s2 = nb2 / sizeof(dst_t);
         size_t s3 = nb3 / sizeof(dst_t);
@@ -251,10 +257,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
         GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
         GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
 
-        GGML_ASSERT(s0 == 1);
-        GGML_ASSERT(s00 == 1);
-        GGML_ASSERT(s10 == 1);
-
         const int block_size = 128;
 
         int64_t hne0 = std::max(ne0 / 2LL, 1LL);
@@ -284,31 +286,31 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
                 k_bin_bcast_unravel<<>>(
                     src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
                     ne12, ne13,
-                    /* s0, */ s1, s2, s3,
-                    /* s00,*/ s01, s02, s03,
-                    /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
+                  /*s0,*/ s1,  s2,  s3,
+                    s00, s01, s02, s03,
+                    s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
             } else {
                 k_bin_bcast_unravel
                     <<>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
                                                            ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
-                                                           /* s0, */ s1, s2, s3,
-                                                           /* s00,*/ s01, s02, s03,
-                                                           /* s10,*/ s11, s12, s13);
+                                                         /*s0,*/ s1,  s2,  s3,
+                                                           s00, s01, s02, s03,
+                                                           s10, s11, s12, s13);
             }
         } else {
             const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
             if constexpr (sizeof...(I) > 0) {
                 k_bin_bcast<<>>(
                     src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
-                    /* s0, */ s1, s2, s3,
-                    /* s00,*/ s01, s02, s03,
-                    /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
+                  /*s0,*/ s1, s2,  s3,
+                    s00 ,s01, s02, s03,
+                    s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
             } else {
                 k_bin_bcast<<>>(
                     src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
-                    /* s0, */ s1, s2, s3,
-                    /* s00,*/ s01, s02, s03,
-                    /* s10,*/ s11, s12, s13);
+                  /*s0,*/ s1,  s2,  s3,
+                    s00, s01, s02, s03,
+                    s10, s11, s12, s13);
             }
         }
     }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh
index e800ee8f613..321357713c5 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh
@@ -85,6 +85,11 @@ static cudaError_t cudaMemsetAsyncReserve ( void* devPtr, int value, size_t coun
 #define GGML_CUDA_CC_TURING          750
 #define GGML_CUDA_CC_AMPERE          800
 #define GGML_CUDA_CC_ADA_LOVELACE    890
+// While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see
+// https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms
+#define GGML_CUDA_CC_BLACKWELL       1200
+#define GGML_CUDA_CC_DGX_SPARK       1210
+#define GGML_CUDA_CC_RUBIN           1300
 #define GGML_CUDA_CC_OFFSET_AMD      0x1000000
 #define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
 #define GGML_CUDA_CC_IS_NVIDIA(cc)   (cc < GGML_CUDA_CC_OFFSET_MTHREADS)
@@ -281,6 +286,10 @@ static const char * cu_get_error_str(CUresult err) {
 #define AMPERE_MMA_AVAILABLE
 #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 
+#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL && __CUDA_ARCH__ < GGML_CUDA_CC_RUBIN
+#    define BLACKWELL_MMA_AVAILABLE
+#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL
+
 #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 #define CP_ASYNC_AVAILABLE
 #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
@@ -289,6 +298,10 @@ static const char * cu_get_error_str(CUresult err) {
 #define FLASH_ATTN_AVAILABLE
 #endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
 
+#if defined(TURING_MMA_AVAILABLE)
+#define LDMATRIX_TRANS_AVAILABLE
+#endif // defined(TURING_MMA_AVAILABLE)
+
 static bool fp16_available(const int cc) {
     return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL ||
         (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
@@ -351,6 +364,11 @@ static bool cp_async_available(const int cc) {
     return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
 }
 
+static bool blackwell_mma_available(const int cc) {
+    return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_BLACKWELL &&
+           ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_RUBIN;
+}
+
 static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
 #if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
     return 64;
@@ -548,6 +566,86 @@ static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) {
 #endif // FP16_AVAILABLE
 }
 
+enum class block_reduce_method {
+    MAX,
+    SUM,
+};
+
+template
+struct block_reduce_policy;
+
+template 
+inline constexpr bool is_any = (std::is_same_v || ...);
+
+template
+inline constexpr bool ggml_cuda_dependent_false_v = false;
+
+template  struct block_reduce_policy {
+    static __device__ T reduce(T val) {
+        if constexpr(is_any) {
+            return warp_reduce_sum(val);
+        } else {
+            static_assert(ggml_cuda_dependent_false_v, "Unsupported type for block reduce sum");
+        }
+    }
+
+    static __device__ T sentinel() {
+        if constexpr (std::is_same_v) {
+            return 0.0f;
+        } else if constexpr (std::is_same_v) {
+            return make_float2(0.0f, 0.0f);
+        } else if constexpr (std::is_same_v) {
+            return make_half2(0.0f, 0.0f);
+        } else if constexpr (std::is_same_v) {
+            return 0;
+        } else {
+            static_assert(ggml_cuda_dependent_false_v, "Unsupported type for block reduce sum");
+        }
+    }
+};
+
+template  struct block_reduce_policy {
+    static __device__ T reduce(T val) {
+        if constexpr (is_any) {
+            return warp_reduce_max(val);
+        } else {
+            static_assert(ggml_cuda_dependent_false_v, "Unsupported type for block reduce max");
+        }
+    }
+
+    static __device__ T sentinel() {
+        if constexpr (std::is_same_v) {
+            return -INFINITY;
+        } else if constexpr (std::is_same_v) {
+            return make_half2(-INFINITY, -INFINITY);
+        } else {
+            static_assert(ggml_cuda_dependent_false_v, "Unsupported type for block reduce max");
+        }
+    }
+};
+
+template 
+static __device__ T block_reduce(T val, T * shared_vals) {
+    val                           = block_reduce_policy::reduce(val);
+    const unsigned int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
+    if (block_size > WARP_SIZE) {
+        assert((block_size <= 1024) && (block_size % WARP_SIZE) == 0);
+        const int warp_id = threadIdx.x / WARP_SIZE;
+        const int lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            shared_vals[warp_id] = val;
+        }
+        __syncthreads();
+        val = block_reduce_policy::sentinel();
+        if (lane_id < (static_cast(block_size) / WARP_SIZE)) {
+            val = shared_vals[lane_id];
+        }
+        return block_reduce_policy::reduce(val);
+    }
+
+    return val;
+}
+
 static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
 #ifdef FP16_AVAILABLE
 
@@ -736,6 +834,28 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
 #endif // CUDART_VERSION >= 12050
 }
 
+__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {
+    const uint8_t sign_bit = (x < 0.0f) << 3;
+    float         ax       = fabsf(x) * e;
+
+    // Positive LUT
+    static constexpr float pos_lut[8] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f };
+
+    int   best_i   = 0;
+    float best_err = fabsf(ax - pos_lut[0]);
+
+#pragma unroll
+    for (int i = 1; i < 8; ++i) {
+        const float err = fabsf(ax - pos_lut[i]);
+        if (err < best_err) {
+            best_err = err;
+            best_i   = i;
+        }
+    }
+
+    return static_cast(best_i | sign_bit);
+}
+
 // See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
 // Precompute mp (m' in the paper) and L such that division
 // can be computed using a multiply (high 32b of 64b result)
@@ -950,15 +1070,16 @@ struct ggml_cuda_device_info {
     int device_count;
 
     struct cuda_device_info {
-        int     cc;                 // compute capability
-        int     nsm;                // number of streaming multiprocessors
-        size_t  smpb;               // max. shared memory per block
-        size_t  smpbo;              // max. shared memory per block (with opt-in)
-        bool    integrated;         // Device is integrated as opposed to discrete
-        bool    vmm;                // virtual memory support
-        size_t  vmm_granularity;    // granularity of virtual memory
+        int     cc;                             // compute capability
+        int     nsm;                            // number of streaming multiprocessors
+        size_t  smpb;                           // max. shared memory per block
+        size_t  smpbo;                          // max. shared memory per block (with opt-in)
+        bool    integrated;                     // Device is integrated as opposed to discrete
+        bool    vmm;                            // virtual memory support
+        size_t  vmm_granularity;                // granularity of virtual memory
         size_t  total_vram;
-        int     warp_size;          // Number of threads in a dispatch
+        int     warp_size;                      // Number of threads in a dispatch
+        bool    supports_cooperative_launch;    // whether cooperative launch is supported
     };
 
     cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
@@ -1038,15 +1159,19 @@ struct ggml_tensor_extra_gpu {
 #define USE_CUDA_GRAPH
 #endif
 
-struct ggml_graph_node_properties {
-    void * node_address;
+struct ggml_cuda_graph_node_properties {
+    void * node_data;
     ggml_op node_op;
+    enum ggml_type node_type;
+    int32_t flags;
     int64_t ne[GGML_MAX_DIMS];
     size_t nb[GGML_MAX_DIMS];
-    void * src_address[GGML_MAX_SRC];
+    void * src_data[GGML_MAX_SRC];
     int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
 };
 
+static_assert(std::is_trivial::value, "ggml_cuda_graph_node_properties must be trivial");
+
 struct ggml_cuda_graph {
 #ifdef USE_CUDA_GRAPH
     ~ggml_cuda_graph() {
@@ -1061,12 +1186,20 @@ struct ggml_cuda_graph {
     cudaGraphExec_t instance = nullptr;
     size_t num_nodes = 0;
     std::vector nodes;
-    std::vector params;
     bool disable_due_to_gpu_arch = false;
-    bool disable_due_to_too_many_updates = false;
-    bool disable_due_to_failed_graph_capture = false;
-    int number_consecutive_updates = 0;
-    std::vector ggml_graph_properties;
+    bool warmup_complete = false;
+    std::vector props;
+
+    // these are extra tensors (inputs) that participate in the ggml graph but are not nodes
+    // they properties also have to match in order to be able to safely reuse a CUDA graph
+    // ref: https://github.com/ggml-org/llama.cpp/pull/18583
+    // ref: https://github.com/ggml-org/llama.cpp/pull/19165
+    std::vector extra;
+
+    bool is_enabled() const {
+        static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
+        return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env);
+    }
 #endif
 };
 
@@ -1229,10 +1362,44 @@ struct ggml_backend_cuda_context {
     cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
     cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
 
-    std::unique_ptr cuda_graph;
-
     int curr_stream_no = 0;
 
+#ifdef USE_CUDA_GRAPH
+    // Map from first_node_ptr to cuda_graph - allows multiple graphs per context
+    // when the computation is split across CPU/GPU (e.g., with --n-cpu-moe)
+    std::unordered_map> cuda_graphs;
+
+    ggml_cuda_graph * cuda_graph(const void * first_node_ptr) {
+        auto it = cuda_graphs.find(first_node_ptr);
+        if (it == cuda_graphs.end()) {
+            cuda_graphs[first_node_ptr] = std::make_unique();
+            return cuda_graphs[first_node_ptr].get();
+        }
+        return it->second.get();
+    }
+
+    // Check if any CUDA graph is enabled for this context (used by kernels that need to know
+    // if graphs are in use without having access to the specific graph key)
+    bool any_cuda_graph_enabled() const {
+        for (const auto & [key, graph] : cuda_graphs) {
+            if (graph && graph->is_enabled()) {
+                return true;
+            }
+        }
+        return false;
+    }
+
+    // Check if any CUDA graph has an instance for this context
+    bool any_cuda_graph_has_instance() const {
+        for (const auto & [key, graph] : cuda_graphs) {
+            if (graph && graph->instance != nullptr) {
+                return true;
+            }
+        }
+        return false;
+    }
+#endif // USE_CUDA_GRAPH
+
     explicit ggml_backend_cuda_context(int device) :
         device(device),
         name(GGML_CUDA_NAME + std::to_string(device)) {
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/convert.cu b/ml/backend/ggml/ggml/src/ggml-cuda/convert.cu
index ba3d4eeb880..b70492c7d6c 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/convert.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/convert.cu
@@ -7,7 +7,8 @@
 
 template 
 static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y,
-        const int64_t ne00, const int64_t ne01, const int64_t ne02,
+        const int64_t ne00, const int64_t ne01,
+        const int64_t ne0203, const uint3 ne02,
         const int64_t s01, const int64_t s02, const int64_t s03) {
     const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x);
 
@@ -15,24 +16,28 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
         return;
     }
 
-    const int64_t i01 = blockIdx.y;
-    const int64_t i02 = blockIdx.z % ne02;
-    const int64_t i03 = blockIdx.z / ne02;
+    for (int64_t i01 = blockIdx.y; i01 < ne01; i01 += gridDim.y) {
+        for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
+            const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
+            const int64_t i02 = dm.y;
+            const int64_t i03 = dm.x;
 
-    const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
+            const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
 
-    const int64_t ib = ibx0 + i00/qk; // block index
-    const int64_t iqs = (i00%qk)/qr; // quant index
-    const int64_t iybs = i00 - i00%qk; // y block start index
-    const int64_t y_offset = qr == 1 ? 1 : qk/2;
+            const int64_t ib = ibx0 + i00/qk; // block index
+            const int64_t iqs = (i00%qk)/qr; // quant index
+            const int64_t iybs = i00 - i00%qk; // y block start index
+            const int64_t y_offset = qr == 1 ? 1 : qk/2;
 
-    // dequantize
-    float2 v;
-    dequantize_kernel(vx, ib, iqs, v);
+            // dequantize
+            float2 v;
+            dequantize_kernel(vx, ib, iqs, v);
 
-    const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
-    y[iy0 + 0]        = ggml_cuda_cast(v.x);
-    y[iy0 + y_offset] = ggml_cuda_cast(v.y);
+            const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs;
+            y[iy0 + 0]        = ggml_cuda_cast(v.x);
+            y[iy0 + y_offset] = ggml_cuda_cast(v.y);
+        }
+    }
 }
 
 template 
@@ -485,9 +490,11 @@ template 
 static void dequantize_block_cuda(const void * vx, dst_t * y,
         const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
         const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
-    const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03);
+    const int64_t ne0203 = ne02*ne03;
+    const uint3 ne02_fdv = init_fastdiv_values(ne02);
+    const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), (int)std::min(ne01, (int64_t)65535), (int)std::min(ne0203, (int64_t)65535));
     dequantize_block<<>>
-        (vx, y, ne00, ne01, ne02, s01, s02, s03);
+        (vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
 }
 
 template 
@@ -612,7 +619,8 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t
 
 template 
 static __global__ void convert_unary(
-        const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
+        const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
+        const int64_t ne0203, const uint3 ne02,
         const int64_t s01, const int64_t s02, const int64_t s03) {
     const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
 
@@ -620,24 +628,30 @@ static __global__ void convert_unary(
         return;
     }
 
-    const int64_t i01 = blockIdx.y;
-    const int64_t i02 = blockIdx.z % ne02;
-    const int64_t i03 = blockIdx.z / ne02;
-
     const src_t * x = (const src_t *) vx;
 
-    const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
-    const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00;
-    y[iy] = ggml_cuda_cast(x[ix]);
+    for (int64_t i01 = blockIdx.y; i01 < ne01; i01 += gridDim.y) {
+        for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
+            const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
+            const int64_t i02 = dm.y;
+            const int64_t i03 = dm.x;
+
+            const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
+            const int64_t iy = (i0203*ne01 + i01)*ne00 + i00;
+            y[iy] = ggml_cuda_cast(x[ix]);
+        }
+    }
 }
 
 template 
 static void convert_unary_cuda(const void * vx, dst_t * y,
         const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
         const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
-    const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03);
+    const int64_t ne0203 = ne02*ne03;
+    const uint3 ne02_fdv = init_fastdiv_values(ne02);
+    const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, (int)std::min(ne01, (int64_t)65535), (int)std::min(ne0203, (int64_t)65535));
     convert_unary<<>>
-        (vx, y, ne00, ne01, ne02, s01, s02, s03);
+        (vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
 }
 
 template 
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu b/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu
index 0e53ecc39b2..178e82d7613 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu
@@ -12,11 +12,11 @@ const int CUDA_CPY_BLOCK_NM = 8;     // block size of 3rd dimension if available
 const int CUDA_CPY_BLOCK_ROWS = 8;   // block dimension for marching through rows
 
 template 
-static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne,
-                                  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-                                  const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
-                                  const int nb12, const int nb13) {
-    const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
+static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne,
+                                  const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+                                  const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
+                                  const int64_t nb12, const int64_t nb13) {
+    const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= ne) {
         return;
@@ -40,10 +40,10 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne,
 }
 
 template 
-static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int ne,
-                               const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-                               const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
-                               const int nb12, const int nb13) {
+static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int64_t ne,
+                               const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+                               const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
+                               const int64_t nb12, const int64_t nb13) {
 
     const T* src = reinterpret_cast(cx);
     T* dst = reinterpret_cast(cdst);
@@ -117,60 +117,60 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
 }
 
 template 
-static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
-                                 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-                                 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
-                                 const int nb12, const int nb13) {
-    const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
+static __global__ void cpy_f32_q(const char * cx, char * cdst, const int64_t ne,
+                                 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+                                 const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
+                                 const int64_t nb12, const int64_t nb13) {
+    const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk;
 
     if (i >= ne) {
         return;
     }
 
-    const int i03 = i/(ne00 * ne01 * ne02);
-    const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
-    const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
-    const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
-    const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+    const int64_t i03 = i/(ne00 * ne01 * ne02);
+    const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+    const int64_t i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
+    const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+    const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
 
-    const int i13 = i/(ne10 * ne11 * ne12);
-    const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
-    const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
-    const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
-    const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
+    const int64_t i13 = i/(ne10 * ne11 * ne12);
+    const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+    const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+    const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+    const int64_t dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
 
     cpy_blck(cx + x_offset, cdst + dst_offset);
 }
 
 template 
-static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
-                                 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-                                 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
-                                 const int nb12, const int nb13) {
-    const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
+static __global__ void cpy_q_f32(const char * cx, char * cdst, const int64_t ne,
+                                 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+                                 const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
+                                 const int64_t nb12, const int64_t nb13) {
+    const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk;
 
     if (i >= ne) {
         return;
     }
 
-    const int i03 = i/(ne00 * ne01 * ne02);
-    const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
-    const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
-    const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
-    const int x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+    const int64_t i03 = i/(ne00 * ne01 * ne02);
+    const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+    const int64_t i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
+    const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+    const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
 
-    const int i13 = i/(ne10 * ne11 * ne12);
-    const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
-    const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
-    const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
-    const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
+    const int64_t i13 = i/(ne10 * ne11 * ne12);
+    const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+    const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+    const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+    const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
 
     cpy_blck(cx + x_offset, cdst + dst_offset);
 }
 
 template
 static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) {
-    const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
+    const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= ne) {
         return;
@@ -188,19 +188,20 @@ static void ggml_cpy_scalar_contiguous_cuda(
 cudaStream_t stream) {
 
     const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+    GGML_ASSERT(num_blocks < UINT_MAX);
     cpy_scalar_contiguous<<>>
         (cx, cdst, ne);
 }
 
 template
 static void ggml_cpy_scalar_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const char * cx, char * cdst, const int64_t ne,
+    const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
 
     if (transposed) {
         GGML_ASSERT(ne == ne00*ne01*ne02);  // ne[3] is 1 assumed
-        int ne00n, ne01n, ne02n;
+        int64_t ne00n, ne01n, ne02n;
         if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here
             ne00n = ne00;
             ne01n = ne01;
@@ -211,143 +212,159 @@ static void ggml_cpy_scalar_cuda(
             ne02n = 1;
         }
 
-        dim3 dimGrid( (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
-                      (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
-                      (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
+        int64_t grid_x = (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D;
+        int64_t grid_y = (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D;
+        int64_t grid_z = (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM;
+        GGML_ASSERT(grid_x < UINT_MAX);
+        GGML_ASSERT(grid_y < USHRT_MAX);
+        GGML_ASSERT(grid_z < USHRT_MAX);
+        dim3 dimGrid(grid_x, grid_y, grid_z);
         dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
         cpy_scalar_transpose<<>>
             (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
     } else {
-        const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+        const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+        GGML_ASSERT(num_blocks < UINT_MAX);
         cpy_scalar><<>>
             (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
     }
 }
 
 static void ggml_cpy_f32_q8_0_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const char * cx, char * cdst, const int64_t ne,
+    const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
 
     GGML_ASSERT(ne % QK8_0 == 0);
-    const int num_blocks = ne / QK8_0;
+    const int64_t num_blocks = ne / QK8_0;
+    GGML_ASSERT(num_blocks < UINT_MAX);
     cpy_f32_q<<>>
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_q8_0_f32_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const char * cx, char * cdst, const int64_t ne,
+    const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
 
-    const int num_blocks = ne;
+    const int64_t num_blocks = ne;
+    GGML_ASSERT(num_blocks < UINT_MAX);
     cpy_q_f32<<>>
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_f32_q4_0_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const char * cx, char * cdst, const int64_t ne,
+    const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
 
     GGML_ASSERT(ne % QK4_0 == 0);
-    const int num_blocks = ne / QK4_0;
+    const int64_t num_blocks = ne / QK4_0;
+    GGML_ASSERT(num_blocks < UINT_MAX);
     cpy_f32_q<<>>
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_q4_0_f32_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02,
-    const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12,
-    const int nb10, const int nb11, const int nb12, const int nb13,
+    const char * cx, char * cdst, const int64_t ne,
+    const int64_t ne00, const int64_t ne01, const int64_t ne02,
+    const int64_t nb00, const int64_t nb01, const int64_t nb02,
+    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
+    const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
     cudaStream_t stream) {
-    const int num_blocks = ne;
+    const int64_t num_blocks = ne;
+    GGML_ASSERT(num_blocks < UINT_MAX);
     cpy_q_f32, QK4_0><<>>(
         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
          ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_f32_q4_1_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const char * cx, char * cdst, const int64_t ne,
+    const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
 
     GGML_ASSERT(ne % QK4_1 == 0);
-    const int num_blocks = ne / QK4_1;
+    const int64_t num_blocks = ne / QK4_1;
+    GGML_ASSERT(num_blocks < UINT_MAX);
     cpy_f32_q<<>>
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_q4_1_f32_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02,
-    const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12,
-    const int nb10, const int nb11, const int nb12, const int nb13,
+    const char * cx, char * cdst, const int64_t ne,
+    const int64_t ne00, const int64_t ne01, const int64_t ne02,
+    const int64_t nb00, const int64_t nb01, const int64_t nb02,
+    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
+    const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
     cudaStream_t stream) {
-    const int num_blocks = ne;
+    const int64_t num_blocks = ne;
+    GGML_ASSERT(num_blocks < UINT_MAX);
     cpy_q_f32, QK4_1><<>>(
         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
          ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_f32_q5_0_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const char * cx, char * cdst, const int64_t ne,
+    const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
 
     GGML_ASSERT(ne % QK5_0 == 0);
-    const int num_blocks = ne / QK5_0;
+    const int64_t num_blocks = ne / QK5_0;
+    GGML_ASSERT(num_blocks < UINT_MAX);
     cpy_f32_q<<>>
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_q5_0_f32_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02,
-    const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12,
-    const int nb10, const int nb11, const int nb12, const int nb13,
+    const char * cx, char * cdst, const int64_t ne,
+    const int64_t ne00, const int64_t ne01, const int64_t ne02,
+    const int64_t nb00, const int64_t nb01, const int64_t nb02,
+    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
+    const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
     cudaStream_t stream) {
-    const int num_blocks = ne;
+    const int64_t num_blocks = ne;
+    GGML_ASSERT(num_blocks < UINT_MAX);
     cpy_q_f32, QK5_0><<>>(
         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
         ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_f32_q5_1_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const char * cx, char * cdst, const int64_t ne,
+    const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
 
     GGML_ASSERT(ne % QK5_1 == 0);
-    const int num_blocks = ne / QK5_1;
+    const int64_t num_blocks = ne / QK5_1;
+    GGML_ASSERT(num_blocks < UINT_MAX);
     cpy_f32_q<<>>
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_q5_1_f32_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02,
-    const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12,
-    const int nb10, const int nb11, const int nb12, const int nb13,
+    const char * cx, char * cdst, const int64_t ne,
+    const int64_t ne00, const int64_t ne01, const int64_t ne02,
+    const int64_t nb00, const int64_t nb01, const int64_t nb02,
+    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
+    const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
     cudaStream_t stream) {
-    const int num_blocks = ne;
+    const int64_t num_blocks = ne;
+    GGML_ASSERT(num_blocks < UINT_MAX);
     cpy_q_f32, QK5_1><<>>(
         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
         ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_f32_iq4_nl_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const char * cx, char * cdst, const int64_t ne,
+    const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
 
     GGML_ASSERT(ne % QK4_NL == 0);
-    const int num_blocks = ne / QK4_NL;
+    const int64_t num_blocks = ne / QK4_NL;
+    GGML_ASSERT(num_blocks < UINT_MAX);
     cpy_f32_q<<>>
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
@@ -393,9 +410,6 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
     const int64_t ne = ggml_nelements(src0);
     GGML_ASSERT(ne == ggml_nelements(src1));
 
-    GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
-    GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
-
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
     const int64_t ne02 = src0->ne[2];
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/cumsum.cu b/ml/backend/ggml/ggml/src/ggml-cuda/cumsum.cu
index d2f2def8bdc..def9c32955f 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/cumsum.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/cumsum.cu
@@ -5,7 +5,7 @@
 #include "ggml.h"
 
 #ifdef GGML_CUDA_USE_CUB
-#   include 
+#   include 
 #endif // GGML_CUDA_USE_CUB
 
 template
@@ -16,12 +16,14 @@ static __global__ void cumsum_cub_kernel(
         const int64_t  s01, const int64_t  s02, const int64_t  s03,
         const int64_t   s1,  const int64_t   s2,  const int64_t   s3) {
 #ifdef GGML_CUDA_USE_CUB
-    using BlockScan = cub::BlockScan;
+    using BlockScanT = cub::BlockScan;
 
-    __shared__ typename BlockScan::TempStorage temp_storage;
-    __shared__ T block_carry;      // carry from previous tile
+    __shared__ typename BlockScanT::TempStorage temp_storage;
+    __shared__ T block_carry;
 
     const int tid = threadIdx.x;
+    constexpr int UNROLL_FACTOR = 4;
+    constexpr int TILE_SIZE = BLOCK_SIZE * UNROLL_FACTOR;
 
     const int64_t i1 = blockIdx.x;
     const int64_t i2 = blockIdx.y;
@@ -39,37 +41,47 @@ static __global__ void cumsum_cub_kernel(
     }
     __syncthreads();
 
-    for (int64_t start = 0; start < ne00; start += BLOCK_SIZE) {
-        int64_t idx = start + tid;
-        T x = (idx < ne00) ? src_row[idx] : T(0);
+    for (int64_t start = 0; start < ne00; start += TILE_SIZE) {
+        T items[UNROLL_FACTOR];
+        T thread_sum = T(0);
 
-        T inclusive;
-        T block_total;
-        BlockScan(temp_storage).InclusiveSum(x, inclusive, block_total);
+#pragma unroll
+        for (int i = 0; i < UNROLL_FACTOR; i++) {
+            int64_t idx = start + tid * UNROLL_FACTOR + i;
+            T val = (idx < ne00) ? src_row[idx] : T(0);
+            thread_sum += val;
+            items[i] = thread_sum;
+        }
 
+        // Block-wide scan on thread sums
+        T thread_prefix;
+        T block_total;
+        BlockScanT(temp_storage).InclusiveSum(thread_sum, thread_prefix, block_total);
         __syncthreads();
 
-        T final_val = inclusive + block_carry;
-
-        // store result
-        if (idx < ne00) {
-            dst_row[idx] = final_val;
+        // Add offset to each item and store
+        T thread_offset = thread_prefix - thread_sum + block_carry;
+#pragma unroll
+        for (int i = 0; i < UNROLL_FACTOR; i++) {
+            int64_t idx = start + tid * UNROLL_FACTOR + i;
+            if (idx < ne00) {
+                dst_row[idx] = items[i] + thread_offset;
+            }
         }
 
         __syncthreads();
 
+        // Update carry for next tile
         if (tid == 0) {
             block_carry += block_total;
         }
-
-        __syncthreads();
     }
 #else
     NO_DEVICE_CODE;
 #endif // GGML_CUDA_USE_CUB
 }
 
-// Fallback kernel implementation (original)
+// Fallback kernel implementation
 template
 static __global__ void cumsum_kernel(
         const T * src, T * dst,
@@ -86,10 +98,10 @@ static __global__ void cumsum_kernel(
     const int warps_per_block = blockDim.x / warp_size;
 
     extern __shared__ float smem[];
-    float * s_vals = smem;
-    float * s_warp_sums = smem + blockDim.x;
-    float * s_carry = smem + blockDim.x + warps_per_block;
-    float * s_chunk_total = s_carry + 1;
+    float *                 s_vals        = smem;
+    float *                 s_warp_sums   = smem + blockDim.x;
+    float *                 s_carry       = smem + blockDim.x + warps_per_block;
+    float *                 s_chunk_total = s_carry + 1;
 
     // Initialize carry
     if (tid == 0) {
@@ -107,21 +119,39 @@ static __global__ void cumsum_kernel(
     const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
     T       * dst_row = dst + i1 * s1  + i2 * s2  + i3 * s3;
 
-    for (int64_t start = 0; start < ne00; start += blockDim.x) {
-        int64_t idx = start + tid;
-        float val = (idx < ne00) ? ggml_cuda_cast(src_row[idx]) : 0.0f;
+    // register blocking: process 4 elements per thread to hide latency
+    // and reduce synchronization overhead
+    constexpr int num_unroll = 4;
+    T             temp[num_unroll];
+
+    for (int64_t i = 0; i < ne00; i += num_unroll * blockDim.x) {
+        int64_t idx = i + tid * num_unroll;
+
+        // thread local sequential scan
+        temp[0] = (idx < ne00 ? src_row[idx] : T(0));
+#pragma unroll
+        for (int64_t j = 1; j < num_unroll; j++) {
+            temp[j] = temp[j - 1];
+            if (idx + j < ne00) {
+                temp[j] += src_row[idx + j];
+            } else {
+                temp[j] += 0;
+            }
+        }
 
-        // 1. Warp inclusive scan
+        // last emenent is sum of all values assigned to thread
+        float val = (idx < ne00) ? ggml_cuda_cast(temp[num_unroll - 1]) : 0.0f;
+
+        // Warp inclusive scan
         val = warp_prefix_inclusive_sum(val);
         s_vals[tid] = val;
 
-        // Store warp total
         if (lane == warp_size - 1) {
             s_warp_sums[warp] = val;
         }
         __syncthreads();
 
-        // 2. Exclusive scan of warp sums (warp 0 only)
+        // Exclusive scan of warp sums (warp 0 only)
         if (warp == 0) {
             float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
             float inc = warp_prefix_inclusive_sum(w);
@@ -134,24 +164,55 @@ static __global__ void cumsum_kernel(
         }
         __syncthreads();
 
+        // write back results
         float carry = *s_carry;
-        float final_val = s_vals[tid] + s_warp_sums[warp] + carry;
-        if (idx < ne00) {
-            dst_row[idx] = ggml_cuda_cast(final_val);
+        // calculate sum offset for this thread
+        float final_val_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1];
+
+#pragma unroll
+        for (int32_t j = 0; j < num_unroll; j++) {
+            if (idx + j < ne00) {
+                dst_row[idx + j] = temp[j] + ggml_cuda_cast(final_val_offset);
+            }
         }
+
         __syncthreads();
 
         // Update carry for next chunk
         if (tid == 0) {
             *s_carry += *s_chunk_total;
         }
-        __syncthreads();
     }
 }
 
+#ifdef GGML_CUDA_USE_CUB
+template 
+static void cumsum_cub(ggml_cuda_pool & pool,
+                       const T *        src,
+                       T *              dst,
+                       int64_t          ne,
+                       cudaStream_t     stream) {
+    size_t tmp_size = 0;
+
+    // Query how much temp storage CUDA UnBound (CUB) needs
+    cub::DeviceScan::InclusiveSum(nullptr,   // d_temp_storage (null = just query size)
+                                  tmp_size,  // reference to size (will be set by CUB)
+                                  src,       // input pointer
+                                  dst,       // output pointer
+                                  ne,        // number of elements
+                                  stream     // CUDA stream to use
+    );
+
+    ggml_cuda_pool_alloc tmp_alloc(pool, tmp_size);
+
+    // Perform the inclusive scan
+    cub::DeviceScan::InclusiveSum((void *) tmp_alloc.get(), tmp_size, src, dst, ne, stream);
+}
+#endif // GGML_CUDA_USE_CUB
+
 template
 static void cumsum_cuda(
-        const T * src, T * dst,
+        [[maybe_unused]] ggml_backend_cuda_context & ctx, const T * src, T * dst,
         const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
         const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
         const int64_t  nb0,  const int64_t nb1, const int64_t  nb2, const int64_t  nb3,
@@ -165,6 +226,15 @@ static void cumsum_cuda(
 
     if (is_contiguous) {
         use_cub = true;
+        const int64_t nrows = ne01 * ne02 * ne03;
+        // TODO: Compare with DeviceSegmentedScan::InclusiveSegmentedSum for nrows > 1 once InclusiveSegmentedSum is released
+        // Heuristics were determined as part of https://github.com/ggml-org/llama.cpp/pull/17004
+        if (((nrows == 1) && (ne00 > 1024)) || (ne00 / nrows > 4096)) {
+            for (int i=0; i= 1024) {
         cumsum_cub_kernel<<>>(
             src, dst,
             ne00, ne01, ne02, ne03,
@@ -203,7 +273,7 @@ void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
         case GGML_TYPE_F32:
             {
                 cumsum_cuda(
-                    (const float *)src0->data, (float *)dst->data,
+                    ctx, (const float *)src0->data, (float *)dst->data,
                     src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
                     src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
                     dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh
index 8dc82a9d3b8..b6a7460da83 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh
@@ -11,10 +11,12 @@
 #define SOFTMAX_FTZ_THRESHOLD -20.0f                   // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
 
 // log(2) = 0.6931, by adding this to the KQ maximum used for the softmax the numerical range representable
-//     by the VKQ accumulators is effectively being shifted up by a factor of 8.
+//     by the VKQ accumulators is effectively being shifted up by a factor of 2.
 // This reduces issues with numerical overflow but also causes larger values to be flushed to zero.
 // However, as the output from FlashAttention will usually be used as an input for a matrix multiplication this should be negligible.
-#define FATTN_KQ_MAX_OFFSET 0.6931f
+// Still, the value range should be shifted as much as necessary but as little as possible.
+// The macro on the following line shifts it by a factor of 2**3=8, as was needed to fix https://github.com/ggml-org/llama.cpp/issues/18606 .
+#define FATTN_KQ_MAX_OFFSET (3.0f*0.6931f)
 
 typedef void (* fattn_kernel_t)(
         const char * __restrict__ Q,
@@ -57,7 +59,7 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
 
 #pragma unroll
     for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
-        half2 tmp[cpy_ne];
+        __align__(16) half2 tmp[cpy_ne];
         ggml_cuda_memcpy_1(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
 #pragma unroll
         for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
@@ -307,7 +309,7 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_
         ggml_cuda_memcpy_1(dst, (const half *) vx + i0);
     } else if constexpr (std::is_same_v) {
         static_assert(ne % 2 == 0, "bad ne");
-        half2 tmp[ne/2];
+        __align__(16) half2 tmp[ne/2];
         ggml_cuda_memcpy_1(tmp, (const half *) vx + i0);
         float2 * dst_f2 = (float2 *) dst;
 #pragma unroll
@@ -627,8 +629,8 @@ static __global__ void flash_attn_mask_to_KV_max(
 template // D == head size
 __launch_bounds__(D, 1)
 static __global__ void flash_attn_stream_k_fixup(
-        float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11,
-        const int nbatch_fa) {
+        float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03,
+        const int ne11, const int ne12, const int nbatch_fa) {
     constexpr int ncols = ncols1*ncols2;
 
     const int bidx0 = blockIdx.x;
@@ -639,11 +641,14 @@ static __global__ void flash_attn_stream_k_fixup(
 
     const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
 
-    const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
-    const int iter_j = (ne01 + (ncols1    - 1)) / ncols1;
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
 
-    const int kbc0      = int64_t(bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
-    const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
+    const int iter_k     = (ne11      + (nbatch_fa - 1)) / nbatch_fa;
+    const int iter_j     = (ne01      + (ncols1    - 1)) / ncols1;
+    const int iter_z_gqa = (gqa_ratio + (ncols2    - 1)) / ncols2;
+
+    const int kbc0      = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
+    const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
 
     const bool did_not_have_any_data   = kbc0 == kbc0_stop;
     const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -652,15 +657,19 @@ static __global__ void flash_attn_stream_k_fixup(
         return;
     }
 
-    const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
-    const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
-    const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
+    // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
+    const int sequence =  kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);
+    const int z_KV     = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
+    const int zt_gqa   = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
+    const int jt       = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
+
+    const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
 
-    if (jt*ncols1 + j >= ne01) {
+    if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
         return;
     }
 
-    dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
+    dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
 
     // Load the partial result that needs a fixup:
     float dst_val = 0.0f;
@@ -679,7 +688,7 @@ static __global__ void flash_attn_stream_k_fixup(
     int bidx = bidx0 - 1;
     int kbc_stop = kbc0;
     while(true) {
-        const int kbc = int64_t(bidx)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
+        const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
         if (kbc == kbc_stop) { // Did not have any data.
             bidx--;
             kbc_stop = kbc;
@@ -776,13 +785,11 @@ void launch_fattn(
 ) {
     constexpr int ncols = ncols1 * ncols2;
 
-    const bool is_mla = DV == 512; // TODO better parameterization
-
     const ggml_tensor * Q = dst->src[0];
     const ggml_tensor * K = dst->src[1];
     const ggml_tensor * V = dst->src[2];
 
-    GGML_ASSERT(V || is_mla);
+    const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs));
 
     const ggml_tensor * mask  = dst->src[3];
     const ggml_tensor * sinks = dst->src[4];
@@ -792,9 +799,9 @@ void launch_fattn(
     GGML_ASSERT(Q->type == GGML_TYPE_F32);
     GGML_ASSERT(KQV->type == GGML_TYPE_F32);
 
-    GGML_ASSERT(      Q->nb[0] == ggml_element_size(Q));
-    GGML_ASSERT(      K->nb[0] == ggml_element_size(K));
-    GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
+    GGML_ASSERT(Q->nb[0] == ggml_element_size(Q));
+    GGML_ASSERT(K->nb[0] == ggml_element_size(K));
+    GGML_ASSERT(V->nb[0] == ggml_element_size(V));
 
     GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
 
@@ -815,10 +822,10 @@ void launch_fattn(
     size_t nb12 = K->nb[2];
     size_t nb13 = K->nb[3];
 
-    const char * V_data = V ? (const char *) V->data : nullptr;
-    size_t nb21 = V ? V->nb[1] : nb11;
-    size_t nb22 = V ? V->nb[2] : nb12;
-    size_t nb23 = V ? V->nb[3] : nb13;
+    const char * V_data = (const char *) V->data;
+    size_t nb21 = V->nb[1];
+    size_t nb22 = V->nb[2];
+    size_t nb23 = V->nb[3];
 
     if (need_f16_K && K->type != GGML_TYPE_F16) {
         const size_t bs = ggml_blck_size(K->type);
@@ -847,36 +854,45 @@ void launch_fattn(
         K_data = (char *) K_f16.ptr;
     }
 
-    if (V && need_f16_V && V->type != GGML_TYPE_F16) {
-        const size_t bs = ggml_blck_size(V->type);
-        const size_t ts = ggml_type_size(V->type);
-
-        V_f16.alloc(ggml_nelements(V));
-        if (ggml_is_contiguously_allocated(V)) {
-            to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
-            to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
-            V_data = (char *) V_f16.ptr;
-
-            nb21 = nb21*bs*sizeof(half)/ts;
-            nb22 = nb22*bs*sizeof(half)/ts;
-            nb23 = nb23*bs*sizeof(half)/ts;
+    if (need_f16_V && V->type != GGML_TYPE_F16) {
+        if (V_is_K_view) {
+            V_data = K_data;
+            nb21   = nb11;
+            nb22   = nb12;
+            nb23   = nb13;
         } else {
-            GGML_ASSERT(V->nb[0] == ts);
-            to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
-            const int64_t s01 = nb21 / ts;
-            const int64_t s02 = nb22 / ts;
-            const int64_t s03 = nb23 / ts;
-            to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
-
-            nb21 = V->ne[0] * sizeof(half);
-            nb22 = V->ne[1] * nb21;
-            nb23 = V->ne[2] * nb22;
+            const size_t bs = ggml_blck_size(V->type);
+            const size_t ts = ggml_type_size(V->type);
+
+            V_f16.alloc(ggml_nelements(V));
+            if (ggml_is_contiguously_allocated(V)) {
+                to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
+                to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
+                V_data = (char *) V_f16.ptr;
+
+                nb21 = nb21*bs*sizeof(half)/ts;
+                nb22 = nb22*bs*sizeof(half)/ts;
+                nb23 = nb23*bs*sizeof(half)/ts;
+            } else {
+                GGML_ASSERT(V->nb[0] == ts);
+                to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
+                const int64_t s01 = nb21 / ts;
+                const int64_t s02 = nb22 / ts;
+                const int64_t s03 = nb23 / ts;
+                to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
+
+                nb21 = V->ne[0] * sizeof(half);
+                nb22 = V->ne[1] * nb21;
+                nb23 = V->ne[2] * nb22;
+            }
+            V_data = (char *) V_f16.ptr;
         }
-        V_data = (char *) V_f16.ptr;
     }
 
-    const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
-    const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
+    const int ntiles_x     = ((Q->ne[1] + ncols1 - 1) / ncols1);
+    const int gqa_ratio    = Q->ne[2] / K->ne[2];
+    const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2);
+    const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];
 
     // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
     // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
@@ -912,13 +928,15 @@ void launch_fattn(
 
         const int nblocks_stream_k = max_blocks;
 
-        const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75;
+        const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75;
 
         blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
         blocks_num.y = 1;
         blocks_num.z = 1;
 
-        dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
+        if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
+            dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
+        }
     } else {
         const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
 
@@ -949,7 +967,7 @@ void launch_fattn(
 
         blocks_num.x = ntiles_x;
         blocks_num.y = parallel_blocks;
-        blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3];
+        blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3];
 
         if (parallel_blocks > 1) {
             dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
@@ -1003,7 +1021,7 @@ void launch_fattn(
 
             flash_attn_stream_k_fixup
                 <<>>
-                ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa);
+                ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa);
         }
     } else if (parallel_blocks > 1) {
         const dim3 block_dim_combine(DV, 1, 1);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh
index 3dea2205e55..beb7e32e4fc 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh
@@ -66,8 +66,7 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2,  32, 128, 128, 128, 2, true);
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2,  32, 128, 128, 128, 2, true);
 
-    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  4,  64, 4,  32, 288, 256, 128, 1, false);
-    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  8,  64, 4,  32, 288, 256, 128, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  8,  64, 4,  32, 288, 256, 128, 1, false);
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16,  64, 4,  32, 288, 256, 128, 1, false);
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2,  32, 160, 128, 128, 1, false);
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1,  32, 160, 128, 128, 1, false);
@@ -81,8 +80,7 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2,  64, 128, 128,  64, 2, true);
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2,  64, 128, 128,  64, 2, true);
 
-    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  4,  64, 4,  32,  96,  64, 128, 1, false);
-    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  8,  64, 4,  32,  96,  64, 128, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  8,  64, 4,  32,  96,  64, 128, 1, false);
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16,  64, 4,  32,  96,  64, 128, 1, false);
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2,  32, 160, 128, 128, 1, false);
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1,  32, 160, 128, 128, 1, false);
@@ -91,8 +89,7 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
 }
 
 static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
-    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  4,  64, 4,  32, 288, 256,  64, 1, false);
-    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  8,  64, 4,  32, 288, 256,  64, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  8,  64, 4,  32, 288, 256,  64, 1, false);
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16,  64, 4,  32, 288, 256,  64, 1, false);
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2,  32, 160, 128,  64, 1, false);
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1,  32, 160, 128,  64, 1, false);
@@ -101,6 +98,57 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
     return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
 }
 
+static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_rdna(const int DKQ, const int DV, const int ncols) {
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2,  64, 128, 128, 128, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2,  64, 128, 128,  64, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2,  64, 128, 128,  64, 2, true);
+
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16,  64, 4,  32,  96,  64, 128, 1, false);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2,  32, 160, 128, 128, 1, false);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1,  32, 160, 128, 128, 1, false);
+
+    // TODO tune specifically for RDNA
+    return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
+}
+
+static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_cdna(const int DKQ, const int DV, const int ncols) {
+    // Conservative configs for CDNA (MI100+): 64KB LDS, wavefront64, nstages=1 (no cp.async).
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64,  8, 128, 2, 128,  32,  32,  32, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64, 16, 128, 2,  64,  32,  32,  32, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64, 32, 128, 2,  64,  32,  32,  32, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64, 64, 256, 2,  64,  32,  32,  32, 1, true);
+
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80,  8, 128, 2, 128,  40,  40,  40, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80, 16, 128, 2,  64,  40,  40,  40, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80, 32, 128, 2,  64,  40,  40,  40, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80, 64, 256, 2,  64,  40,  40,  40, 1, true);
+
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96,  8, 128, 2, 128,  48,  48,  48, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96, 16, 128, 2,  64,  48,  48,  48, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96, 32, 128, 2,  64,  48,  48,  48, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96, 64, 256, 2,  64,  48,  48,  48, 1, true);
+
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112,  8, 128, 2, 128,  56,  56,  56, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2,  64,  56,  56,  56, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2,  64,  56,  56,  56, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 256, 2,  64,  56,  56,  56, 1, true);
+
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128,  8, 128, 2, 128,  64,  64,  64, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2,  64,  64,  64,  64, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2,  64,  64,  64,  64, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 256, 2,  64,  64,  64,  64, 1, true);
+
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256,  8,  64, 4,  64, 128, 128, 128, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16,  64, 4,  32, 128, 128, 128, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2,  32, 128, 128, 128, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 256, 2,  32, 128, 128, 128, 1, true);
+
+    // Fallback for unsupported DKQ values (e.g. 576). Must return non-zero values to satisfy
+    // compile-time static_asserts even though the kernel guard prevents runtime execution.
+    // nthreads=256 gives nwarps=4 (warp_size=64) or 8 (warp_size=32), nbatch_fa=128 satisfies np*16 divisibility.
+    return fattn_mma_config(256, 1, 128, 4, 4, 4, 1, false);
+}
+
 static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
     if (ampere_mma_available(cc)) {
         return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
@@ -108,6 +156,12 @@ static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, c
     if (turing_mma_available(cc)) {
         return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
     }
+    if (amd_mfma_available(cc)) {
+        return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols);
+    }
+    if (amd_wmma_available(cc)) {
+        return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
+    }
     GGML_ASSERT(volta_mma_available(cc));
     return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
 }
@@ -117,8 +171,12 @@ static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(cons
     return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
 #elif defined(TURING_MMA_AVAILABLE)
     return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
+#elif defined(AMD_MFMA_AVAILABLE)
+    return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols);
 #elif defined(VOLTA_MMA_AVAILABLE)
     return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
+#elif defined(AMD_WMMA_AVAILABLE)
+    return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
 #else
     GGML_UNUSED_VARS(DKQ, DV, ncols);
     return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
@@ -189,6 +247,23 @@ static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ,
     return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg;
 }
 
+static constexpr __device__ int get_cols_per_thread() {
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+    return 1; // AMD has a single column per thread.
+#else
+    return 2; // This is specifically KQ columns, Volta only has a single VKQ column.
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+}
+
+static __host__ int get_cols_per_warp(const int cc) {
+    if (turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc)) {
+        return 16;
+    } else {
+        // Volta
+        return 32;
+    }
+}
+
 // ------------------------------------------------------------------------------------------------------------------
 
 static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) {
@@ -209,6 +284,7 @@ static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, c
 template
 static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
         const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) {
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     // K/V data is loaded with decreasing granularity for D for better memory bandwidth.
     // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
     if constexpr (use_cp_async) {
@@ -220,10 +296,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
         const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
 
         auto load = [&] __device__ (auto n) {
-            const int stride_k = WARP_SIZE >> n;
-            const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
+            const int stride_k = warp_size >> n;
+            const int k0_start = stride_k == warp_size ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
             const int k0_stop  =                             chunks_per_row - chunks_per_row % (1*stride_k);
-            const int stride_i = WARP_SIZE / stride_k;
+            const int stride_i = warp_size / stride_k;
 
             if (k0_start == k0_stop) {
                 return;
@@ -231,7 +307,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
 
 #pragma unroll
             for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
-                const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+                const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
 
                 if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
                     break;
@@ -239,7 +315,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
 
 #pragma unroll
                 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
-                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+                    const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
 
                     cp_async_cg_16(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
                 }
@@ -255,10 +331,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
     } else {
         // TODO use ggml_cuda_memcpy_1
         auto load = [&] __device__ (const int n) {
-            const int stride_k = WARP_SIZE >> n;
-            const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
+            const int stride_k = warp_size >> n;
+            const int k0_start = stride_k == warp_size ? 0 : D2 - D2 % (2*stride_k);
             const int k0_stop  =                             D2 - D2 % (1*stride_k);
-            const int stride_i = WARP_SIZE / stride_k;
+            const int stride_i = warp_size / stride_k;
 
             if (k0_start == k0_stop) {
                 return;
@@ -266,7 +342,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
 
 #pragma unroll
             for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
-                const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+                const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
 
                 if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
                     break;
@@ -274,7 +350,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
 
 #pragma unroll
                 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
-                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+                    const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
 
                     tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f);
                 }
@@ -292,18 +368,19 @@ template= 32 ? nbatch_fa * sizeof(half) : 64;
-        constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
+        constexpr int cols_per_warp = 8*warp_size/nbatch_fa;
         constexpr int stride_j = nwarps * cols_per_warp;
 
         const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);
 
 #pragma unroll
         for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
-            const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
+            const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp);
             const int j_vram = fastmodulo(j0 + j_sram, ne01);
 
             if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
@@ -325,25 +402,25 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
             }
 
 #pragma unroll
-            for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) {
+            for (int i0 = 0; i0 < nbatch_fa; i0 += warp_size) {
                 const int i = i0 + threadIdx.x;
 
                 tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f);
             }
         }
-    } else if constexpr (nbatch_fa < 2*WARP_SIZE) {
-        constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
+    } else if constexpr (nbatch_fa < 2*warp_size) {
+        constexpr int cols_per_warp = 2*warp_size/nbatch_fa;
         constexpr int stride_j = nwarps * cols_per_warp;
 #pragma unroll
         for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
-            const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
+            const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp);
             const int j_vram = fastmodulo(j0 + j_sram, ne01);
 
             if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
                 break;
             }
 
-            const int i = threadIdx.x % (WARP_SIZE/cols_per_warp);
+            const int i = threadIdx.x % (warp_size/cols_per_warp);
 
             ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i);
         }
@@ -358,7 +435,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
             }
 
 #pragma unroll
-            for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) {
+            for (int i0 = 0; i0 < nbatch_fa; i0 += 2*warp_size) {
                 const int i = i0 + 2*threadIdx.x;
 
                 ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i);
@@ -368,7 +445,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
 }
 
 template
 static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         const float2 * const __restrict__ Q_f2,
@@ -396,10 +473,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         const int jt,
         const int kb0,
         const int k_VKQ_sup) {
-#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
+    constexpr int  warp_size       = ggml_cuda_get_physical_warp_size();
     constexpr int  ncols           = ncols1 * ncols2;
     constexpr int  cols_per_warp   = T_B_KQ::I;
-    constexpr int  cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
+    constexpr int  cols_per_thread = get_cols_per_thread();
     constexpr int  np              = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
     constexpr int  nbatch_fa       = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
     constexpr int  nbatch_K2       = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
@@ -410,19 +488,20 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     constexpr int stride_tile_Q = DKQ/2     + 4;
     constexpr int stride_tile_K = nbatch_K2 + 4;
 
-    static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
-    constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
+    constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
 
     const int k_VKQ_0 = kb0 * nbatch_fa;
 #if defined(TURING_MMA_AVAILABLE)
     T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))];
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+    T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
 #else // Volta
     T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
 #endif // defined(TURING_MMA_AVAILABLE)
 
     if constexpr (nstages > 1) {
         static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline");
-        static_assert(!mla, "multi-stage loading not implemented for MLA");
+        static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
         static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
         constexpr bool use_cp_async = true;
         cp_async_wait_all();
@@ -437,8 +516,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         }
     }
 
+    // For MLA K and V have the same data.
+    // Therefore, iterate over K in reverse and later re-use the data if possible.
 #pragma unroll
-    for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) {
+    for (int k0_start = (DKQ/2-1) - (DKQ/2-1) % nbatch_K2; k0_start >= 0; k0_start -= nbatch_K2) {
         const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
         const int k0_diff = k0_stop - k0_start;
 
@@ -464,8 +545,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                     if constexpr (cols_per_warp == 8) {
                         mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
                     } else {
-                        // Wide version of KQ_C is column-major => swap A and B.
+                        // Wide version of KQ_C is column-major
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+                        // AMD matrix C is column-major.
+                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
+#else
+                        // swap A and B for CUDA.
                         mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A);
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
                     }
                 }
             }
@@ -485,13 +572,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                         mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
                     } else {
                         // Wide version of KQ_C is column-major
-#if defined(AMD_WMMA_AVAILABLE)
-                        // RDNA matrix C is column-major.
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+                        // AMD matrix C is column-major.
                         mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
 #else
                         // swap A and B for CUDA.
                         mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
-#endif // defined(AMD_WMMA_AVAILABLE)
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
                     }
                 }
             }
@@ -543,8 +630,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) {
 #pragma unroll
             for (int l = 0; l < T_C_KQ::ne; ++l) {
-                if (!oob_check || k0 + T_C_KQ::get_i(l) < k_VKQ_sup) {
-                    KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
+                if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+                    constexpr int KQ_idx = 0;
+#else
+                    // Turing + Volta:
+                    const int KQ_idx = l % 2;
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+                    KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
                 }
             }
         }
@@ -554,7 +647,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         for (int col = 0; col < cols_per_thread; ++col) {
 #pragma unroll
             for (int offset = 16; offset >= 4; offset >>= 1) {
-                KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
+                KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size));
             }
         }
 
@@ -564,8 +657,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
 #pragma unroll
             for (int l = 0; l < T_C_KQ::ne; ++l) {
                 if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
-                    KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[l % 2]);
-                    KQ_rowsum_add[l % 2] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+                    constexpr int KQ_idx = 0;
+#else
+                    // Turing + Volta:
+                    const int KQ_idx = l % 2;
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+                    KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]);
+                    KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
                 } else {
                     KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f;
                 }
@@ -595,9 +694,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
 #pragma unroll
             for (int l = 0; l < T_C_KQ::ne; ++l) {
-                if (!oob_check || k0 + T_C_KQ::get_j(l) < k_VKQ_sup) {
+                if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+                    constexpr int KQ_idx = 0;
+#else
                     // Turing + Volta:
-                    KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
+                    const int KQ_idx = (l/2) % 2;
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+                    KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
                 }
             }
         }
@@ -608,14 +712,22 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
             // Values per KQ column are spread across 4 threads:
             constexpr int offset_first = 2;
             constexpr int offset_last  = 1;
-#else
+#elif defined(AMD_MFMA_AVAILABLE)
+            // MFMA: 4 threads per Q column (threadIdx.x % 16 == col, spaced by 16).
+            constexpr int offset_first = 32;
+            constexpr int offset_last  = 16;
+#elif defined(AMD_WMMA_AVAILABLE)
+            // Values per KQ column are spread across 2 threads:
+            constexpr int offset_first = 16;
+            constexpr int offset_last  = 16;
+#else // Volta
             // Values per KQ column are spread across 2 threads:
             constexpr int offset_first = 2;
             constexpr int offset_last  = 2;
 #endif // defined(TURING_MMA_AVAILABLE)
 #pragma unroll
             for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
-                KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
+                KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size));
             }
         }
 
@@ -624,10 +736,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
 #pragma unroll
             for (int l = 0; l < T_C_KQ::ne; ++l) {
-                // Turing + Volta:
                 if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
-                    KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[(l/2) % 2]);
-                    KQ_rowsum_add[(l/2) % 2] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+                    constexpr int KQ_idx = 0;
+#else
+                    // Turing + Volta:
+                    const int KQ_idx = (l/2) % 2;
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+                    KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]);
+                    KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
                 } else {
                     KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f;
                 }
@@ -651,7 +768,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
 
 #if defined(TURING_MMA_AVAILABLE)
         if constexpr (cols_per_warp == 8) {
-            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
+            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
 #pragma unroll
             for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
 #pragma unroll
@@ -672,6 +789,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                 }
             }
         }
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+        const half2 KQ_max_scale_h2 = make_half2(
+            KQ_max_scale[0], KQ_max_scale[0]);
+#pragma unroll
+        for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
+#pragma unroll
+            for (int l = 0; l < T_C_VKQ::ne; ++l) {
+                VKQ_C[i].x[l] *= KQ_max_scale_h2;
+            }
+        }
 #else // Volta
         const half2 KQ_max_scale_h2 = make_half2(
             KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]);
@@ -700,6 +827,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     }
 
     if constexpr (nstages > 1) {
+        static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
         // Preload K tile for next iteration:
         constexpr bool use_cp_async = true;
         cp_async_wait_all();
@@ -715,19 +843,20 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     }
 
 
-    // For MLA K and V have the same data.
-    // Therefore, iterate over V in reverse and re-use the data if possible.
-    static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
-    constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
+#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
+    T_A_VKQ A_identity;
+    make_identity_mat(A_identity);
+#endif // defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
 
     // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
 #pragma unroll
-    for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
-        const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
-        const int i0_diff  = i0_stop - i0_start;
+    for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) {
+        static_assert(DV % (2*nbatch_V2) == 0, "bad loop size");
+        const int i0_stop = i0_start + 2*nbatch_V2;
+        const int i0_diff = i0_stop - i0_start;
 
         if constexpr (nstages <= 1) {
-            if (i0_start < reusable_cutoff) {
+            if (!V_is_K_view || i0_stop > 2*nbatch_K2) {
                 constexpr bool use_cp_async = nstages == 1;
                 flash_attn_ext_f16_load_tile
                     (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup);
@@ -737,9 +866,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                 __syncthreads();
             }
         }
-        const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
+        const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2;
 
-#if defined(TURING_MMA_AVAILABLE)
+#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
         constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
 #pragma unroll
         for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
@@ -749,12 +878,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                 const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J;
 
                 T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
+#if defined(LDMATRIX_TRANS_AVAILABLE)
                 load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
+#elif defined(AMD_MFMA_AVAILABLE)
+                // MFMA A register layout: A_mat[i=lane%16][k=4*(lane/16)+reg].
+                // Normal load gives A_mat[seq][dv] but we need A_mat[dv][seq] = V^T.
+                // Load with transposed addressing: 4 strided half loads.
+                {
+                    const half2 * xs0 = tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2;
+                    const half * xs0_h = (const half *) xs0;
+                    const int stride_h = stride_tile_V * 2; // stride in half units
+                    half * A_h = (half *) A.x;
+#pragma unroll
+                    for (int l = 0; l < 4; ++l) {
+                        A_h[l] = xs0_h[(4*(threadIdx.x / 16) + l) * stride_h + threadIdx.x % 16];
+                    }
+                }
+#else
+                // TODO: Try to transpose tile_V when loading gmem to smem.
+                // Use mma to transpose T_A_VKQ for RDNA.
+                T_A_VKQ A_trans;
+                load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
+                mma(A, A_trans, A_identity);
+#endif // defined(LDMATRIX_TRANS_AVAILABLE)
                 if constexpr (T_B_KQ::I == 8) {
                     mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
                 } else {
-                    // Wide version of VKQ_C is column-major => swap A and B.
+                    // Wide version of VKQ_C is column-major.
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+                    // AMD matrix C is column-major.
+                    mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
+#else
+                    // swap A and B for CUDA.
                     mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A);
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
                 }
             }
         }
@@ -773,7 +930,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                 mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A);
             }
         }
-#endif // defined(TURING_MMA_AVAILABLE)
+#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
 
         if constexpr (nstages <= 1) {
             __syncthreads(); // Only needed if tile_K == tile_V.
@@ -786,7 +943,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         tile_Q, tile_K, tile_V, tile_mask,
         Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
     NO_DEVICE_CODE;
-#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
 }
 
 #if defined(TURING_MMA_AVAILABLE)
@@ -806,6 +963,15 @@ template<> struct mma_tile_sizes<8> {
     using T_B_VKQ = tile< 8,  8, half2>; // column-major
     using T_C_VKQ = tile<16,  4, half2>; // row-major
 };
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+template struct mma_tile_sizes {
+    using T_A_KQ  = tile<16,  8, half2>; // row-major
+    using T_B_KQ  = tile<16,  8, half2>; // column-major
+    using T_C_KQ  = tile<16, 16, float>; // column-major
+    using T_A_VKQ = tile<16,  8, half2>; // row-major
+    using T_B_VKQ = tile<16,  8, half2>; // column-major
+    using T_C_VKQ = tile<16,  8, half2>; // column-major
+};
 #else // Volta
 template struct mma_tile_sizes {
     using T_A_KQ  = tile< 8,  4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
@@ -817,7 +983,7 @@ template struct mma_tile_sizes {
 };
 #endif // defined(TURING_MMA_AVAILABLE)
 
-template
+template
 static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const float2 * const __restrict__ Q_f2,
         const half2  * const __restrict__ K_h2,
@@ -831,6 +997,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const float logit_softcap,
         const uint3 ne01,
         const int ne02,
+        const int gqa_ratio,
         const int ne11,
         const int stride_Q1,
         const int stride_Q2,
@@ -838,11 +1005,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const int stride_V,
         const int stride_mask,
         const int jt,
+        const int zt_gqa,
         const int kb0_start,
         const int kb0_stop) {
-#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     constexpr int ncols = ncols1 * ncols2;
     using     T_A_KQ    = typename mma_tile_sizes::T_A_KQ;
     using     T_B_KQ    = typename mma_tile_sizes::T_B_KQ;
@@ -852,7 +1021,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
     using     T_C_VKQ   = typename mma_tile_sizes::T_C_VKQ;
 
     constexpr int  cols_per_warp   = T_B_KQ::I;
-    constexpr int  cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
+    constexpr int  cols_per_thread = get_cols_per_thread();
     constexpr int  np              = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
     constexpr int  nbatch_fa       = ggml_cuda_fattn_mma_get_nbatch_fa     (DKQ, DV, ncols);
     constexpr int  nbatch_K2       = ggml_cuda_fattn_mma_get_nbatch_K2     (DKQ, DV, ncols);
@@ -871,8 +1040,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
     constexpr int stride_tile_Q = DKQ/2     + 4;
     constexpr int stride_tile_K = nbatch_K2 + 4;
 
-    static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
-    constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
+    constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
     constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
 
     extern __shared__ half2 tile_Q[];
@@ -883,6 +1051,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
     T_B_KQ    Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)];
 #if defined(TURING_MMA_AVAILABLE)
     T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)];
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+    T_C_VKQ VKQ_C[                                     DV/(2*T_C_VKQ::J)];
 #else // Volta
     T_C_VKQ VKQ_C[                                     DV/(2*T_C_VKQ::J)];
 #endif // defined(TURING_MMA_AVAILABLE)
@@ -899,10 +1069,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
     // The loading is done with decreasing granularity for D for better memory bandwidth.
     const half2 scale_h2 = make_half2(scale, scale);
 #pragma unroll
-    for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
-        const int k0_start  = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
+    for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) {
+        const int k0_start  = stride_k == warp_size ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
         const int k0_stop   =                             DKQ/2 - (DKQ/2) % (1*stride_k);
-        const int stride_jc = WARP_SIZE / stride_k;
+        const int stride_jc = warp_size / stride_k;
 
         if (k0_start == k0_stop) {
             continue;
@@ -910,7 +1080,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 
 #pragma unroll
         for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) {
-            const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+            const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
 
             if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) {
                 break;
@@ -919,10 +1089,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
             const int j = jc / ncols2;
             const int c = jc % ncols2;
 
-            if (jt*ncols1 + j < int(ne01.z)) {
+            if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) {
 #pragma unroll
                 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
-                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+                    const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
 
                     const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
                     tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y);
@@ -930,7 +1100,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
             } else {
 #pragma unroll
                 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
-                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+                    const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
 
                     tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f);
                 }
@@ -974,7 +1144,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
             constexpr bool last_iter = false;
             constexpr int  k_VKQ_sup = nbatch_fa;
             flash_attn_ext_f16_iter
-                
                 (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
                  ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
@@ -983,7 +1153,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         constexpr bool last_iter = true;
         const     int  k_VKQ_sup = ne11 - kb0*nbatch_fa;
         flash_attn_ext_f16_iter
-            
             (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
              ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
@@ -994,7 +1164,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
             constexpr bool last_iter = false;
             constexpr int  k_VKQ_sup = nbatch_fa;
             flash_attn_ext_f16_iter
-                
                 (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
                  ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
@@ -1003,7 +1173,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         constexpr bool last_iter = true;
         constexpr int  k_VKQ_sup = nbatch_fa;
         flash_attn_ext_f16_iter
-            
             (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
              ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
@@ -1022,6 +1192,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         // The partial sums are spread across 8/4 threads.
         constexpr int offset_first = cols_per_warp == 8 ? 16 : 2;
         constexpr int offset_last  = cols_per_warp == 8 ?  4 : 1;
+#elif defined(AMD_MFMA_AVAILABLE)
+        // The partial sums are spread across 4 threads (wavefront64, 16 cols).
+        constexpr int offset_first = 32;
+        constexpr int offset_last  = 16;
+#elif defined(AMD_WMMA_AVAILABLE)
+        // The partial sums are spread across 2 threads.
+        constexpr int offset_first = 16;
+        constexpr int offset_last  = 16;
 #else // Volta
         // The partial sums are spread across 2 threads.
         constexpr int offset_first = 2;
@@ -1031,7 +1209,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         for (int col = 0; col < cols_per_thread; ++col) {
 #pragma unroll
             for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
-                KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
+                KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, warp_size);
             }
         }
     }
@@ -1059,7 +1237,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 
 #if defined(TURING_MMA_AVAILABLE)
         if constexpr (cols_per_warp == 8) {
-            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
+            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
 #pragma unroll
             for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
 #pragma unroll
@@ -1080,6 +1258,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
                 }
             }
         }
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+        const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]);
+#pragma unroll
+        for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
+#pragma unroll
+            for (int l = 0; l < T_C_VKQ::ne; ++l) {
+                VKQ_C[i].x[l] *= KQ_max_scale_h2;
+            }
+        }
 #else // Volta
         const int col = (threadIdx.x / 2) % 2;
         const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
@@ -1131,6 +1318,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4);
         const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]);
         const bool thread_should_write = threadIdx.x % 4 < cols_per_thread;
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+        const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0);
+        const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]);
+        const bool thread_should_write = threadIdx.x / 16 < cols_per_thread;
 #else // Volta
         const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2);
         const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]);
@@ -1161,14 +1352,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         // Warps with threadIdx.y % np != 0 must NOT return early.
         // All threads must return simultaneously to avoid race conditions with work on the next tile.
 
-        constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1;
+        constexpr int nmeta = np*cols_per_warp >= warp_size ? np*cols_per_warp/warp_size : 1;
 
-        const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
+        const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < warp_size ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
         float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2;
         float2 meta[nmeta];
 #pragma unroll
         for (int imeta = 0; imeta < nmeta; ++imeta) {
-            meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2];
+            meta[imeta] = meta_ptr[imeta * warp_size * tile_stride/2];
         }
 
         float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
@@ -1178,8 +1369,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         }
 #pragma unroll
         for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
-            if (offset < WARP_SIZE) {
-                KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
+            if (offset < warp_size) {
+                KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, warp_size));
             }
         }
 
@@ -1196,8 +1387,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         }
 #pragma unroll
         for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
-            if (offset < WARP_SIZE) {
-                KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
+            if (offset < warp_size) {
+                KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, warp_size);
             }
         }
 
@@ -1206,19 +1397,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         // Write back combined meta data:
 #pragma unroll
         for (int imeta = 0; imeta < nmeta; ++imeta) {
-            if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) {
+            if (np*cols_per_warp >= warp_size || threadIdx.x < np*cols_per_warp) {
                 // Combined KQ max scale + rowsum.
-                meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
+                meta_ptr[imeta * warp_size * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
             }
         }
 
         // Combined KQ max + rowsum.
-        static_assert(cols_per_warp <= WARP_SIZE);
-        if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
+        static_assert(cols_per_warp <= warp_size);
+        if (needs_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
             float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
             dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
         }
-        if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
+        if (is_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
             float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
             dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
         }
@@ -1266,10 +1457,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
             float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2));
 
 #pragma unroll
-            for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
-                const int k0_start  = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
+            for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) {
+                const int k0_start  = stride_k == warp_size ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
                 const int k0_stop   =                             nbatch_combine - nbatch_combine % (1*stride_k);
-                const int stride_jc = WARP_SIZE / stride_k;
+                const int stride_jc = warp_size / stride_k;
 
                 if (k0_start == k0_stop) {
                     continue;
@@ -1277,7 +1468,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 
 #pragma unroll
                 for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
-                    const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+                    const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
 
                     if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
                         break;
@@ -1288,14 +1479,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
                     const int j_dst = jc_dst / ncols2;
                     const int c_dst = jc_dst % ncols2;
 
-                    if (!is_fixup && jt*ncols1 + j_dst >= int(ne01.z)) {
+                    if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt_gqa*ncols2 + c_dst >= gqa_ratio))) {
                         continue;
                     }
 
                     const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine;
 #pragma unroll
                     for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
-                        const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+                        const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
 
                         float2 dstk_val = make_float2(0.0f, 0.0f);
 #pragma unroll
@@ -1327,14 +1518,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
     }
 #else
     GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup,
-        scale, slope, logit_softcap, ne01, ne02,
+        scale, slope, logit_softcap, ne01, ne02, gqa_ratio,
         stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
         jt, kb0_start, kb0_stop);
     NO_DEVICE_CODE;
-#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
 }
 
-template
+template
 __launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2))
 static __global__ void flash_attn_ext_f16(
         const char * __restrict__ Q,
@@ -1358,7 +1549,7 @@ static __global__ void flash_attn_ext_f16(
                             const int32_t nb21, const int32_t nb22, const int64_t nb23,
                             const int32_t ne31, const int32_t ne32, const int32_t ne33,
                             const int32_t nb31, const int32_t nb32, const int64_t nb33) {
-#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
+#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
 
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
@@ -1379,12 +1570,25 @@ static __global__ void flash_attn_ext_f16(
     }
 #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
 
-    static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
+#if defined(AMD_WMMA_AVAILABLE)
+    if (ncols1*ncols2 > 32 || ncols1*ncols2 < 16 || DKQ > 128 || ncols2 == 1) {
+        NO_DEVICE_CODE;
+        return;
+    }
+#endif // defined(AMD_WMMA_AVAILABLE)
 
+#if defined(AMD_MFMA_AVAILABLE)
+    if (DKQ != 64 && DKQ != 80 && DKQ != 96 && DKQ != 112 && DKQ != 128) {
+        NO_DEVICE_CODE;
+        return;
+    }
+#endif // defined(AMD_MFMA_AVAILABLE)
+
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     constexpr int ncols     = ncols1 * ncols2;
     constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
     constexpr int nthreads  = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols);
-    constexpr int nwarps    = nthreads / WARP_SIZE;
+    constexpr int nwarps    = nthreads / warp_size;
 
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
 
@@ -1393,14 +1597,15 @@ static __global__ void flash_attn_ext_f16(
     const int stride_K    = nb11 / sizeof(half2);
     const int stride_mask = nb31 / sizeof(half);
 
-    const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
+    const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2);
 
-    const int iter_k = (ne11   + (nbatch_fa - 1)) / nbatch_fa;
-    const int iter_j = (ne01.z + (ncols1    - 1)) / ncols1;
+    const int iter_k     = (ne11      + (nbatch_fa - 1)) / nbatch_fa;
+    const int iter_j     = (ne01.z    + (ncols1    - 1)) / ncols1;
+    const int iter_z_gqa = (gqa_ratio + (ncols2    - 1)) / ncols2;
 
     // kbc == k block continuous, current index in continuous ijk space.
-    int       kbc      = int64_t(blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
-    const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
+    int       kbc      = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
+    const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
 
     // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
     // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -1411,22 +1616,24 @@ static __global__ void flash_attn_ext_f16(
     int kb0_stop  = min(iter_k, kb0_start + kbc_stop - kbc);
 
     while (kbc < kbc_stop && kb0_stop == iter_k) {
-        const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
-        const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
-        const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
+        // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
+        const int sequence =  kbc /(iter_k*iter_j*iter_z_gqa*ne12);
+        const int z_KV     = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
+        const int zt_gqa   = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
+        const int jt       = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
 
-        const int head0 = zt * ncols2;
+        const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
 
-        const float2 * Q_f2   = (const float2 *) (Q + nb03*sequence + nb02* head0);
-        const half2  * K_h2   = (const half2  *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
+        const float2 * Q_f2   = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
+        const half2  * K_h2   = (const half2  *) (K + nb13*sequence + nb12*z_KV);
         const half   * mask_h = ncols2 == 1 && !mask ? nullptr :
             (const half *) (mask + nb33*(sequence % ne33));
-        float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
+        float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
 
-        const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
-        const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
+        const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
+        const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
 
-        const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
+        const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
 
         if (KV_max) {
             kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
@@ -1434,14 +1641,14 @@ static __global__ void flash_attn_ext_f16(
         constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
         if (kb0_start == 0) {
             constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
-            flash_attn_ext_f16_process_tile
+            flash_attn_ext_f16_process_tile
                 (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
-                 ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
+                 ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
         } else {
             constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
-            flash_attn_ext_f16_process_tile
+            flash_attn_ext_f16_process_tile
                 (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
-                 ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
+                 ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
         }
 
         kbc += iter_k;
@@ -1455,22 +1662,24 @@ static __global__ void flash_attn_ext_f16(
         return;
     }
 
-    const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
-    const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
-    const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
+    // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index.
+    const int sequence =  kbc /(iter_k*iter_j*iter_z_gqa*ne12);
+    const int z_KV     = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
+    const int zt_gqa   = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
+    const int jt       = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
 
-    const int head0 = zt * ncols2;
+    const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
 
-    const float2 * Q_f2   = (const float2 *) (Q + nb03*sequence + nb02* head0);
-    const half2  * K_h2   = (const half2  *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
+    const float2 * Q_f2   = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
+    const half2  * K_h2   = (const half2  *) (K + nb13*sequence + nb12*z_KV);
     const half   * mask_h = ncols2 == 1 && !mask ? nullptr :
         (const half *) (mask + nb33*(sequence % ne33));
-    float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
+    float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
 
-    const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
-    const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
+    const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
+    const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
 
-    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
+    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
 
     if (KV_max) {
         kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
@@ -1478,9 +1687,9 @@ static __global__ void flash_attn_ext_f16(
 
     constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
     constexpr bool needs_fixup = false;
-    flash_attn_ext_f16_process_tile
+    flash_attn_ext_f16_process_tile
         (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
-         ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
+         ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
 #else
     GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
         max_bias, m0, m1, n_head_log2, logit_softcap,
@@ -1492,7 +1701,7 @@ static __global__ void flash_attn_ext_f16(
               ne31, ne32, ne33,
               nb31, nb32, nb33);
     NO_DEVICE_CODE;
-#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
+#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
 }
 
 template 
@@ -1511,10 +1720,11 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
     const bool Q_in_reg       = ggml_cuda_fattn_mma_get_Q_in_reg      (DKQ, DV, ncols, cc);
     const int  nstages        = ggml_cuda_fattn_mma_get_nstages       (DKQ, DV, ncols1, ncols2, cc);
 
-    const int cols_per_warp = std::min(ncols, turing_mma_available(cc) ? 16 : 32);
-    const int nwarps        = nthreads / WARP_SIZE;
+    const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc));
+    const int warp_size_host = ggml_cuda_info().devices[ctx.device].warp_size;
+    const int nwarps         = nthreads / warp_size_host;
 
-    constexpr bool mla = DKQ == 576;
+    constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu
 
     const size_t nbytes_shared_KV_1stage = nbatch_fa            * std::max(nbatch_K2 + 4,  nbatch_V2 + 4) * sizeof(half2);
     const size_t nbytes_shared_KV_2stage = nbatch_fa            *         (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
@@ -1531,33 +1741,38 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
     float logit_softcap;
     memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
 
+#if defined(GGML_USE_HIP)
+    using fattn_kernel_ptr_t = const void*;
+#else
+    using fattn_kernel_ptr_t = fattn_kernel_t;
+#endif // defined(GGML_USE_HIP)
     fattn_kernel_t fattn_kernel;
     if (logit_softcap == 0.0f) {
         constexpr bool use_logit_softcap = false;
-        fattn_kernel = flash_attn_ext_f16;
+        fattn_kernel = flash_attn_ext_f16;
 
-#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_MUSA)
         static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
         if (!shared_memory_limit_raised[id]) {
-            CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
+            CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
             shared_memory_limit_raised[id] = true;
         }
-#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+#endif // !defined(GGML_USE_MUSA)
     } else {
         constexpr bool use_logit_softcap = true;
-        fattn_kernel = flash_attn_ext_f16;
+        fattn_kernel = flash_attn_ext_f16;
 
-#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_MUSA)
         static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
         if (!shared_memory_limit_raised[id]) {
-            CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
+            CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
             shared_memory_limit_raised[id] = true;
         }
-#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+#endif // !defined(GGML_USE_MUSA)
     }
 
     launch_fattn
-        (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true);
+        (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true, warp_size_host);
 }
 
 
@@ -1609,3 +1824,5 @@ extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
 extern DECL_FATTN_MMA_F16_CASE(576, 512,  4,  4);
 extern DECL_FATTN_MMA_F16_CASE(576, 512,  8,  4);
 extern DECL_FATTN_MMA_F16_CASE(576, 512, 16,  4);
+extern DECL_FATTN_MMA_F16_CASE(576, 512,  1, 32);
+extern DECL_FATTN_MMA_F16_CASE(576, 512,  2, 32);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh
index 371be74421c..f3fa80ab23d 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh
@@ -351,7 +351,7 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile(
                 for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
                     const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne;
 
-                    const half2 zero[cpy_ne] = {{0.0f, 0.0f}};
+                    const __align__(16) half2 zero[cpy_ne] = {{0.0f, 0.0f}};
                     ggml_cuda_memcpy_1(
                         tile_KV + i*(J/2 + J_padding) + j,
                         !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
@@ -402,11 +402,11 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile(
                     const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2);
 
                     const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}};
-                    half2 tmp_h2[cpy_ne/2];
+                    __align__(16) half2 tmp_h2[cpy_ne/2];
                     ggml_cuda_memcpy_1(
                         tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
 
-                    float2 tmp_f2[cpy_ne/2];
+                    __align__(16) float2 tmp_f2[cpy_ne/2];
 #pragma unroll
                     for (int l = 0; l < cpy_ne/2; ++l) {
                         tmp_f2[l] = __half22float2(tmp_h2[l]);
@@ -453,14 +453,14 @@ static __device__ __forceinline__ void flash_attn_tile_iter_KQ(
     static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K");
 #pragma unroll
     for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) {
-        half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne];
-        half2 Q_k[cpw][cpy_ne];
+        __align__(16) half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne];
+        __align__(16) half2 Q_k[cpw][cpy_ne];
 #else
     static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K");
 #pragma unroll
     for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) {
-        float K_k[nbatch_fa/(np*warp_size)][cpy_ne];
-        float Q_k[cpw][cpy_ne];
+        __align__(16) float K_k[nbatch_fa/(np*warp_size)][cpy_ne];
+        __align__(16) float Q_k[cpw][cpy_ne];
 #endif // FAST_FP16_AVAILABLE
 
 #pragma unroll
@@ -610,9 +610,9 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
 #pragma unroll
     for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) {
 #ifdef FAST_FP16_AVAILABLE
-        half  tmp[nbatch_fa/(np*warp_size)][KQ_cs];
+        __align__(16) half  tmp[nbatch_fa/(np*warp_size)][KQ_cs];
 #else
-        float tmp[nbatch_fa/(np*warp_size)][KQ_cs];
+        __align__(16) float tmp[nbatch_fa/(np*warp_size)][KQ_cs];
 #endif // FAST_FP16_AVAILABLE
 
 #pragma unroll
@@ -672,8 +672,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
 #ifdef FAST_FP16_AVAILABLE
 #pragma unroll
         for (int k1 = 0; k1 < nbatch_V; k1 += np) {
-            half2 V_k[(DVp/2)/warp_size];
-            half2 KQ_k[cpw];
+            __align__(16) half2 V_k[(DVp/2)/warp_size];
+            __align__(16) half2 KQ_k[cpw];
 
             constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
 #pragma unroll
@@ -684,7 +684,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
             for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {
                 const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs);
 
-                half tmp[KQ_cs];
+                __align__(16) half tmp[KQ_cs];
                 ggml_cuda_memcpy_1(
                     &tmp, KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs);
 #pragma unroll
@@ -704,8 +704,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
 #else
 #pragma unroll
         for (int k1 = 0; k1 < nbatch_V; k1 += np) {
-            float2 V_k[(DVp/2)/warp_size];
-            float  KQ_k[cpw];
+            __align__(16) float2 V_k[(DVp/2)/warp_size];
+            __align__(16) float  KQ_k[cpw];
 
             constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
 #pragma unroll
@@ -829,12 +829,12 @@ static __global__ void flash_attn_tile(
     __shared__ half2 Q_tmp[ncols * DKQ/2];
     __shared__ half2 KV_tmp[nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV];
     __shared__ half  KQ[ncols * nbatch_fa];
-    half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
+    __align__(16) half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
 #else
     __shared__ float Q_tmp[ncols * DKQ];
     __shared__ float KV_tmp[nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV];
     __shared__ float KQ[ncols * nbatch_fa];
-    float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
+    __align__(16) float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
 #endif // FAST_FP16_AVAILABLE
 
     float KQ_max[cpw];
@@ -857,7 +857,7 @@ static __global__ void flash_attn_tile(
 #pragma unroll
         for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) {
             if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) {
-                float tmp_f[cpy_ne_D] = {0.0f};
+                __align__(16) float tmp_f[cpy_ne_D] = {0.0f};
                 ggml_cuda_memcpy_1
                     (tmp_f, &Q_f[c*(nb02/sizeof(float)) + fastmodulo(col_Q_0 + j, ne01)*(nb01/sizeof(float))
                                  + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]);
@@ -868,7 +868,7 @@ static __global__ void flash_attn_tile(
                 }
 
 #ifdef FAST_FP16_AVAILABLE
-                half2 tmp_h2[cpy_ne_D/2];
+                __align__(16) half2 tmp_h2[cpy_ne_D/2];
 #pragma unroll
                 for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
                     tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
@@ -967,7 +967,7 @@ static __global__ void flash_attn_tile(
             constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;
 #pragma unroll
             for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
-                half2 tmp[cpy_ne_D];
+                __align__(16) half2 tmp[cpy_ne_D];
                 ggml_cuda_memcpy_1(tmp, &VKQ_combine[(threadIdx.y + ip)*(DVp/2) + i0 + threadIdx.x*cpy_ne_D]);
 #pragma unroll
                 for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
@@ -978,7 +978,7 @@ static __global__ void flash_attn_tile(
             constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
 #pragma unroll
             for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
-                float tmp[cpy_ne_D];
+                __align__(16) float tmp[cpy_ne_D];
                 ggml_cuda_memcpy_1(tmp, &VKQ_combine[(threadIdx.y + ip)*DVp + i0 + threadIdx.x*cpy_ne_D]);
 #pragma unroll
                 for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
@@ -1041,7 +1041,7 @@ static __global__ void flash_attn_tile(
         constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
 #pragma unroll
         for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
-            float2 tmp[cpy_ne_D];
+            __align__(16) float2 tmp[cpy_ne_D];
 #pragma unroll
             for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
                 tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]);
@@ -1186,8 +1186,10 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
     GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
     const int gqa_ratio = Q->ne[2] / K->ne[2];
 
+    // On NVIDIA (Pascal and older) the GQA optimizations seem to be detrimental in some cases.
+    // However, for DKQ == 576, DV == 512 only the kernel variant with GQA optimizations is implemented.
     const bool nvidia = GGML_CUDA_CC_IS_NVIDIA(ggml_cuda_info().devices[ggml_cuda_get_device()].cc);
-    const int gqa_limit = nvidia && gqa_ratio <= 4 ? 16 : INT_MAX;
+    const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
     const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
 
     if constexpr (DV == 512) {
@@ -1195,10 +1197,6 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
             launch_fattn_tile_switch_ncols1(ctx, dst);
             return;
         }
-        if (use_gqa_opt && gqa_ratio % 8 == 0) {
-            launch_fattn_tile_switch_ncols1(ctx, dst);
-            return;
-        }
         if (use_gqa_opt && gqa_ratio % 4 == 0) {
             launch_fattn_tile_switch_ncols1(ctx, dst);
             return;
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec.cuh
index 4d167b95a07..3f4a78cc6e5 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec.cuh
@@ -10,7 +10,7 @@ static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() {
     return 128;
 }
 
-// Currenlty llvm with the amdgcn target dose not support unrolling loops
+// Currenlty llvm with the amdgcn target does not support unrolling loops
 // that contain a break that can not be resolved at compile time.
 #ifdef __clang__
 #pragma clang diagnostic push
@@ -132,7 +132,7 @@ static __global__ void flash_attn_ext_vec(
 #ifdef V_DOT2_F32_F16_AVAILABLE
     half2  Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely.
 #else
-    float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
+    __align__(16) float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
 #endif // V_DOT2_F32_F16_AVAILABLE
     int    Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
     float2  Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
@@ -200,7 +200,7 @@ static __global__ void flash_attn_ext_vec(
             for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
                 const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
 
-                float2 tmp[cpy_ne] = {{0.0f, 0.0f}};
+                __align__(16) float2 tmp[cpy_ne] = {{0.0f, 0.0f}};
                 if (ncols == 1 || ic0 + j < int(ne01.z)) {
                     ggml_cuda_memcpy_1(tmp,            &Q_j[i]);
                     ggml_cuda_memcpy_1(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu
index 8694fd06c7b..f19defbff93 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu
@@ -63,11 +63,19 @@ static __global__ void flash_attn_ext_f16(
     constexpr int frag_m = ncols == 8 ? 32 : 16;
     constexpr int frag_n = ncols == 8 ?  8 : 16;
     static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
+#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000
+    typedef wmma::fragment frag_a_K;
+    typedef wmma::fragment frag_a_V;
+    typedef wmma::fragment frag_b;
+    typedef wmma::fragment                      frag_c_KQ;
+    typedef wmma::fragment                          frag_c_VKQ;
+#else
     typedef wmma::fragment frag_a_K;
     typedef wmma::fragment frag_a_V;
     typedef wmma::fragment frag_b;
     typedef wmma::fragment                      frag_c_KQ;
     typedef wmma::fragment                          frag_c_VKQ;
+#endif
 
     constexpr int KQ_stride_tc  = nwarps*frag_m; // Number of KQ rows calculated in parallel.
     constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
@@ -126,6 +134,19 @@ static __global__ void flash_attn_ext_f16(
 
     __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
     half2 * VKQ2 = (half2 *) VKQ;
+
+#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000
+    const _Float16 * K_h_f16  = reinterpret_cast(K_h);
+    const _Float16 * V_h_f16  = reinterpret_cast(V_h);
+    _Float16       * KQ_f16   = reinterpret_cast<_Float16 *>(KQ);
+    _Float16       * VKQ_f16  = reinterpret_cast<_Float16 *>(VKQ);
+#else
+    const half * K_h_f16  = K_h;
+    const half * V_h_f16  = V_h;
+    half       * KQ_f16   = KQ;
+    half       * VKQ_f16  = VKQ;
+#endif
+
 #pragma unroll
     for (int j0 = 0; j0 < ncols; j0 += nwarps) {
         const int j = j0 + threadIdx.y;
@@ -160,7 +181,7 @@ static __global__ void flash_attn_ext_f16(
     for (int i0 = 0; i0 < D; i0 += 16) {
 #pragma unroll
         for (int j0 = 0; j0 < ncols; j0 += frag_n) {
-            wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
+            wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ_f16 + j0*D_padded + i0, D_padded);
         }
     }
 
@@ -180,7 +201,7 @@ static __global__ void flash_attn_ext_f16(
 #pragma unroll
             for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
                 frag_a_K K_a;
-                wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
+                wmma::load_matrix_sync(K_a, K_h_f16 + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
 #pragma unroll
                 for (int j = 0; j < ncols/frag_n; ++j) {
                     wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
@@ -310,7 +331,7 @@ static __global__ void flash_attn_ext_f16(
                 const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
                 wmma::load_matrix_sync(
                     KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
-                    KQ + j0*(kqar*kqs_padded) + k,
+                    KQ_f16 + j0*(kqar*kqs_padded) + k,
                     kqar*kqs_padded);
             }
         }
@@ -328,7 +349,7 @@ static __global__ void flash_attn_ext_f16(
                 const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
 
                 frag_a_V v_a;
-                wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
+                wmma::load_matrix_sync(v_a, V_h_f16 + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
 #pragma unroll
                 for (int j = 0; j < ncols/frag_n; ++j) {
                     wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
@@ -344,7 +365,7 @@ static __global__ void flash_attn_ext_f16(
 #pragma unroll
             for (int j0 = 0; j0 < ncols; j0 += frag_n) {
                 wmma::store_matrix_sync(
-                    KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
+                    KQ_f16 + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
                     VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
                     D_padded, wmma::mem_col_major);
             }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
index 1693479cb54..85c177f496f 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
@@ -18,12 +18,14 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
         }
     }
 
-    if (turing_mma_available(cc) && Q->ne[1] <= 16/ncols2) {
-        ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst);
-        return;
+    if constexpr (ncols2 <= 16) {
+        if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) {
+            ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst);
+            return;
+        }
     }
 
-    if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) {
+    if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 32/ncols2) {
         ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst);
         return;
     }
@@ -33,6 +35,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
 
 template 
 static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
     const ggml_tensor * KQV  = dst;
     const ggml_tensor * Q    = dst->src[0];
     const ggml_tensor * K    = dst->src[1];
@@ -46,7 +49,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
     //     are put into the template specialization without GQA optimizations.
     bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
     for (const ggml_tensor * t : {Q, K, V, mask}) {
-        if (t == nullptr) {
+        if (t == nullptr || ggml_is_quantized(t->type)) {
             continue;
         }
         for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
@@ -60,17 +63,38 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
     GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
     const int gqa_ratio = Q->ne[2] / K->ne[2];
 
-    if (use_gqa_opt && gqa_ratio % 8 == 0) {
+    // On Volta the GQA optimizations aren't as impactful vs. minimizing wasted compute:
+    if (cc == GGML_CUDA_CC_VOLTA) {
+        if (use_gqa_opt && gqa_ratio % 8 == 0) {
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst);
+            return;
+        }
+
+        if (use_gqa_opt && gqa_ratio % 4 == 0) {
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst);
+            return;
+        }
+
+        if (use_gqa_opt && gqa_ratio % 2 == 0) {
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst);
+            return;
+        }
+
+        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst);
+        return;
+    }
+
+    if (use_gqa_opt && gqa_ratio > 4) {
         ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst);
         return;
     }
 
-    if (use_gqa_opt && gqa_ratio % 4 == 0) {
+    if (use_gqa_opt && gqa_ratio > 2) {
         ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst);
         return;
     }
 
-    if (use_gqa_opt && gqa_ratio % 2 == 0) {
+    if (use_gqa_opt && gqa_ratio > 1) {
         ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst);
         return;
     }
@@ -79,6 +103,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
 }
 
 static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
     const ggml_tensor * KQV  = dst;
     const ggml_tensor * Q    = dst->src[0];
     const ggml_tensor * K    = dst->src[1];
@@ -111,7 +136,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
             ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
             break;
         case 576: {
-            // For Deepseek/GLM4, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
+            // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
             GGML_ASSERT(V->ne[0] == 512);
             float max_bias = 0.0f;
             memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
@@ -121,8 +146,46 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
 
             GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
             const int gqa_ratio = Q->ne[2] / K->ne[2];
-            GGML_ASSERT(gqa_ratio % 4 == 0);
-            if (gqa_ratio % 16 == 0) {
+            if (gqa_ratio == 20) { // GLM 4.7 Flash
+                if (cc >= GGML_CUDA_CC_DGX_SPARK) {
+                    if (Q->ne[1] <= 8) {
+                        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+                        break;
+                    }
+                    ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+                    break;
+                }
+                if (cc >= GGML_CUDA_CC_BLACKWELL) {
+                    if (Q->ne[1] <= 4 && K->ne[1] >= 65536) {
+                        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+                        break;
+                    }
+                    ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+                    break;
+                }
+                if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
+                    if (Q->ne[1] <= 4) {
+                        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+                        break;
+                    }
+                    ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+                    break;
+                }
+                if (cc >= GGML_CUDA_CC_TURING) {
+                    if (Q->ne[1] <= 4) {
+                        if (K->ne[1] <= 16384) {
+                            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+                            break;
+                        }
+                        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 32>(ctx, dst);
+                        break;
+                    }
+                    ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+                    break;
+                }
+                // Volta:
+                ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+            } else if (gqa_ratio % 16 == 0) {
                 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
             } else {
                 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512,  4>(ctx, dst);
@@ -234,7 +297,18 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
 
     // The effective batch size for the kernel can be increased by gqa_ratio.
     // The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,
-    const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
+    bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
+    for (const ggml_tensor * t : {Q, K, V, mask}) {
+        if (t == nullptr || ggml_is_quantized(t->type)) {
+            continue;
+        }
+        for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
+            if (t->nb[i] % 16 != 0) {
+                gqa_opt_applies = false;
+                break;
+            }
+        }
+    }
 
     const int cc = ggml_cuda_info().devices[device].cc;
 
@@ -255,7 +329,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
             if (V->ne[0] != 512) {
                 return BEST_FATTN_KERNEL_NONE;
             }
-            if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
+            if (!gqa_opt_applies) {
                 return BEST_FATTN_KERNEL_NONE;
             }
             break;
@@ -341,6 +415,43 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
         return BEST_FATTN_KERNEL_WMMA_F16;
     }
 
+    if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72) {
+        if (can_use_vector_kernel) {
+            if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
+                if (Q->ne[1] == 1) {
+                    if (!gqa_opt_applies) {
+                        return BEST_FATTN_KERNEL_VEC;
+                    }
+                }
+            } else {
+                if (Q->ne[1] <= 2) {
+                    return BEST_FATTN_KERNEL_VEC;
+                }
+            }
+        }
+        int gqa_ratio_eff = 1;
+        const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
+        while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
+            gqa_ratio_eff *= 2;
+        }
+        if (Q->ne[1] * gqa_ratio_eff <= 8) {
+            return BEST_FATTN_KERNEL_TILE; // AMD WMMA is only faster if the full tile width of 16 can be utilized.
+        }
+        return BEST_FATTN_KERNEL_MMA_F16;
+    }
+
+    // Use MFMA flash attention for CDNA (MI100+):
+    if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 576) {
+        const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1);
+        // MMA vs tile crossover benchmarked on MI300X @ d32768:
+        //   hsk=64  (gqa=4): MMA wins at eff >= 128 (+11%)
+        //   hsk=128 (gqa=4): MMA wins at eff >= 128 (+4%)
+        if (eff_nq >= (GGML_CUDA_CC_IS_CDNA1(cc) && Q->ne[0] == 64 ? 64 : 128)) {
+            return BEST_FATTN_KERNEL_MMA_F16;
+        }
+        // Fall through to tile kernel for small effective batch sizes.
+    }
+
     // If there are no tensor cores available, use the generic tile kernel:
     if (can_use_vector_kernel) {
         if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
index 5c9dfd03242..e42a1599be6 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -19,6 +19,7 @@
 #include "ggml-cuda/count-equal.cuh"
 #include "ggml-cuda/cpy.cuh"
 #include "ggml-cuda/cross-entropy-loss.cuh"
+#include "ggml-cuda/cumsum.cuh"
 #include "ggml-cuda/diagmask.cuh"
 #include "ggml-cuda/diag.cuh"
 #include "ggml-cuda/fattn.cuh"
@@ -44,6 +45,7 @@
 #include "ggml-cuda/ssm-scan.cuh"
 #include "ggml-cuda/sum.cuh"
 #include "ggml-cuda/sumrows.cuh"
+#include "ggml-cuda/top-k.cuh"
 #include "ggml-cuda/mean.cuh"
 #include "ggml-cuda/tsembd.cuh"
 #include "ggml-cuda/topk-moe.cuh"
@@ -68,20 +70,23 @@
 #include 
 #include 
 #include 
-#include 
+#include 
 #include 
 #include 
 #include 
 #include 
 #include 
-#include 
-#include 
-#include 
+#include 
+#include 
+#include 
 #include 
 #include 
+#include 
 
 static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
 
+bool reserving_graph = false;
+
 [[noreturn]]
 void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
     int id = -1; // in case cudaGetDevice fails
@@ -251,16 +256,6 @@ static ggml_cuda_device_info ggml_cuda_init() {
     GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
 
     int64_t total_vram = 0;
-#ifdef GGML_CUDA_FORCE_MMQ
-    GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ:    yes\n", __func__);
-#else
-    GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ:    no\n", __func__);
-#endif // GGML_CUDA_FORCE_MMQ
-#ifdef GGML_CUDA_FORCE_CUBLAS
-    GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__);
-#else
-    GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
-#endif // GGML_CUDA_FORCE_CUBLAS
     GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
 
     std::vector> turing_devices_without_mma;
@@ -301,6 +296,14 @@ static ggml_cuda_device_info ggml_cuda_init() {
         info.devices[id].nsm        = prop.multiProcessorCount;
         info.devices[id].smpb       = prop.sharedMemPerBlock;
         info.devices[id].warp_size  = prop.warpSize;
+
+#ifndef GGML_USE_MUSA
+        int supports_coop_launch = 0;
+        CUDA_CHECK(cudaDeviceGetAttribute(&supports_coop_launch, cudaDevAttrCooperativeLaunch, id));
+        info.devices[id].supports_cooperative_launch = !!supports_coop_launch;
+#else
+        info.devices[id].supports_cooperative_launch = false;
+#endif // !(GGML_USE_MUSA)
 #if defined(GGML_USE_HIP)
         info.devices[id].smpbo = prop.sharedMemPerBlock;
 
@@ -407,6 +410,9 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
         allocate(alloc) {
     }
 
+    bool alloc_memory() override { return allocate; }
+    size_t alloc_size() override { return pool_size + last_alloc; }
+
     ~ggml_cuda_pool_leg() {
         ggml_cuda_set_device(device);
         for (int i = 0; i < MAX_BUFFERS; ++i) {
@@ -496,14 +502,6 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
         }
         pool_size -= size;
     }
-
-    bool alloc_memory() override {
-        return allocate;
-    }
-
-    size_t alloc_size() override {
-        return pool_size + last_alloc;
-    }
 };
 
 // pool with virtual memory
@@ -531,6 +529,9 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
         }
     }
 
+    bool alloc_memory() override { return allocate; }
+    size_t alloc_size() override { return pool_size + last_alloc; }
+
     ~ggml_cuda_pool_vmm() {
         if (pool_addr != 0 && allocate) {
 #if defined(GGML_USE_HIP)
@@ -634,21 +635,12 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
         // all deallocations must be in reverse order of the allocations
         GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used));
     }
-
-    bool alloc_memory() override {
-        return allocate;
-    }
-
-    size_t alloc_size() override {
-        return pool_size + last_alloc;
-    }
-
 };
 #endif // defined(GGML_USE_VMM)
 
 std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int                  device,
                                                                                [[maybe_unused]] int stream_no,
-                                                                               bool alloc) {
+                                                                               bool                 alloc) {
 #if defined(GGML_USE_VMM)
     if (ggml_cuda_info().devices[device].vmm) {
         return std::unique_ptr(new ggml_cuda_pool_vmm(device, alloc));
@@ -2345,7 +2337,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
 
             const int cc            = ggml_cuda_info().devices[id].cc;
             const int warp_size     = ggml_cuda_info().devices[id].warp_size;
-            use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
+            use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);
             use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
             use_mul_mat_vec_f       = use_mul_mat_vec_f         && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
             any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);
@@ -2353,7 +2345,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     } else {
         const int cc            = ggml_cuda_info().devices[ctx.device].cc;
         const int warp_size     = ggml_cuda_info().devices[ctx.device].warp_size;
-        use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
+        use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);
         use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
         use_mul_mat_vec_f       = use_mul_mat_vec_f         && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
         any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);
@@ -2411,17 +2403,24 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
 
     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
 
+    // [TAG_MUL_MAT_ID_CUDA_GRAPHS]
     if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-        if (ne2 == 1) {
+        static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
+        if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
             if (ggml_is_quantized(src0->type)) {
-                ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
+                if (ne2 <= MMVQ_MMID_MAX_BATCH_SIZE) {
+                    ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
+                    return;
+                }
             } else {
-                ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
+                if (GGML_CUDA_CC_IS_AMD(cc)) {
+                    ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
+                    return;
+                }
             }
-            return;
         }
 
-        if (ggml_cuda_should_use_mmq(src0->type, cc, ne12)) {
+        if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) {
             ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst);
             return;
         }
@@ -2432,6 +2431,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
         }
     }
 
+    // note: this path should not be reached when recording CUDA graphs, because it requires stream synchronization
+    // TODO: add asserts to verify this. should work with CUDA, HIP, etc.
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(nb12 % nb11 == 0);
@@ -2821,6 +2822,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_SUM:
             ggml_cuda_op_sum(ctx, dst);
             break;
+        case GGML_OP_CUMSUM:
+            ggml_cuda_op_cumsum(ctx, dst);
+            break;
         case GGML_OP_SUM_ROWS:
             ggml_cuda_op_sum_rows(ctx, dst);
             break;
@@ -2833,6 +2837,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_SSM_SCAN:
             ggml_cuda_op_ssm_scan(ctx, dst);
             break;
+        case GGML_OP_TOP_K:
+            ggml_cuda_op_top_k(ctx, dst);
+            break;
         case GGML_OP_ARGSORT:
             ggml_cuda_op_argsort(ctx, dst);
             break;
@@ -2842,9 +2849,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_CROSS_ENTROPY_LOSS:
             ggml_cuda_cross_entropy_loss(ctx, dst);
             break;
-        case GGML_OP_CUMSUM:
-            ggml_cuda_op_cumsum(ctx, dst);
-            break;
         case GGML_OP_TRI:
             ggml_cuda_op_tri(ctx, dst);
             break;
@@ -2984,19 +2988,11 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
 }
 
 #ifdef USE_CUDA_GRAPH
-static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
-    int batch_size, bool use_cuda_graph) {
+static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
 
+    bool use_cuda_graph = true;
     // Loop over nodes in GGML graph to obtain info needed for CUDA graph
 
-    const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
-    const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
-    const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
-    const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
-    const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
-    const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
-    const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
-
     for (int i = 0; i < cgraph->n_nodes; i++) {
         ggml_tensor * node = cgraph->nodes[i];
 
@@ -3011,43 +3007,17 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
 #endif
         }
 
-        if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) {
-            use_cuda_graph = false; // This node type is not supported by CUDA graph capture
+        // [TAG_MUL_MAT_ID_CUDA_GRAPHS]
+        if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > MMVQ_MMID_MAX_BATCH_SIZE)) {
+            // under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
+            // TODO: figure out a way to enable for larger batch sizes, without hurting performance
+            // ref: https://github.com/ggml-org/llama.cpp/pull/18958
+            use_cuda_graph = false;
 #ifndef NDEBUG
             GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
 #endif
         }
 
-        // If we have an explicit batch size hint then we don't need to use the tensor name heuristics
-        if (batch_size >= 0) {
-            if (batch_size > 1) {
-                use_cuda_graph = false;
-#ifndef NDEBUG
-                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%d]\n", __func__, batch_size);
-#endif
-            }
-        } else {
-            if (node->op == GGML_OP_ADD &&
-                node->src[1] && node->src[1]->ne[1] > 1 &&
-                (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
-                (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
-                strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
-                strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
-                strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
-                strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
-                strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
-                // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
-                // by means of matching node names. See
-                // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
-                // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
-                // Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
-                use_cuda_graph = false;
-#ifndef NDEBUG
-                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
-#endif
-            }
-        }
-
         if (!use_cuda_graph) {
             break;
         }
@@ -3056,94 +3026,146 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
     return use_cuda_graph;
 }
 
-static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
-    graph_node_properties->node_address = node->data;
-    graph_node_properties->node_op = node->op;
+static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
+    memset(props, 0, sizeof(ggml_cuda_graph_node_properties));
+    props->node_data = node->data;
+    props->node_op = node->op;
+    props->node_type = node->type;
+    props->flags = node->flags;
     for (int i = 0; i < GGML_MAX_DIMS; i++) {
-        graph_node_properties->ne[i] = node->ne[i];
-        graph_node_properties->nb[i] = node->nb[i];
+        props->ne[i] = node->ne[i];
+        props->nb[i] = node->nb[i];
     }
     for (int i = 0; i < GGML_MAX_SRC; i++) {
-        graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
+        if (!node->src[i]) {
+            continue;
+        }
+
+        props->src_data[i] = node->src[i]->data;
     }
-    memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
+    memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS);
 }
 
-static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
-    if (node->data != graph_node_properties->node_address &&
-          node->op != GGML_OP_VIEW) {
+static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) {
+    if (node->data != props->node_data && node->op != GGML_OP_VIEW) {
+        return false;
+    }
+
+    if (node->op != props->node_op) {
         return false;
     }
 
-    if (node->op != graph_node_properties->node_op) {
+    if (node->type != props->node_type) {
         return false;
     }
 
     for (int i = 0; i < GGML_MAX_DIMS; i++) {
-        if (node->ne[i] != graph_node_properties->ne[i]) {
+        if (node->ne[i] != props->ne[i]) {
             return false;
         }
-        if (node->nb[i] != graph_node_properties->nb[i]) {
+        if (node->nb[i] != props->nb[i]) {
             return false;
         }
     }
 
-    for (int i = 0; i < GGML_MAX_SRC; i++) {
-        if (node->src[i] &&
-            node->src[i]->data != graph_node_properties->src_address[i] &&
-            node->op != GGML_OP_VIEW
-        ) {
-            return false;
+    if (node->op != GGML_OP_VIEW) {
+        for (int i = 0; i < GGML_MAX_SRC; i++) {
+            if (!node->src[i]) {
+                if (props->src_data[i] != nullptr) {
+                    return false;
+                }
+                continue;
+            }
+
+            if (node->src[i]->data != props->src_data[i]) {
+                return false;
+            }
         }
     }
 
-    if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) &&
-        memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
+    if (memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
+        return false;
+    }
+
+    if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) != (props->flags & GGML_TENSOR_FLAG_COMPUTE)) {
         return false;
     }
 
     return true;
 }
 
-static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
+static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {
+    return cgraph->nodes[0];
+}
 
-    bool cuda_graph_update_required = false;
+static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
+    bool res = false;
 
-    if (cuda_ctx->cuda_graph->instance == nullptr) {
-        cuda_graph_update_required = true;
-    }
+    const void * graph_key = ggml_cuda_graph_get_key(cgraph);
+    ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
 
     // Check if the graph size has changed
-    if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
-        cuda_graph_update_required = true;
-        cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
+    if (graph->props.size() != (size_t)cgraph->n_nodes) {
+        res = true;
+        graph->props.resize(cgraph->n_nodes);
     }
 
     // Loop over nodes in GGML graph to determine if CUDA graph update is required
     // and store properties to allow this comparison for the next token
+    std::unordered_set seen_node;
+    std::vector srcs_extra;
     for (int i = 0; i < cgraph->n_nodes; i++) {
-        bool has_matching_properties = true;
-        if (!cuda_graph_update_required) {
-            has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
+        bool props_match = true;
+
+        seen_node.insert(cgraph->nodes[i]);
+
+        if (!res) {
+            props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]);
+        }
+        if (!props_match) {
+            res = true;
+        }
+        ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]);
+
+        for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
+            ggml_tensor * src = cgraph->nodes[i]->src[src_idx];
+            if (src && seen_node.find(src) == seen_node.end()) {
+                srcs_extra.push_back(src);
+            }
+        }
+    }
+
+    if (graph->extra.size() != (size_t) srcs_extra.size()) {
+        res = true;
+        graph->extra.resize(srcs_extra.size());
+    }
+
+    for (size_t i = 0; i < srcs_extra.size(); ++i) {
+        bool props_match = true;
+
+        if (!res) {
+            props_match = ggml_cuda_graph_node_properties_match(srcs_extra[i], &graph->extra[i]);
         }
-        if (!has_matching_properties) {
-            cuda_graph_update_required = true;
+
+        if (!props_match) {
+            res = true;
         }
-        set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
+        ggml_cuda_graph_node_set_properties(&graph->extra[i], srcs_extra[i]);
     }
 
-    return cuda_graph_update_required;
+    return res;
 }
 
-static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
+static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
+    ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
 
 #if CUDART_VERSION >= 12000
     cudaGraphExecUpdateResultInfo result_info;
-    cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
+    cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &result_info);
 #else
     cudaGraphNode_t errorNode;
     cudaGraphExecUpdateResult result_info;
-    cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
+    cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &errorNode, &result_info);
 #endif // CUDART_VERSION >= 12000
 
     if (stat == cudaErrorGraphExecUpdateFailure) {
@@ -3154,14 +3176,14 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
         // The pre-existing graph exec cannot be updated due to violated constraints
         // so instead clear error and re-instantiate
         (void)cudaGetLastError();
-        CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
-        cuda_ctx->cuda_graph->instance = nullptr;
-        CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
+        CUDA_CHECK(cudaGraphExecDestroy(graph->instance));
+        graph->instance = nullptr;
+        CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
     } else {
         GGML_ASSERT(stat == cudaSuccess);
     }
 }
-#endif
+#endif // USE_CUDA_GRAPH
 
 static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
                                                 const ggml_tensor * view,
@@ -3197,53 +3219,166 @@ static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
     return true;
 }
 
-static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops, std::initializer_list unary_ops) {
-#ifndef NDEBUG
-    const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
-    GGML_ASSERT(unary_ops.size() == num_unary);
-#endif
-
-    //TODO: remove special case once ggml_can_fuse can handle empty nodes
-    std::initializer_list topk_moe_ops =
-        ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false);
-    std::initializer_list topk_moe_ops_with_norm =
-        ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false);
-    std::initializer_list topk_moe_ops_delayed_softmax =
-        ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
+static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_args & args) {
+    args.sigmoid         = false;
+    args.softmax         = false;
+    args.delayed_softmax = false;
+    args.prob_bias       = false;
+    args.norm            = false;
 
-    const auto is_equal = [](const std::initializer_list & list1,
-                             const std::initializer_list & list2) {
-        return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
-    };
+    const int      n_nodes = cgraph->n_nodes;
+    ggml_tensor ** nodes   = cgraph->nodes;
 
-    if (is_equal(topk_moe_ops_with_norm, ops) &&
-        ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
-        ggml_tensor * softmax = cgraph->nodes[node_idx];
-        ggml_tensor * weights = cgraph->nodes[node_idx + 9];
+    if (nodes[node_idx]->op == GGML_OP_SOFT_MAX) {
+        args.softmax = true;
+    }
 
-        if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
-            return true;
+    if (nodes[node_idx]->op == GGML_OP_UNARY) {
+        if (ggml_get_unary_op(nodes[node_idx]) != GGML_UNARY_OP_SIGMOID) {
+            return false;
         }
+        args.sigmoid = true;
     }
 
-    if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
-        ggml_tensor * softmax = cgraph->nodes[node_idx];
-        ggml_tensor * weights = cgraph->nodes[node_idx + 4];
-        if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
-            return true;
+    if (nodes[node_idx]->op == GGML_OP_ARGSORT) {
+        args.delayed_softmax = true;
+    }
+
+    node_idx++;
+
+    if (args.sigmoid || args.softmax) {
+        // SOFTMAX -> RESHAPE
+        if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_RESHAPE ||
+                nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
+            return false;
+        }
+        ggml_tensor * probs_reshaped = nodes[node_idx];
+        node_idx++;
+
+        if (node_idx >= n_nodes) {
+            return false;
+        }
+
+        // src of bias add is the unreshaped probs (-2 instead of -1)
+        if (nodes[node_idx]->op == GGML_OP_ADD && nodes[node_idx]->src[0] == nodes[node_idx - 2]) {
+            args.prob_bias = true;
+            node_idx++;
+        }
+        // RESHAPE/ADD -> ARGSORT
+        if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_ARGSORT) {
+            return false;
+        }
+
+        if (args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
+            return false;
+        } else if (!args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 2]) {
+            return false;
+        }
+
+        node_idx++;
+
+        // ARGSORT-> VIEW
+        if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
+                nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
+            return false;
+        }
+        node_idx++;
+
+        if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_GET_ROWS) {
+            return false;
+        }
+
+        // GET_ROWS
+        if (nodes[node_idx]->src[0] != probs_reshaped || nodes[node_idx]->src[1] != nodes[node_idx - 1]) {
+            return false;
+        }
+        node_idx++;
+    } else if (args.delayed_softmax) {
+        if (node_idx - 2 < 0) {
+            return false;
+        }
+        ggml_tensor * probs_reshaped = nodes[node_idx - 2];
+
+        // VIEW->ARGSORT
+        if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
+            nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
+            return false;
+        }
+        node_idx++;
+
+        // GET_ROWS
+        if (node_idx >= n_nodes || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
+                nodes[node_idx]->src[0] != probs_reshaped) {
+            return false;
         }
+        node_idx++;
+
+        static const std::vector remaining_ops = { GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
+
+        for (const ggml_op op : remaining_ops) {
+            if (node_idx >= n_nodes || nodes[node_idx]->op != op || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
+                return false;
+            }
+            node_idx++;
+        }
+    }
+
+    // At this point we can check for norm + scale. Everything is now at least valid till the norm
+    if (node_idx >= n_nodes) {
+        return true;
     }
 
-    if (is_equal(topk_moe_ops_delayed_softmax, ops) &&
-        ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
-        ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
-        ggml_tensor * weights = cgraph->nodes[node_idx + 5];
+    if (nodes[node_idx]->op == GGML_OP_RESHAPE) {
+        //check RESHAPE->SUM_ROWS->CLAMP->DIV->RESHAPE
+        static const std::vector norm_ops = { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP };
+
+        args.norm = true;
+        for (const ggml_op op : norm_ops) {
+            if (nodes[node_idx]->op == op && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
+                node_idx++;
+            } else {
+                args.norm = false;
+                return true;
+            }
+        }
+
+        // DIV <- CLAMP, RESHAPE
+        if (nodes[node_idx]->op != GGML_OP_DIV || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
+            nodes[node_idx]->src[0] != nodes[node_idx - 3]) {
+            args.norm = false;
+            return true;
+        }
+        node_idx++;
 
-        if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
+        if (nodes[node_idx]->op != GGML_OP_RESHAPE || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
+            args.norm = false;
             return true;
         }
+
+        node_idx++;
+    }
+
+    if (nodes[node_idx]->op == GGML_OP_SCALE && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
+        args.scale = true;
     }
 
+    return true;
+}
+
+static bool ggml_cuda_can_fuse(const struct ggml_cgraph *                cgraph,
+                               int                                       node_idx,
+                               std::initializer_list       ops,
+                               std::initializer_list unary_ops) {
+#ifndef NDEBUG
+    const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
+    GGML_ASSERT(unary_ops.size() == num_unary);
+#endif
+
+    const auto is_equal = [](const std::initializer_list & list1,
+                             const std::initializer_list & list2) {
+        return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
+    };
+
     std::initializer_list mul_mat_bias_glu_ops    = { GGML_OP_MUL_MAT,    GGML_OP_ADD,    GGML_OP_MUL_MAT,    GGML_OP_ADD,    GGML_OP_GLU };
     std::initializer_list mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };
 
@@ -3356,11 +3491,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
     return false;
 }
 
-static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
-    bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
+static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) {
+    bool graph_evaluated_or_captured = false;
 
     // flag used to determine whether it is an integrated_gpu
-    const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
+    const bool integrated            = ggml_cuda_info().devices[cuda_ctx->device].integrated;
 
     ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context();
     bool                         is_concurrent_event_active = false;
@@ -3398,6 +3533,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
                     should_launch_concurrent_events = should_launch_concurrent_events && event.is_valid();
                 }
             }
+
             if (should_launch_concurrent_events) {
                 // Restore original node order within each concurrent region to enable fusion within streams
 
@@ -3449,6 +3585,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
                         cgraph->nodes[start_pos + i] = const_cast(event.original_order[i]);
                     }
                 }
+            } else {
+                stream_ctx.concurrent_events.clear();
             }
 
             for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -3495,6 +3633,10 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
                     continue;
                 }
 
+                if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
+                    continue;
+                }
+
                 // When reserving, we are forcing CUDA graphs but this operation is not graph-safe so we need to skip it
                 if (reserving_graph && node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) {
                     continue;
@@ -3503,35 +3645,75 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
                 // start of fusion operations
                 static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
                 if (!disable_fusion) {
+                    ggml_cuda_topk_moe_args args;
+
+                    if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX ||
+                        cgraph->nodes[i]->op == GGML_OP_ARGSORT) {
+                        const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args);
+
+                        std::vector ops;
+
+                        if (can_fuse) {
+                            const ggml_tensor * logits  = node->src[0];
+                            ggml_tensor *       weights = nullptr;
+                            ggml_tensor *       ids     = nullptr;
+                            const ggml_tensor * bias    = nullptr;
+                            const ggml_tensor * clamp   = nullptr;
+                            const ggml_tensor * scale   = nullptr;
+
+                            if (!args.delayed_softmax) {
+                                ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX;
+                                int     out_nodes[2];  // nodes which can't be elided
+
+                                if (args.prob_bias) {
+                                    bias = cgraph->nodes[i + 2]->src[1];
+                                    ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT,
+                                                            GGML_OP_VIEW, GGML_OP_GET_ROWS });
+                                    out_nodes[0] = i + 4;
+                                    ids          = cgraph->nodes[i + 4];
+                                } else {
+                                    ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW,
+                                                            GGML_OP_GET_ROWS });
+                                    out_nodes[0] = i + 3;
+                                    ids          = cgraph->nodes[i + 3];
+                                }
 
-                    if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
-                        ggml_tensor * weights          = cgraph->nodes[i + 9];
-                        ggml_tensor * selected_experts = cgraph->nodes[i + 3];
-                        ggml_tensor * clamp            = cgraph->nodes[i + 7];
-                        ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
-                                              /*delayed softmax*/ false, clamp);
-                        i += 9;
-                        continue;
-                    }
-
-                    if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
-                        ggml_tensor * weights          = cgraph->nodes[i + 4];
-                        ggml_tensor * selected_experts = cgraph->nodes[i + 3];
-                        ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
-                                              /*delayed softmax*/ false);
-                        i += 4;
-                        continue;
-                    }
+                                if (args.norm) {
+                                    ops.insert(ops.end(), { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
+                                                            GGML_OP_DIV, GGML_OP_RESHAPE });
+                                    clamp = cgraph->nodes[i + ops.size() - 3];
+                                }
+                                if (args.scale) {
+                                    ops.insert(ops.end(), { GGML_OP_SCALE });
+                                    scale = cgraph->nodes[i + ops.size() - 1];
+                                }
 
-                    if (ggml_cuda_can_fuse(cgraph, i,
-                                           ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) {
-                        ggml_tensor * weights = cgraph->nodes[i + 5];
-                        ggml_tensor * ids     = cgraph->nodes[i + 1];
+                                weights      = cgraph->nodes[i + ops.size() - 1];
+                                out_nodes[1] = i + ops.size() - 1;
 
-                        ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false,
-                                              /*delayed_softmax*/ true);
-                        i += 5;
-                        continue;
+                                if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
+                                        ggml_cuda_should_use_topk_moe(node, logits, weights, ids)) {
+                                    ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
+                                    i += ops.size() - 1;
+                                    continue;
+                                }
+                            } else if (!args.norm && !args.prob_bias) {
+                                //special case gpt-oss, no norm, no bias.
+                                ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
+                                                        GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE });
+                                weights                     = cgraph->nodes[i + 5];
+                                ids                         = cgraph->nodes[i + 1];
+                                const ggml_tensor * softmax = cgraph->nodes[i + 4];
+
+                                int out_nodes[2] = { i + 1, i + 5 };
+                                if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
+                                        ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids)) {
+                                    ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
+                                    i += ops.size() - 1;
+                                    continue;
+                                }
+                            }
+                        }
                     }
 
                     if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
@@ -3563,11 +3745,13 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
                         n_fuse++;
 
                         if (n_fuse > 1) {
+                            ggml_tensor fused_add_node;
+                            memcpy(&fused_add_node, node, sizeof(ggml_tensor));
                             for (int j = 0; j < n_fuse - 1; ++j) {
-                                node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
+                                fused_add_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
                             }
-                            cgraph->nodes[i + n_fuse - 1]->data = node->data;
-                            ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
+                            fused_add_node.data = cgraph->nodes[i + n_fuse - 1]->data;
+                            ggml_cuda_op_fused_add(*cuda_ctx, &fused_add_node, n_fuse);
                             i += n_fuse - 1;
 
                             continue;
@@ -3808,13 +3992,14 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
         }
 
 #ifdef USE_CUDA_GRAPH
+        ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
         if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
-            if (cuda_ctx->cuda_graph->graph != nullptr) {
-                CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
-                cuda_ctx->cuda_graph->graph = nullptr;
+            if (graph->graph != nullptr) {
+                CUDA_CHECK(cudaGraphDestroy(graph->graph));
+                graph->graph = nullptr;
             }
 
-            CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
+            CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &graph->graph));
             graph_evaluated_or_captured = true; // CUDA graph has been captured
 
             std::lock_guard lock(ggml_cuda_lock);
@@ -3827,75 +4012,85 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
     }
 
     if (use_cuda_graph) {
-        if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
-            CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
+        ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
+        if (graph->instance == nullptr) { // Create executable graph from captured graph.
+            CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
         }
         if (cuda_graph_update_required) { // Update graph executable
-            update_cuda_graph_executable(cuda_ctx);
+            ggml_cuda_graph_update_executable(cuda_ctx, graph_key);
         }
         // Launch graph
-        CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
+        CUDA_CHECK(cudaGraphLaunch(graph->instance, cuda_ctx->stream()));
 #else
+        GGML_UNUSED(graph_key);
         graph_evaluated_or_captured = true;
 #endif  // USE_CUDA_GRAPH
     }
 }
 
-static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) {
-    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
-    cuda_ctx->pool_set_alloc(true);
-
-    ggml_cuda_set_device(cuda_ctx->device);
-
 #ifdef USE_CUDA_GRAPH
-    static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
-
-    // Objects required for CUDA Graph
-    if (cuda_ctx->cuda_graph == nullptr) {
-        cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
-    }
-
-    bool use_cuda_graph = true;
-    bool cuda_graph_update_required = false;
+static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
+    ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
 
-    if (cuda_ctx->cuda_graph->graph == nullptr) {
+    if (graph->graph == nullptr) {
         if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
-            cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
-#ifndef NDEBUG
-            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
-#endif
+            if (!graph->disable_due_to_gpu_arch) {
+                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
+            }
+            graph->disable_due_to_gpu_arch = true;
         }
     }
 
-    // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
-    // or previous graph capture failure.
-    // Also disable for multi-gpu for now. TO DO investigate
-    if (disable_cuda_graphs_due_to_env
-        || cuda_ctx->cuda_graph->disable_due_to_gpu_arch
-        || cuda_ctx->cuda_graph->disable_due_to_too_many_updates
-        || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
-        use_cuda_graph = false;
-    }
+    return graph->is_enabled();
+}
+#endif // USE_CUDA_GRAPH
 
-    if (use_cuda_graph) {
-        cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
+static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
+    cuda_ctx->pool_set_alloc(true);
 
-        use_cuda_graph = check_node_graph_compatibility(cgraph, batch_size, use_cuda_graph);
+    GGML_UNUSED(batch_size);
 
-        // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
-        if (use_cuda_graph && cuda_graph_update_required) {
-            cuda_ctx->cuda_graph->number_consecutive_updates++;
-        } else {
-            cuda_ctx->cuda_graph->number_consecutive_updates = 0;
-        }
+    ggml_cuda_set_device(cuda_ctx->device);
 
-        if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
-            cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
-#ifndef NDEBUG
-            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
-#endif
+    bool use_cuda_graph             = false;
+    bool cuda_graph_update_required = false;
+    const void * graph_key = nullptr;
+
+#ifdef USE_CUDA_GRAPH
+    graph_key = ggml_cuda_graph_get_key(cgraph);
+
+    ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
+
+    ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
+    if (graph->is_enabled()) {
+        const bool graph_compatible = ggml_cuda_graph_check_compability(cgraph);
+        if (graph_compatible) {
+            const bool properties_changed = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
+
+            if (!graph->warmup_complete) {
+                // Warmup: need at least 2 calls with no property change on the 2nd call
+                if (!properties_changed) {
+                    graph->warmup_complete = true;
+                    GGML_LOG_DEBUG("%s: CUDA graph warmup complete\n", __func__);
+                    use_cuda_graph = true;
+                    cuda_graph_update_required = true;
+                }
+                // else: properties changed or first call - execute directly (use_cuda_graph stays false)
+            } else {
+                // Post-warmup: normal CUDA graph operation
+                if (properties_changed) {
+                    // Properties changed - reset warmup, execute directly until stable again
+                    graph->warmup_complete = false;
+                    GGML_LOG_DEBUG("%s: CUDA graph warmup reset\n", __func__);
+                } else {
+                    use_cuda_graph = true;
+                    cuda_graph_update_required = graph->instance == nullptr;
+                }
+            }
         }
     }
+#endif // USE_CUDA_GRAPH
 
     if (use_cuda_graph && cuda_graph_update_required) {
         // Start CUDA graph capture
@@ -3907,29 +4102,19 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
         CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
     }
 
-#else
-    bool use_cuda_graph = false;
-    bool cuda_graph_update_required = false;
-#endif // USE_CUDA_GRAPH
-
-    bool graph_evaluated_or_captured = false;
-
-    evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
+    ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required, graph_key);
 
     return GGML_STATUS_SUCCESS;
 }
 
-// This is used to skip operations that are not graph safe during the reservation process.
-bool reserving_graph = false;
-
 static enum ggml_status ggml_backend_cuda_graph_reserve(ggml_backend_t backend, ggml_cgraph * cgraph, bool alloc) {
     ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
     cuda_ctx->pool_set_alloc(alloc);
 
+    const void * graph_key = nullptr;
     #ifdef USE_CUDA_GRAPH
-    if (cuda_ctx->cuda_graph == nullptr) {
-        cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
-    }
+    graph_key = ggml_cuda_graph_get_key(cgraph);
+    // cuda_ctx->cuda_graph(graph_key) will auto-create the graph if needed
     #endif
 
     ggml_cuda_set_device(cuda_ctx->device);
@@ -3951,9 +4136,8 @@ static enum ggml_status ggml_backend_cuda_graph_reserve(ggml_backend_t backend,
     try {
         bool use_cuda_graph = false;
         bool cuda_graph_update_required = false;
-        bool graph_evaluated_or_captured = false;
 
-        evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
+        ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required, graph_key);
     } catch (const std::exception &e) {
         result = GGML_STATUS_FAILED;
     }
@@ -4018,8 +4202,17 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
 static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
     ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
 
+#ifdef USE_CUDA_GRAPH
+    const void * graph_key = ggml_cuda_graph_get_key(cgraph);
+    const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
+#else
+    const bool use_cuda_graph = false;
+    GGML_UNUSED(cuda_ctx);
+    GGML_UNUSED(cgraph);
+#endif
+
     static bool enable_graph_optimization = [] {
-        const char * env = getenv("GGML_CUDA_GRAPH_OPT");
+        const char * env     = getenv("GGML_CUDA_GRAPH_OPT");
         return env != nullptr && atoi(env) == 1;
     }();
 
@@ -4027,12 +4220,13 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
         return;
     }
 
-    GGML_ASSERT(ggml_backend_cuda_get_device_count() == 1 && "compute graph optimization is only supported on single GPU in the CUDA backend");
-    GGML_LOG_DEBUG("Optimizing CUDA graph %p with %d nodes\n", cgraph->nodes, cgraph->n_nodes);
-
     ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context();
     stream_context.reset();
 
+    if (!use_cuda_graph || ggml_backend_cuda_get_device_count() != 1) {
+        return;
+    }
+
     // number of out-degrees for a particular node
     std::unordered_map fan_out;
     // reverse mapping of node to index in the cgraph
@@ -4093,6 +4287,12 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
         if (count >= min_fan_out && count <= max_fan_out) {
             const int root_node_idx = node_indices[root_node];
 
+            // only optimize for attn_norm
+            // TODO: make this more generic
+            if (!strstr(root_node->name, "attn_norm")) {
+                continue;
+            }
+
             bool is_part_of_event = false;
             for (const auto & [start, end] : concurrent_node_ranges) {
                 if (root_node_idx >= start && root_node_idx <= end) {
@@ -4337,6 +4537,7 @@ struct ggml_backend_cuda_device_context {
     int driver_major;
     int driver_minor;
     int integrated;
+    int op_offload_min_batch_size;
 };
 
 static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
@@ -4496,8 +4697,7 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
     props->id          = ggml_backend_cuda_device_get_id(dev);
     props->type        = ggml_backend_cuda_device_get_type(dev);
     props->device_id   = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
-
-    // Memory reporting is disabled to avoid allocation of a CUDA primary context (~300 MB per device).
+    // Prefer calling ggml_backend_dev_memory() explicitly if you need memory data.
     // If you need the memory data, call ggml_backend_dev_memory() explicitly.
     props->memory_total = props->memory_free = 0;
 
@@ -4593,6 +4793,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 case GGML_UNARY_OP_CEIL:
                 case GGML_UNARY_OP_ROUND:
                 case GGML_UNARY_OP_TRUNC:
+                    // TODO: should become:
+                    //return ggml_is_contiguous_rows(op->src[0]);
                     return ggml_is_contiguous(op->src[0]);
                 default:
                     return false;
@@ -4812,7 +5014,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_L2_NORM:
             return true;
         case GGML_OP_RMS_NORM_BACK:
-            return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
+            return ggml_is_contiguous(op->src[0]);
             break;
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
@@ -4874,10 +5076,14 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_CONV_2D_DW:
         case GGML_OP_CONV_TRANSPOSE_2D:
         case GGML_OP_POOL_2D:
-        case GGML_OP_ACC:
             return true;
+        case GGML_OP_ACC:
+            // TODO: extend support like so:
+            //return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]);
+            return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
         case GGML_OP_SUM:
             return ggml_is_contiguous_rows(op->src[0]);
+        case GGML_OP_TOP_K:
         case GGML_OP_ARGSORT:
 #ifndef GGML_CUDA_USE_CUB
             return op->src[0]->ne[0] <= 1024;
@@ -4887,8 +5093,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_SUM_ROWS:
         case GGML_OP_MEAN:
         case GGML_OP_GROUP_NORM:
-        case GGML_OP_PAD:
             return ggml_is_contiguous(op->src[0]);
+        case GGML_OP_PAD:
+            return true;
         case GGML_OP_UPSCALE:
         case GGML_OP_PAD_REFLECT_1D:
         case GGML_OP_ARANGE:
@@ -4938,11 +5145,9 @@ static int64_t get_op_batch_size(const ggml_tensor * op) {
 }
 
 static bool ggml_backend_cuda_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
-    const int min_batch_size = 32;
-
-    return get_op_batch_size(op) >= min_batch_size;
+    ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
 
-    GGML_UNUSED(dev);
+    return get_op_batch_size(op) >= dev_ctx->op_offload_min_batch_size;
 }
 
 static ggml_backend_event_t ggml_backend_cuda_device_event_new(ggml_backend_dev_t dev) {
@@ -5059,6 +5264,16 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t
         features.push_back({ "FA_ALL_QUANTS", "1" });
     #endif
 
+    {
+        const auto & info = ggml_cuda_info();
+        for (int id = 0; id < info.device_count; ++id) {
+            if (blackwell_mma_available(info.devices[id].cc)) {
+                features.push_back({ "BLACKWELL_NATIVE_FP4", "1"});
+                break;
+            }
+        }
+    }
+
     #undef _STRINGIFY
     #undef STRINGIFY
 
@@ -5106,6 +5321,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
         std::lock_guard lock(mutex);
         if (!initialized) {
             ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context;
+            const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
             int driverVersion = 0;
 
             for (int i = 0; i < ggml_cuda_info().device_count; i++) {
@@ -5130,6 +5346,8 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
                 dev_ctx->driver_major = driverVersion / 1000;
                 dev_ctx->driver_minor = (driverVersion - (dev_ctx->driver_major * 1000)) / 10;
                 dev_ctx->integrated = prop.integrated;
+                dev_ctx->op_offload_min_batch_size = min_batch_size;
+
                 ggml_backend_dev_t dev = new ggml_backend_device {
                     /* .iface   = */ ggml_backend_cuda_device_interface,
                     /* .reg     = */ ®,
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mean.cu b/ml/backend/ggml/ggml/src/ggml-cuda/mean.cu
index 347abc18660..49af5389957 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/mean.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/mean.cu
@@ -31,16 +31,15 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 #endif // USE_CUDA_GRAPH
     if ((nrows == 1) &&
 #ifdef USE_CUDA_GRAPH
-            // CUDA_GRAPHS_DISABLED
-            ((ncols > 65536) &&
-             ((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
-              ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
-              ctx.cuda_graph->disable_due_to_failed_graph_capture)) ||
-        // CUDA_GRAPHS ENABLED
-        ((ncols > 32768) &&
-         !((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
-           ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
-           ctx.cuda_graph->disable_due_to_failed_graph_capture))) {
+            // Determine if CUDA graphs are effectively disabled for this context
+            // (no graph instance exists and we're not capturing, OR graphs are explicitly enabled)
+            (((ncols > 65536) &&
+              (((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) ||
+               ctx.any_cuda_graph_enabled())) ||
+            // CUDA graphs are enabled - use lower threshold
+             ((ncols > 32768) &&
+              !(((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) ||
+                ctx.any_cuda_graph_enabled())))) {
 #else
         (ncols > 65536)) {
 #endif // USE_CUDA_GRAPH
@@ -63,6 +62,9 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
     const int id  = ggml_cuda_get_device();
     const int nsm = ggml_cuda_info().devices[id].nsm;
+
+    // Heuristic for block size selection to optimize occupancy.
+    // See discussion in: https://github.com/ggml-org/llama.cpp/pull/15132
     if ((nrows / nsm) < 2) {
         const dim3 block_dims(512, 1, 1);
         reduce_rows_f32<<>>(src0_d, dst_d, ncols);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh
index dcfa40f4d50..5d1dadd3e4f 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh
@@ -76,15 +76,29 @@ namespace ggml_cuda_mma {
         // For the A/C matrices this means I major == row major, J major == column major.
         // For the B matrix this means I major == column major, J major == row major.
         // MIRRORED == Each data value is held exactly once per thread subgroup.
-        DATA_LAYOUT_I_MAJOR           =  0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell.
-        DATA_LAYOUT_I_MAJOR_MIRRORED  = 10,
-        DATA_LAYOUT_J_MAJOR_MIRRORED  = 20,
+        DATA_LAYOUT_I_MAJOR           =  0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
+        DATA_LAYOUT_J_MAJOR           = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
+        DATA_LAYOUT_I_MAJOR_MIRRORED  = 20, // Volta, matrix A&B for RDNA3.
+        DATA_LAYOUT_J_MAJOR_MIRRORED  = 30,
     };
     // Implemented mma combinations are:
     //   - (I_MAJOR, I_MAJOR)          -> I_MAJOR
     //   - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
     //   - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
 
+    static constexpr bool is_i_major(const data_layout dl) {
+        return dl == DATA_LAYOUT_I_MAJOR ||
+               dl == DATA_LAYOUT_I_MAJOR_MIRRORED;
+    }
+
+    static constexpr __device__ data_layout get_input_data_layout() {
+#if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+        return DATA_LAYOUT_I_MAJOR_MIRRORED;
+#else
+        return DATA_LAYOUT_I_MAJOR;
+#endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+    }
+
     template 
     struct tile {};
 
@@ -115,9 +129,9 @@ namespace ggml_cuda_mma {
             } else if constexpr (I == 32 && J == 4) {
                 return threadIdx.x % 32;
             } else if constexpr (I == 16 && J == 16) {
-                return 4 * (threadIdx.x / 16) + l;
+                return threadIdx.x % 16;
             } else if constexpr (I == 32 && J == 32) {
-                return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
+                return threadIdx.x % 32;
             } else {
                 NO_DEVICE_CODE;
                 return -1;
@@ -132,9 +146,9 @@ namespace ggml_cuda_mma {
             } else if constexpr (I == 32 && J == 4) {
                 return 2 * (threadIdx.x / 32) + l;
             } else if constexpr (I == 16 && J == 16) {
-                return threadIdx.x % 16;
+                return 4 * (threadIdx.x / 16) + l;
             } else if constexpr (I == 32 && J == 32) {
-                return threadIdx.x % 32;
+                return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
             } else {
                 NO_DEVICE_CODE;
                 return -1;
@@ -171,28 +185,19 @@ namespace ggml_cuda_mma {
             }
         }
 #elif defined(AMD_WMMA_AVAILABLE)
-#if defined(RDNA4)
         static constexpr int ne = I * J / 32;
-#elif defined(RDNA3)
-        static constexpr int ne = (I == 16 && J == 16) ? I * J / 32 : I * J / 16;
-#endif // defined(RDNA4)
         T x[ne] = {0};
 
         static constexpr __device__ bool supported() {
             if (I == 16 && J == 16) return true;
+            if (I == 16 && J == 8) return true;
+            if (I == 16 && J == 4) return true;
             return false;
         }
 
         static __device__ __forceinline__ int get_i(const int l) {
-            if constexpr (I == 16 && J == 16) {
-#if defined(RDNA4)
-                return 8 * (threadIdx.x / 16) + l;
-#elif defined(RDNA3)
-                return 2 * l + (threadIdx.x / 16);
-#else
-                NO_DEVICE_CODE;
-                return -1;
-#endif // defined(RDNA4)
+            if constexpr (supported()) {
+                return threadIdx.x % 16;
             } else {
                 NO_DEVICE_CODE;
                 return -1;
@@ -201,7 +206,23 @@ namespace ggml_cuda_mma {
 
         static __device__ __forceinline__ int get_j(const int l) {
             if constexpr (I == 16 && J == 16) {
-                return threadIdx.x % 16;
+#if defined(RDNA3)
+                if constexpr (std::is_same_v || std::is_same_v) {
+                    // matrix C
+                    return 2 * l + (threadIdx.x / 16);
+                } else {
+                    // matrix A&B
+                    return l;
+                }
+#else
+                // matrix C is the transposed matrix A&B on RDNA4
+                return ne * (threadIdx.x / 16) + l;
+#endif // defined(RDNA3)
+            } else if constexpr (I == 16 && J == 8) {
+                // mmq input for RDNA4
+                return ne * (threadIdx.x / 16) + l;
+            } else if constexpr (I == 16 && J == 4) {
+                return ne * (threadIdx.x / 16) + l;
             } else {
                 NO_DEVICE_CODE;
                 return -1;
@@ -293,12 +314,7 @@ namespace ggml_cuda_mma {
             }
         }
 #elif defined(AMD_WMMA_AVAILABLE)
-#if defined(RDNA3)
-        // RDNA3 has duplicated data as input.
-        static constexpr int ne = I * J / 32 * 2;
-#else
         static constexpr int ne = I * J / 32;
-#endif // defined(RDNA3)
         half2 x[ne] = {{0.0f, 0.0f}};
 
         static constexpr __device__ bool supported() {
@@ -317,14 +333,33 @@ namespace ggml_cuda_mma {
 
         static __device__ __forceinline__ int get_j(const int l) {
             if constexpr (I == 16 && J == 8) {
-#if defined(RDNA4)
-                return 4 * (threadIdx.x / 16) + l;
-#elif defined(RDNA3)
-                return l;
-#else
+                return ne * (threadIdx.x / 16) + l;
+            } else {
                 NO_DEVICE_CODE;
                 return -1;
-#endif // defined(RDNA4)
+            }
+        }
+#elif defined(AMD_MFMA_AVAILABLE)
+        static constexpr int ne = I * J / 64;
+        half2 x[ne] = {{0.0f, 0.0f}};
+
+        static constexpr __device__ bool supported() {
+            if (I == 16 && J == 8) return true;
+            return false;
+        }
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            if constexpr (I == 16 && J == 8) {
+                return threadIdx.x % 16;
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            if constexpr (I == 16 && J == 8) {
+                return ne * (threadIdx.x / 16) + l;
             } else {
                 NO_DEVICE_CODE;
                 return -1;
@@ -382,42 +417,34 @@ namespace ggml_cuda_mma {
         static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
 
 #if defined(AMD_WMMA_AVAILABLE)
-#if defined(RDNA3)
-        // RDNA3 has duplicated data as input.
-        static constexpr int ne = I * J / 32 * 2;
-#else
-        static constexpr int ne = I * J / 32;
-#endif // defined(RDNA3)
+        static constexpr int ne = tile::ne;
         nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
 
         static constexpr __device__ bool supported() {
-            if (I == 16 && J == 8) return true;
-            return false;
+            return tile::supported();
         }
 
         static __device__ __forceinline__ int get_i(const int l) {
-            if constexpr (I == 16 && J == 8) {
-                return threadIdx.x % 16;
-            } else {
-                NO_DEVICE_CODE;
-                return -1;
-            }
+            return tile::get_i(l);
         }
 
         static __device__ __forceinline__ int get_j(const int l) {
-            if constexpr (I == 16 && J == 8) {
-#if defined(RDNA4)
-                return 4 * (threadIdx.x / 16) + l;
-#elif defined(RDNA3)
-                return l;
-#else
-                NO_DEVICE_CODE;
-                return -1;
-#endif // defined(RDNA4)
-            } else {
-                NO_DEVICE_CODE;
-                return -1;
-            }
+            return tile::get_j(l);
+        }
+#elif defined(AMD_MFMA_AVAILABLE)
+        static constexpr int ne = tile::ne;
+        nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
+
+        static constexpr __device__ bool supported() {
+            return tile::supported();
+        }
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            return tile::get_i(l);
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            return tile::get_j(l);
         }
 #else
         static constexpr int ne = I * J / WARP_SIZE;
@@ -458,11 +485,87 @@ namespace ggml_cuda_mma {
 #endif  // defined(AMD_WMMA_AVAILABLE)
     };
 
+    template 
+    struct tile {
+        static constexpr int         I  = I_;
+        static constexpr int         J  = J_;
+        static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR;
+
+        static constexpr int ne = tile::ne;
+        T x[ne] = {0};
+
+        static constexpr __device__ bool supported() {
+            return tile::supported();
+        }
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            return tile::get_j(l);
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            return tile::get_i(l);
+        }
+    };
+
+    template 
+    struct tile {
+        static constexpr int         I  = I_;
+        static constexpr int         J  = J_;
+        static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
+
+        // RDNA3
+        static constexpr int         ne = I * J / 32 * 2;
+
+        T x[ne] = {0};
+
+        static constexpr __device__ bool supported() {
+            if (I == 16 && J == 16) return true;
+            if (I == 16 && J == 8)  return true;
+            if (I == 16 && J == 4)  return true;
+            return false;
+        }
+
+        static __device__ __forceinline__ int get_i(const int /*l*/) {
+            if constexpr (supported()) {
+                return threadIdx.x % 16;
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            if constexpr (supported()) {
+                return l;
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+    };
+
     template 
     struct tile {
         static constexpr int         I  = I_;
         static constexpr int         J  = J_;
         static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
+#if defined(RDNA3)
+        static constexpr int         ne = tile::ne;
+
+        half2 x[ne] = {{0.0f, 0.0f}};
+
+        static constexpr __device__ bool supported() {
+            return tile::supported();
+        }
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            return tile::get_i(l);
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            return tile::get_j(l);
+        }
+#else // Volta
         static constexpr int         ne = I * J / (WARP_SIZE/4);
 
         half2 x[ne] = {{0.0f, 0.0f}};
@@ -489,6 +592,29 @@ namespace ggml_cuda_mma {
                 return -1;
             }
         }
+#endif // defined(RDNA3)
+    };
+
+    template 
+    struct tile {
+        static constexpr int         I  = I_;
+        static constexpr int         J  = J_;
+        static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
+        static constexpr int         ne = tile::ne;
+
+        nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
+
+        static constexpr __device__ bool supported() {
+            return tile::supported();
+        }
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            return tile::get_i(l);
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            return tile::get_j(l);
+        }
     };
 
     template 
@@ -542,6 +668,21 @@ namespace ggml_cuda_mma {
 
         return ret;
     }
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+    template 
+    static __device__ __forceinline__ tile get_half2(const tile & tile_float) {
+        tile ret;
+#pragma unroll
+        for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {
+            ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
+        }
+        return ret;
+    }
+
+    static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {
+        NO_DEVICE_CODE;
+        return tile<8, 8, half2>{};
+    }
 #else // Volta
     template 
     static __device__ __forceinline__ tile get_half2(const tile & tile_float) {
@@ -560,6 +701,19 @@ namespace ggml_cuda_mma {
     }
 #endif // defined(TURING_MMA_AVAILABLE)
 
+    static __device__ __forceinline__ void make_identity_mat(tile<16, 8, half2> & t) {
+#if defined(RDNA4)
+        const int row = t.get_i(0);
+        const int left_right = t.get_j(0) / 4;
+        const int up_down = row / 8;
+        const int idx = row % 8;
+        reinterpret_cast(t.x)[idx] = left_right == up_down ? 1.0f : 0.0f;
+#else
+        GGML_UNUSED_VARS(t);
+        NO_DEVICE_CODE;
+#endif // defined(RDNA4)
+    }
+
     template 
     static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) {
 #if defined(AMD_MFMA_AVAILABLE)
@@ -569,55 +723,28 @@ namespace ggml_cuda_mma {
                 t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
             }
         } else {
-            int64_t * xi = (int64_t *) t.x;
-            const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
-            xi[0] = xs[0];
+            ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
         }
 #elif defined(AMD_WMMA_AVAILABLE)
-        if constexpr (std::is_same_v || std::is_same_v) {
-#if defined(RDNA4)
-                ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
-#elif defined(RDNA3)
-                ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
-                ggml_cuda_memcpy_1(t.x + t.ne/2, xs0 + t.get_i(0) * stride + t.get_j(t.ne/2));
-#else
-                NO_DEVICE_CODE;
-#endif // defined(RDNA4)
-        } else if constexpr (std::is_same_v) {
-            if constexpr (I == 16 && J == 4) {
-                int64_t * xi = (int64_t *) t.x;
-#if defined(RDNA4)
-                const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
-                xi[0] = xs[0];
-#elif defined(RDNA3)
-                static_assert(tile::ne >= 4, "fragment too small");
-                const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride);
-                xi[0] = xs[0];
-                xi[1] = xs[1];
-#endif // defined(RDNA4)
-            } else if constexpr (I == 16 && J == 8) {
-                int64_t * xi = (int64_t *) t.x;
-#if defined(RDNA4)
-                const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
-                xi[0] = xs[0];
-
-                const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
-                xi[1] = xs1[0];
-#elif defined(RDNA3)
-                static_assert(tile::ne >= 8, "fragment too small");
-                const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride);
-                // contiguous four 64-bit chunks per lane for the wider RDNA3 fragment
-                xi[0] = xs[0];
-                xi[1] = xs[1];
-                const int64_t * xs1 = xs + 2;
-                xi[2] = xs1[0];
-                xi[3] = xs1[1];
-#endif // defined(RDNA4)
+        // All wmma layout has contiguous data when i-major.
+        if constexpr (is_i_major(dl)) {
+            // the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes()
+            constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes();
+            if constexpr (sizeof(t.x) > aligned_copy_bytes) {
+                static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size");
+                constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes;
+#pragma unroll
+                for (int i = 0; i < aligned_copy_count; ++i) {
+                    ggml_cuda_memcpy_1(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i));
+                }
             } else {
-                NO_DEVICE_CODE;
+                ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
             }
         } else {
-            NO_DEVICE_CODE;
+#pragma unroll
+            for (int l = 0; l < t.ne; ++l) {
+                t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
+            }
         }
 #else
 #pragma unroll
@@ -660,9 +787,9 @@ namespace ggml_cuda_mma {
 #endif // TURING_MMA_AVAILABLE
     }
 
-    template 
+    template 
     static __device__ __forceinline__ void load_ldmatrix(
-            tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
+            tile<16, 8, T, dl> & t, const T * __restrict__ xs0, const int stride) {
 #if defined(TURING_MMA_AVAILABLE)
         int * xi = (int * ) t.x;
         const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
@@ -826,14 +953,54 @@ namespace ggml_cuda_mma {
             : "+r"(Dxi[2]), "+r"(Dxi[3])
             : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#elif defined(AMD_WMMA_AVAILABLE)
+#if defined(RDNA4)
+        using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
+        halfx8_t& acc_frag = reinterpret_cast(D.x[0]);
+        const halfx8_t& a_frag = reinterpret_cast(A.x[0]);
+        const halfx8_t& b_frag = reinterpret_cast(B.x[0]);
+        acc_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
+#else
+        GGML_UNUSED_VARS(D, A, B);
+        NO_DEVICE_CODE;
+#endif // defined(RDNA4)
+#elif defined(AMD_MFMA_AVAILABLE)
+        // MFMA: FP16 input, FP32 accumulate, convert back to half2.
+        using halfx4_t = __attribute__((ext_vector_type(4))) _Float16;
+        using floatx4_t = __attribute__((ext_vector_type(4))) float;
+
+        // Convert existing half2 accumulator to float for MFMA:
+        floatx4_t acc_f32;
+        {
+            const halfx4_t acc_h = reinterpret_cast(D.x[0]);
+#pragma unroll
+            for (int i = 0; i < 4; ++i) {
+                acc_f32[i] = (float)acc_h[i];
+            }
+        }
+
+        const halfx4_t& a_frag = reinterpret_cast(A.x[0]);
+        const halfx4_t& b_frag = reinterpret_cast(B.x[0]);
+        acc_f32 = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_f32, 0, 0, 0);
+
+        // Convert back to half2:
+        {
+            halfx4_t result_h;
+#pragma unroll
+            for (int i = 0; i < 4; ++i) {
+                result_h[i] = (_Float16)acc_f32[i];
+            }
+            reinterpret_cast(D.x[0]) = result_h;
+        }
 #else
         GGML_UNUSED_VARS(D, A, B);
         NO_DEVICE_CODE;
 #endif // TURING_MMA_AVAILABLE
     }
 
+    template 
     static __device__ __forceinline__ void mma(
-            tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) {
+            tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) {
 #ifdef AMPERE_MMA_AVAILABLE
         const int * Axi = (const int *) A.x;
         const int * Bxi = (const int *) B.x;
@@ -847,6 +1014,53 @@ namespace ggml_cuda_mma {
 #endif // AMPERE_MMA_AVAILABLE
     }
 
+    template 
+    static __device__ __forceinline__ void mma(
+            tile<16, 16, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<16, 8, float, dl_ab> & B) {
+#ifdef AMD_MFMA_AVAILABLE
+        using floatx4_t = __attribute__((ext_vector_type(4))) float;
+        floatx4_t& acc_frag = reinterpret_cast(D.x[0]);
+#if defined(CDNA3)
+        using floatx2_t = __attribute__((ext_vector_type(2))) float;
+        const floatx2_t& a_frag = reinterpret_cast(A.x[0]);
+        const floatx2_t& b_frag = reinterpret_cast(B.x[0]);
+        acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0);
+#elif defined(CDNA2) || defined(CDNA1)
+#pragma unroll
+        for (int i = 0; i < 2; ++i) {
+            acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0);
+        }
+#else
+        GGML_UNUSED_VARS(D, A, B);
+        NO_DEVICE_CODE;
+#endif // defined(CDNA3)
+#else
+        GGML_UNUSED_VARS(D, A, B);
+        NO_DEVICE_CODE;
+#endif // AMD_MFMA_AVAILABLE
+    }
+
+    static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> &     D,
+                                                            const tile<16, 8, int> & A,
+                                                            const tile<8, 8, int> &  B,
+                                                            uint32_t                 a_scale,
+                                                            uint32_t                 b_scale) {
+#ifdef BLACKWELL_MMA_AVAILABLE
+        const int * Axi = (const int *) A.x;
+        const int * Bxi = (const int *) B.x;
+        float *     Dxi = (float *) D.x;
+
+        asm volatile(
+            "mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
+            "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
+            "%10, {0, 0}, %11, {0, 0};"
+            : "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
+#else
+        GGML_UNUSED_VARS(D, A, B, a_scale, b_scale);
+#endif  // BLACKWELL_MMA_AVAILABLE
+    }
+
     static __device__ __forceinline__ void mma(
             tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
 #ifdef TURING_MMA_AVAILABLE
@@ -887,8 +1101,9 @@ namespace ggml_cuda_mma {
 #endif // AMPERE_MMA_AVAILABLE
     }
 
+    template 
     static __device__ __forceinline__ void mma(
-            tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
+            tile<16, 16, float, dl_d> & D, const tile<16, 8, half2, dl_ab> & A, const tile<16, 8, half2, dl_ab> & B) {
 #ifdef TURING_MMA_AVAILABLE
         const int * Axi = (const int *) A.x;
         const int * Bxi = (const int *) B.x;
@@ -934,14 +1149,22 @@ namespace ggml_cuda_mma {
         GGML_UNUSED_VARS(D, A, B);
         NO_DEVICE_CODE;
 #endif // RDNA4
+#elif defined(AMD_MFMA_AVAILABLE)
+        using halfx4_t = __attribute__((ext_vector_type(4))) _Float16;
+        using floatx4_t = __attribute__((ext_vector_type(4))) float;
+        floatx4_t& acc_frag = reinterpret_cast(D.x[0]);
+        const halfx4_t& a_frag = reinterpret_cast(A.x[0]);
+        const halfx4_t& b_frag = reinterpret_cast(B.x[0]);
+        acc_frag = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_frag, 0, 0, 0);
 #else
         GGML_UNUSED_VARS(D, A, B);
         NO_DEVICE_CODE;
 #endif // TURING_MMA_AVAILABLE
     }
 
+    template 
     static __device__ __forceinline__ void mma(
-            tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) {
+            tile<16, 16, float, dl_d> & D, const tile<16, 8, nv_bfloat162, dl_ab> & A, const tile<16, 8, nv_bfloat162, dl_ab> & B) {
 #if defined(AMD_WMMA_AVAILABLE)
 #if defined(RDNA4)
         using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
@@ -960,15 +1183,36 @@ namespace ggml_cuda_mma {
 #else
         GGML_UNUSED_VARS(D, A, B);
         NO_DEVICE_CODE;
-#endif // RDNA4
+#endif // defined(RDNA4)
+#elif defined(AMD_MFMA_AVAILABLE)
+        using floatx4_t = __attribute__((ext_vector_type(4))) float;
+        floatx4_t& acc_frag = reinterpret_cast(D.x[0]);
+#if defined(CDNA3) || defined(CDNA2)
+        using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16;
+        const bf16x4_t& a_frag = reinterpret_cast(A.x[0]);
+        const bf16x4_t& b_frag = reinterpret_cast(B.x[0]);
+        acc_frag = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_frag, b_frag, acc_frag, 0, 0, 0);
+#elif defined(CDNA1)
+#pragma unroll
+        for (int i = 0; i < 2; ++i) {
+            using bf16x2_t = __attribute__((ext_vector_type(2))) __bf16;
+            const bf16x2_t& a_frag = reinterpret_cast(A.x[i]);
+            const bf16x2_t& b_frag = reinterpret_cast(B.x[i]);
+            acc_frag = __builtin_amdgcn_mfma_f32_16x16x8bf16(a_frag, b_frag, acc_frag, 0, 0, 0);
+        }
 #else
         GGML_UNUSED_VARS(D, A, B);
         NO_DEVICE_CODE;
-#endif // AMPERE_MMA_AVAILABLE
+#endif // defined(CDNA3) || defined(CDNA2)
+#else
+        GGML_UNUSED_VARS(D, A, B);
+        NO_DEVICE_CODE;
+#endif // defined(AMD_WMMA_AVAILABLE)
     }
 
+    template 
     static __device__ __forceinline__ void mma(
-            tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
+            tile<16, 16, int, dl_d> & D, const tile<16, 8, int, dl_ab> & A, const tile<16, 8, int, dl_ab> & B) {
 #if defined(AMD_MFMA_AVAILABLE)
         using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
         int32x4_t * acc = (int32x4_t *) D.x;
@@ -1122,8 +1366,9 @@ namespace ggml_cuda_mma {
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
     }
 
-static __device__ __forceinline__ void mma(
-            tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) {
+    template 
+    static __device__ __forceinline__ void mma(
+            tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) {
 #if defined(AMD_WMMA_AVAILABLE)
         using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
         int32x8_t * acc = (int32x8_t *) D.x;
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cu b/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cu
index 6643f243b12..aad4c34aa66 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cu
@@ -2,6 +2,13 @@
 #include "mmf.cuh"
 #include "mmid.cuh"
 
+static __forceinline__ int mmf_get_rows_per_block(const int cc) {
+    if (GGML_CUDA_CC_IS_CDNA(cc)) {
+        return MMF_ROWS_PER_BLOCK_CDNA;
+    } else {
+        return MMF_ROWS_PER_BLOCK;
+    }
+}
 
 void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
     GGML_ASSERT(        src1->type == GGML_TYPE_F32);
@@ -89,28 +96,32 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
         ids_info_ptr = &ids_info;
     }
 
+    const int device    = ggml_cuda_get_device();
+    const int cc        = ggml_cuda_info().devices[device].cc;
+    const int rows_per_block = mmf_get_rows_per_block(cc);
+
     switch (src0->type) {
         case GGML_TYPE_F32: {
             const float * src0_d = (const float *) src0->data;
             constexpr int vals_per_T = 1;
-            mul_mat_f_switch_cols_per_block(
-                src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
+            mul_mat_f_switch_rows_per_block(
+                rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
                 ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
                 ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
         } break;
         case GGML_TYPE_F16: {
             const half2 * src0_d = (const half2 *) src0->data;
             constexpr int vals_per_T = 2;
-            mul_mat_f_switch_cols_per_block(
-                src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
+            mul_mat_f_switch_rows_per_block(
+                rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
                 ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
                 ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
         } break;
         case GGML_TYPE_BF16: {
             const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
             constexpr int vals_per_T = 2;
-            mul_mat_f_switch_cols_per_block(
-                src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
+            mul_mat_f_switch_rows_per_block(
+                rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
                 ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
                 ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
         } break;
@@ -140,7 +151,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
             return false;
         }
     }
-    if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
+    if (src0_ne[1] % mmf_get_rows_per_block(cc) != 0) {
+        return false;
+    }
+
+    if (GGML_CUDA_CC_IS_CDNA3(cc) && type == GGML_TYPE_BF16) {
         return false;
     }
 
@@ -153,6 +168,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
     } else {
         if (GGML_CUDA_CC_IS_RDNA3_0(cc) && src1_ncols > 8) {
             return false;
+        } else if (GGML_CUDA_CC_IS_CDNA2(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) {
+            //TODO: truse CDNA2 as CDNA1, tune the perf when CDNA2 is available.
+            return false;
+        } else if (GGML_CUDA_CC_IS_CDNA1(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) {
+            return false;
         } else if (src1_ncols > 16) {
             return false;
         }
@@ -160,11 +180,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
 
     switch (type) {
         case GGML_TYPE_F32:
-            return ampere_mma_available(cc);
+            return ampere_mma_available(cc) || amd_mfma_available(cc);
         case GGML_TYPE_F16:
-            return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc);
+            return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);
         case GGML_TYPE_BF16:
-            return ampere_mma_available(cc) || amd_wmma_available(cc);
+            return ampere_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);
         default:
             return false;
     }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cuh
index e1c695c5c0f..c2a8d54c95a 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cuh
@@ -7,6 +7,31 @@
 using namespace ggml_cuda_mma;
 
 #define MMF_ROWS_PER_BLOCK 32
+#define MMF_ROWS_PER_BLOCK_CDNA 64
+
+static __forceinline__ int64_t mmf_get_max_block_size(int cc) {
+    if (GGML_CUDA_CC_IS_CDNA(cc)) {
+        return 512;
+    } else {
+        return 256;
+    }
+}
+
+static __forceinline__ int mmf_get_padding(int cc) {
+    if (GGML_CUDA_CC_IS_CDNA(cc)) {
+        return 2;
+    } else {
+        return 4;
+    }
+}
+
+static constexpr __device__ int mmf_get_padding() {
+#if defined(AMD_MFMA_AVAILABLE)
+    return 2;
+#else
+    return 4;
+#endif // defined(AMD_MFMA_AVAILABLE)
+}
 
 struct mmf_ids_data {
     const int32_t * ids_src_compact = nullptr;
@@ -29,21 +54,25 @@ static __global__ void mul_mat_f(
         const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
         const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
 // TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
-#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
+#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
 #if defined(AMD_WMMA_AVAILABLE)
-    // Special case for tf32, just dummy mma layout as wmma doesn't support it.
-    constexpr int tile_B_I = std::is_same_v ? 8 : 16;
-    constexpr int tile_C_J = std::is_same_v ? 8 : 16;
-    typedef tile<16,       8, T>     tile_A;
-    typedef tile     tile_B;
-    typedef tile<16,       tile_C_J, float> tile_C;
+    if constexpr (!(std::is_same_v || std::is_same_v) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
+    typedef tile<16, 8,  T,     get_input_data_layout()> tile_A;
+    typedef tile<16, 8,  T,     get_input_data_layout()> tile_B;
+    typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR>     tile_C;
+#elif defined(AMD_MFMA_AVAILABLE)
+    if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
+    typedef tile<16, 8,  T,     DATA_LAYOUT_I_MAJOR> tile_A;
+    typedef tile<16, 8,  T,     DATA_LAYOUT_I_MAJOR> tile_B;
+    typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
 #else
 #ifdef VOLTA_MMA_AVAILABLE
-    if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else {
+    if constexpr (!std::is_same_v || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
     typedef tile<32, 4, T,     DATA_LAYOUT_I_MAJOR>          tile_A;
     typedef tile< 8, 4, T,     DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
     typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR>          tile_C;
 #else
+    if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
     typedef tile<16, 8, T>     tile_A;
     typedef tile<8,  8, T>     tile_B;
     typedef tile<16, 8, float> tile_C;
@@ -55,7 +84,7 @@ static __global__ void mul_mat_f(
     }
 
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-    constexpr int tile_k_padded = warp_size + 4;
+    constexpr int tile_k_padded = warp_size + mmf_get_padding();
     constexpr int ntA = rows_per_block / tile_A::I;
     constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
 
@@ -196,7 +225,7 @@ static __global__ void mul_mat_f(
     }
 
     float * buf_iw = (float *) compute_base;
-    constexpr int kiw = nwarps*rows_per_block + 4;
+    constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();
 
     if (nwarps > 1) {
         __syncthreads();
@@ -226,27 +255,34 @@ static __global__ void mul_mat_f(
             return;
         }
 
-        float sum = 0.0f;
-        static_assert(rows_per_block == warp_size, "need loop/check");
+        float sum[rows_per_block/warp_size] = {0.0f};
+        static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
 #pragma unroll
         for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
-            const int i = i0 + threadIdx.x;
+#pragma unroll
+            for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
+                const int i = i0 + i1*warp_size + threadIdx.x;
 
-            sum += buf_iw[j*kiw + i];
+                sum[i1] += buf_iw[j*kiw + i];
+            }
         }
 
         if constexpr (!has_ids) {
-            dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
+#pragma unroll
+            for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
+                dst[j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
+            }
         } else {
             const int slot = (j < cols_per_block) ? slot_map[j] : -1;
             if (slot >= 0 && (col_base + j) < ncols_dst_total) {
-                dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum;
+#pragma unroll
+                for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
+                    dst[slot*stride_channel_dst + j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
+                }
             }
         }
     }
-#ifdef VOLTA_MMA_AVAILABLE
     }
-#endif //VOLTA_MMA_AVAILABLE
 #else
     GGML_UNUSED_VARS(x, y, ids, dst,
         ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
@@ -254,7 +290,7 @@ static __global__ void mul_mat_f(
         channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
         sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
     NO_DEVICE_CODE;
-#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
+#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
 }
 
 //This kernel is for larger batch sizes of mul_mat_id
@@ -269,21 +305,25 @@ static __global__ void mul_mat_f_ids(
         const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
         const uint3 sis1_fd, const uint3 nch_fd) {
 // TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
-#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
+#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
 #if defined(AMD_WMMA_AVAILABLE)
-    // Special case for tf32, just dummy mma layout as wmma doesn't support it.
-    constexpr int tile_B_I = std::is_same_v ? 8 : 16;
-    constexpr int tile_C_J = std::is_same_v ? 8 : 16;
-    typedef tile<16,       8, T>     tile_A;
-    typedef tile     tile_B;
-    typedef tile<16,       tile_C_J, float> tile_C;
+    if constexpr (!(std::is_same_v || std::is_same_v) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
+    typedef tile<16, 8,  T,     get_input_data_layout()> tile_A;
+    typedef tile<16, 8,  T,     get_input_data_layout()> tile_B;
+    typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR>     tile_C;
+#elif defined(AMD_MFMA_AVAILABLE)
+    if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
+    typedef tile<16, 8,  T,     DATA_LAYOUT_I_MAJOR> tile_A;
+    typedef tile<16, 8,  T,     DATA_LAYOUT_I_MAJOR> tile_B;
+    typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
 #else
 #ifdef VOLTA_MMA_AVAILABLE
-    if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else {
+    if constexpr (!std::is_same_v || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
     typedef tile<32, 4, T,     DATA_LAYOUT_I_MAJOR>          tile_A;
     typedef tile< 8, 4, T,     DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
     typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR>          tile_C;
 #else
+    if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
     typedef tile<16, 8, T>     tile_A;
     typedef tile<8,  8, T>     tile_B;
     typedef tile<16, 8, float> tile_C;
@@ -296,7 +336,7 @@ static __global__ void mul_mat_f_ids(
 
 
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-    constexpr int tile_k_padded = warp_size + 4;
+    constexpr int tile_k_padded = warp_size + mmf_get_padding();
     constexpr int ntA = rows_per_block / tile_A::I;
     constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
 
@@ -463,7 +503,7 @@ static __global__ void mul_mat_f_ids(
     }
 
     float * buf_iw = (float *) compute_base;
-    constexpr int kiw = nwarps*rows_per_block + 4;
+    constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();
 
     if (nwarps > 1) {
         __syncthreads();
@@ -493,13 +533,16 @@ static __global__ void mul_mat_f_ids(
             return;
         }
 
-        float sum = 0.0f;
-        static_assert(rows_per_block == warp_size, "need loop/check");
+        float sum[rows_per_block/warp_size] = {0.0f};
+        static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
 #pragma unroll
         for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
-            const int i = i0 + threadIdx.x;
+#pragma unroll
+            for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
+                const int i = i0 + i1*warp_size + threadIdx.x;
 
-            sum += buf_iw[j*kiw + i];
+                sum[i1] += buf_iw[j * kiw + i];
+            }
         }
 
         const int global_j = col_base + j;
@@ -509,23 +552,24 @@ static __global__ void mul_mat_f_ids(
             const int token = (int) qrm.x;
             if (token < ncols_dst_total) {
                 const int slot = (int) qrm.y;
-                dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum;
+#pragma unroll
+                for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
+                    dst[slot * stride_channel_dst + token * stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
+                }
             }
         }
     }
-#ifdef VOLTA_MMA_AVAILABLE
     }
-#endif // VOLTA_MMA_AVAILABLE
 #else
     GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
         ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
         channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
         sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
     NO_DEVICE_CODE;
-#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
+#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
 }
 
-template
+template
 static inline void mul_mat_f_switch_ids(
         const T * x, const float * y, const int32_t * ids, float * dst,
         const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst,
@@ -549,7 +593,7 @@ static inline void mul_mat_f_switch_ids(
         const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);
         const uint3 nch_fd  = init_fastdiv_values((uint32_t) nchannels_dst);
 
-        mul_mat_f_ids<<>>
+        mul_mat_f_ids<<>>
             (x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,
             ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
             channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
@@ -560,19 +604,19 @@ static inline void mul_mat_f_switch_ids(
         dim3 block_nums_ids = block_nums;
         block_nums_ids.y *= col_tiles;
 
-        mul_mat_f<<>>
+        mul_mat_f<<>>
             (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
              stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
              sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
     } else {
-        mul_mat_f<<>>
+        mul_mat_f<<>>
             (x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
              stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
              sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
     }
 }
 
-template 
+template 
 void mul_mat_f_cuda(
         const T * x, const float * y, const int32_t * ids, float * dst,
         const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
@@ -601,7 +645,7 @@ void mul_mat_f_cuda(
 
     int64_t nwarps_best     = 1;
     int64_t niter_best      = (ncols_x + warp_size*2 - 1) / (warp_size*2);
-    int64_t max_block_size  = 256;
+    int64_t max_block_size  = mmf_get_max_block_size(cc);
     for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
         const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
         if (niter < niter_best) {
@@ -610,10 +654,9 @@ void mul_mat_f_cuda(
         }
     }
 
-    constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
-    const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
-    const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I;
-    const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4;
+    const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + mmf_get_padding(cc)) * 4;
+    const int nbytes_cols_per_block_pad = (amd_wmma_available(cc) || amd_mfma_available(cc)) ? tile_B_16::I : tile_B_8::I;
+    const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + mmf_get_padding(cc)) * 4;
     const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
     const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
     const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
@@ -624,56 +667,56 @@ void mul_mat_f_cuda(
 
     switch (nwarps_best) {
         case 1: {
-            mul_mat_f_switch_ids(
+            mul_mat_f_switch_ids(
                 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
                 ids_data);
         } break;
         case 2: {
-            mul_mat_f_switch_ids(
+            mul_mat_f_switch_ids(
                 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
                 ids_data);
         } break;
         case 3: {
-            mul_mat_f_switch_ids(
+            mul_mat_f_switch_ids(
                 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
                 ids_data);
         } break;
         case 4: {
-            mul_mat_f_switch_ids(
+            mul_mat_f_switch_ids(
                 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
                 ids_data);
         } break;
         case 5: {
-            mul_mat_f_switch_ids(
+            mul_mat_f_switch_ids(
                 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
                 ids_data);
         } break;
         case 6: {
-            mul_mat_f_switch_ids(
+            mul_mat_f_switch_ids(
                 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
                 ids_data);
         } break;
         case 7: {
-            mul_mat_f_switch_ids(
+            mul_mat_f_switch_ids(
                 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
                 ids_data);
         } break;
         case 8: {
-            mul_mat_f_switch_ids(
+            mul_mat_f_switch_ids(
                 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
@@ -687,7 +730,7 @@ void mul_mat_f_cuda(
     GGML_UNUSED_VARS(nchannels_y);
 }
 
-template 
+template 
 static void mul_mat_f_switch_cols_per_block(
         const T * x, const float * y, const int32_t * ids, float * dst,
         const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
@@ -704,82 +747,82 @@ static void mul_mat_f_switch_cols_per_block(
 
     switch (ncols_case) {
         case  1: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case  2: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case  3: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case  4: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case  5: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y,  stride_sample_dst, stream, ids_data);
         } break;
         case  6: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case  7: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case  8: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case  9: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case 10: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case 11: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case 12: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case 13: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case 14: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case 15: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case 16: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
@@ -789,8 +832,36 @@ static void mul_mat_f_switch_cols_per_block(
     }
 }
 
-#define DECL_MMF_CASE_HELPER(T, ncols_dst) \
-    template void mul_mat_f_cuda( \
+template 
+static void mul_mat_f_switch_rows_per_block(
+        const int rows_per_block, const T * x, const float * y, const int32_t * ids, float * dst,
+        const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
+        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+        const int64_t stride_col_id, const int stride_row_id,
+        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+        cudaStream_t stream, const mmf_ids_data * ids_data) {
+    switch (rows_per_block) {
+        case MMF_ROWS_PER_BLOCK: {
+            mul_mat_f_switch_cols_per_block(
+                x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
+        } break;
+        case MMF_ROWS_PER_BLOCK_CDNA: {
+            mul_mat_f_switch_cols_per_block(
+                x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
+        } break;
+        default:
+            GGML_ABORT("unsupported rows_per_block: %i", rows_per_block);
+    }
+}
+
+#define DECL_MMF_CASE_HELPER(T, nrows_dst, ncols_dst) \
+    template void mul_mat_f_cuda( \
         const T * x, const float * y, const int32_t * ids, float * dst, \
         const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
         const int64_t stride_col_id, const int64_t stride_row_id, \
@@ -799,16 +870,22 @@ static void mul_mat_f_switch_cols_per_block(
         const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
         cudaStream_t stream, const mmf_ids_data * ids_data);
 
-#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_MUSA)
 #define DECL_MMF_CASE_EXTERN(ncols_dst) \
-    extern DECL_MMF_CASE_HELPER(float, ncols_dst) \
-    extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \
-    extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
+    extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \
+    extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \
+    extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \
+    extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
+    extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
+    extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)
 
 #define DECL_MMF_CASE(ncols_dst) \
-    DECL_MMF_CASE_HELPER(float, ncols_dst) \
-    DECL_MMF_CASE_HELPER(half2, ncols_dst) \
-    DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
+    DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \
+    DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \
+    DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \
+    DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
+    DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
+    DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)
 
 DECL_MMF_CASE_EXTERN(1);
 DECL_MMF_CASE_EXTERN(2);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
index f7a2cbca90f..9a69f41d159 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
@@ -1,3 +1,4 @@
+#include "common.cuh"
 #include "mmq.cuh"
 #include "quantize.cuh"
 #include "mmid.cuh"
@@ -114,6 +115,9 @@ void ggml_cuda_mul_mat_q(
     const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
                             || GGML_CUDA_CC_IS_CDNA(cc);
 
+    // TODO: tighter pool buffer size vs q8 path
+    const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4;
+
     if (!ids) {
         const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
             get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
@@ -123,12 +127,24 @@ void ggml_cuda_mul_mat_q(
             const int64_t s11 = src1->nb[1] / ts_src1;
             const int64_t s12 = src1->nb[2] / ts_src1;
             const int64_t s13 = src1->nb[3] / ts_src1;
-            quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type,
-                ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
+            if (use_native_mxfp4) {
+                static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1));
+                quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
+                                        ne11, ne12, ne13, stream);
+
+            } else {
+                quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
+                                       ne11, ne12, ne13, stream);
+            }
             CUDA_CHECK(cudaGetLastError());
         }
 
-        const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
+        // Stride depends on quantization format
+        const int64_t s12 = use_native_mxfp4 ?
+                                ne11 * ne10_padded * sizeof(block_fp4_mmq) /
+                                    (8 * QK_MXFP4 * sizeof(int))  // block_fp4_mmq holds 256 values (8 blocks of 32)
+                                :
+                                ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
         const int64_t s13 = ne12*s12;
 
         const mmq_args args = {
@@ -174,13 +190,20 @@ void ggml_cuda_mul_mat_q(
     {
         const int64_t s11 = src1->nb[1] / ts_src1;
         const int64_t s12 = src1->nb[2] / ts_src1;
-        const int64_t s13 = src1->nb[2] / ts_src1;
-        quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type,
-            ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
+        const int64_t s13 = src1->nb[3] / ts_src1;
+
+        if (use_native_mxfp4) {
+            quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
+                                    ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
+        } else {
+            quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
+                                   ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
+        }
         CUDA_CHECK(cudaGetLastError());
     }
 
-    const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
+    const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) :
+                                           ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
     const int64_t s13 = ne12*s12;
 
     // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
@@ -236,7 +259,7 @@ void ggml_cuda_op_mul_mat_q(
     GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_padded_row_size);
 }
 
-bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
+bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts) {
 #ifdef GGML_CUDA_FORCE_CUBLAS
     return false;
 #endif // GGML_CUDA_FORCE_CUBLAS
@@ -297,7 +320,10 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
         if (GGML_CUDA_CC_IS_CDNA3(cc)) {
             return true;
         }
-        if (ne11 <= 128 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) {
+        if (n_experts > 64 || ne11 <= 128) {
+            return true;
+        }
+        if (type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) {
             return true;
         }
         if (ne11 <= 256 && (type == GGML_TYPE_Q4_K || type == GGML_TYPE_Q5_K)) {
@@ -307,6 +333,31 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
     }
 
     if (amd_wmma_available(cc)) {
+        if (GGML_CUDA_CC_IS_RDNA3(cc)) {
+            // High expert counts are almost always better on MMQ due to
+            //     the synchronization overhead in the cuBLAS/hipBLAS path:
+            // https://github.com/ggml-org/llama.cpp/pull/18202
+            if (n_experts >= 64) {
+                return true;
+            }
+
+            // For some quantization types MMQ can have lower peak TOPS than hipBLAS
+            //     so it's only faster for sufficiently small batch sizes:
+            switch (type) {
+                case GGML_TYPE_Q2_K:
+                    return ne11 <= 128;
+                case GGML_TYPE_Q6_K:
+                    return ne11 <= (GGML_CUDA_CC_IS_RDNA3_0(cc) ? 128 : 256);
+                case GGML_TYPE_IQ2_XS:
+                case GGML_TYPE_IQ2_S:
+                    return GGML_CUDA_CC_IS_RDNA3_5(cc) || ne11 <= 128;
+                default:
+                    return true;
+            }
+        }
+
+        // For RDNA4 MMQ is consistently faster than dequantization + hipBLAS:
+        // https://github.com/ggml-org/llama.cpp/pull/18537#issuecomment-3706422301
         return true;
     }
 
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
index 1298f99fff6..255e59f6fc6 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
@@ -11,6 +11,7 @@ using namespace ggml_cuda_mma;
 
 #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
 #define MMQ_ITER_K 256
+#define MMQ_ITER_K_MXFP4_FP4    512
 #define MMQ_NWARPS 8
 
 typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
@@ -44,8 +45,15 @@ struct block_q8_1_mmq {
     };
     int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
 };
+
+struct block_fp4_mmq {
+    uint32_t d4[4];       // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.
+    int8_t   qs[4 * 32];  // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values
+};
+
 static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
 static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1),      "Unexpected block_q8_1_mmq size");
+static_assert(sizeof(block_fp4_mmq)  == sizeof(block_q8_1_mmq),    "Unexpected block_fp4_mmq size");
 
 static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
     switch (type_x) {
@@ -129,6 +137,14 @@ static int get_mmq_y_host(const int cc) {
         ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64);
 }
 
+static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
+#if defined(BLACKWELL_MMA_AVAILABLE)
+    return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;
+#else
+    return MMQ_ITER_K;
+#endif // defined(BLACKWELL_MMA_AVAILABLE)
+}
+
 static constexpr __device__ int get_mmq_y_device() {
 #if defined(GGML_USE_HIP)
 #if defined(RDNA1)
@@ -191,6 +207,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
 }
 
 #define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0                   + 4)
+#define MMQ_MMA_TILE_X_K_FP4  (2*MMQ_TILE_NE_K + 8                                       + 4)
 #define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0                   + 4)
 #define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K                           + 4)
 #define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2                         + 4)
@@ -201,6 +218,8 @@ static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
 static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
 static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
 static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_FP4  % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4");
 
 static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
     switch (type) {
@@ -209,6 +228,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
         case GGML_TYPE_Q5_0:    return MMQ_MMA_TILE_X_K_Q8_0;
         case GGML_TYPE_Q5_1:    return MMQ_MMA_TILE_X_K_Q8_1;
         case GGML_TYPE_Q8_0:    return MMQ_MMA_TILE_X_K_Q8_0;
+        // tile sizes are the same for Q8_1 and FP4 for blackwell
         case GGML_TYPE_MXFP4:   return MMQ_MMA_TILE_X_K_Q8_1;
         case GGML_TYPE_Q2_K:    return MMQ_MMA_TILE_X_K_Q2_K;
         case GGML_TYPE_Q3_K:    return MMQ_MMA_TILE_X_K_Q3_K;
@@ -228,7 +248,8 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
 }
 
 // block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
-#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1)
+#define MMQ_TILE_Y_K     (MMQ_TILE_NE_K + MMQ_TILE_NE_K / QI8_1)
+#define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K
 
 static int mmq_get_granularity_host(const int mmq_x, const int cc) {
     if (amd_mfma_available(cc) || amd_wmma_available(cc)) {
@@ -761,6 +782,50 @@ template  static __device__ __forceinline__ void loa
     }
 }
 
+template 
+static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restrict__ x,
+                                                            int * __restrict__ x_tile,
+                                                            const int kbx0,
+                                                            const int i_max,
+                                                            const int stride) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+    int *      x_qs = (int *) x_tile;
+    uint32_t * x_sc = (uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
+
+    const int txi = threadIdx.x;
+
+    constexpr int iter_k = get_iter_k(GGML_TYPE_MXFP4);
+
+    constexpr int threads_per_row = iter_k / QK_MXFP4;  // each thread processes 1 block
+    constexpr int rows_per_warp   = warp_size / threads_per_row;
+    const int     kbx             = txi % threads_per_row;
+    const int     row_in_warp     = txi / threads_per_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
+        int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
+
+        if constexpr (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
+
+        // quantize_mxfp4_mmq permutes nibbles to match the quantized format
+        const int k0 = kbx * 4;
+        memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16);
+
+        // Load E8M0 scales: pack 2 consecutive scales into one uint32
+        if (kbx % 2 == 0) {
+            uint32_t e = bxi->e;
+            e |= ((bxi + 1)->e << 8);
+            x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e;
+        }
+    }
+}
+
 template 
 static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@@ -797,9 +862,10 @@ template 
 static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
-    typedef tile<16,  8, int> tile_A;
-    typedef tile<16,  8, int> tile_B;
-    typedef tile<16, 16, int> tile_C;
+    constexpr data_layout input_layout = get_input_data_layout();
+    typedef tile<16,  8, int, input_layout>        tile_A;
+    typedef tile<16,  8, int, input_layout>        tile_B;
+    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = granularity;
@@ -930,6 +996,78 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
 #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 }
 
+template 
+static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,
+                                                               const int * __restrict__ y,
+                                                               float * __restrict__ sum,
+                                                               const int k00) {
+    typedef tile<16, 8, int>   tile_A;
+    typedef tile<8, 8, int>    tile_B;
+    typedef tile<16, 8, float> tile_C;  // Output is float for native scaled MMA
+
+    constexpr int granularity   = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx           = rows_per_warp / tile_C::I;  // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);
+
+    // Match layout from load_tiles_mxfp4_fp4
+    const int *      x_qs = (const int *) x;
+    const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
+    const int *      y_qs = (const int *) y + 4;
+    const uint32_t * y_sc = (const uint32_t *) y;
+
+    // tile_A has a length of 64 logical values vs. 32 values in block_mxfp4
+    tile_A   A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
+    uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
+
+    // Block scale
+    // Each thread has to point to a 4 byte scale value
+    // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
+
+    const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+#pragma unroll
+    for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
+            const int k0 = k00 + k01;
+
+            load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
+                          MMQ_MMA_TILE_X_K_FP4);
+
+            // based on block-scaling document, 2 threads in each quad need to supply to the scale value
+            const int tidx         = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
+            scaleA[n][k01 / (2 * QI_MXFP4)] =
+                *(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4));
+        }
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
+#pragma unroll
+        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
+            tile_B   B;
+            uint32_t scaleB;  // 2xN scales
+
+            load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);
+
+            scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)];
+
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+                tile_C C;
+
+                mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB);
+#pragma unroll
+                for (int l = 0; l < tile_C::ne; ++l) {
+                    sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
+                }
+            }
+        }
+    }
+}
+
 template 
 static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@@ -966,9 +1104,10 @@ template 
 static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
-    typedef tile<16,  8, int> tile_A;
-    typedef tile<16,  8, int> tile_B;
-    typedef tile<16, 16, int> tile_C;
+    constexpr data_layout input_layout = get_input_data_layout();
+    typedef tile<16,  8, int, input_layout>        tile_A;
+    typedef tile<16,  8, int, input_layout>        tile_B;
+    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = granularity;
@@ -1130,10 +1269,11 @@ template 
 static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 #if defined(AMD_MFMA_AVAILABLE)
-    typedef tile<16,  8, int> tile_A;
-    typedef tile<16,  8, int> tile_B;
-    typedef tile<16, 16, int> tile_C;
-    typedef tile<64,  2, int> tile_load;
+    constexpr data_layout input_layout = get_input_data_layout();
+    typedef tile<16,  8, int, input_layout>        tile_A;
+    typedef tile<16,  8, int, input_layout>        tile_B;
+    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
+    typedef tile<64,  2, int, input_layout>        tile_load;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = granularity;
@@ -1179,9 +1319,10 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
         }
     }
 #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
-    typedef tile<16,  4, int> tile_A;
-    typedef tile<16,  4, int> tile_B;
-    typedef tile<16, 16, int> tile_C;
+    constexpr data_layout input_layout = get_input_data_layout();
+    typedef tile<16,  4, int, input_layout>        tile_A;
+    typedef tile<16,  4, int, input_layout>        tile_B;
+    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = granularity;
@@ -1435,10 +1576,11 @@ template 
 static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 #if defined(AMD_MFMA_AVAILABLE)
-    typedef tile<16,  8, int> tile_A;
-    typedef tile<16,  8, int> tile_B;
-    typedef tile<16, 16, int> tile_C;
-    typedef tile<64,  2, int> tile_load;
+    constexpr data_layout input_layout = get_input_data_layout();
+    typedef tile<16,  8, int, input_layout>        tile_A;
+    typedef tile<16,  8, int, input_layout>        tile_B;
+    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
+    typedef tile<64,  2, int, input_layout>        tile_load;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = granularity;
@@ -1501,10 +1643,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
         }
     }
 #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
-
-    typedef tile<16,  4, int> tile_A;
-    typedef tile<16,  4, int> tile_B;
-    typedef tile<16, 16, int> tile_C;
+    constexpr data_layout input_layout = get_input_data_layout();
+    typedef tile<16,  4, int, input_layout>        tile_A;
+    typedef tile<16,  4, int, input_layout>        tile_B;
+    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = granularity;
@@ -2265,10 +2407,11 @@ template 
 static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
 #if defined(AMD_MFMA_AVAILABLE)
-    typedef tile<16,  8, int> tile_A;
-    typedef tile<16,  8, int> tile_B;
-    typedef tile<16, 16, int> tile_C;
-    typedef tile<64,  2, int> tile_load;
+    constexpr data_layout input_layout = get_input_data_layout();
+    typedef tile<16,  8, int, input_layout>        tile_A;
+    typedef tile<16,  8, int, input_layout>        tile_B;
+    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
+    typedef tile<64,  2, int, input_layout>        tile_load;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = granularity;
@@ -2316,9 +2459,10 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
         }
     }
 #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
-    typedef tile<16,  4, int> tile_A;
-    typedef tile<16,  4, int> tile_B;
-    typedef tile<16, 16, int> tile_C;
+    constexpr data_layout input_layout = get_input_data_layout();
+    typedef tile<16,  4, int, input_layout>        tile_A;
+    typedef tile<16,  4, int, input_layout>        tile_B;
+    typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = granularity;
@@ -2571,14 +2715,14 @@ template  static __device__ __forceinline__ void loa
 
 #pragma unroll
         for (int l = 0; l < QR2_XXS; ++l) {
-            const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]);
-            const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F];
+            const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[l]];
+            const uint32_t signs = unpack_ksigns(aux32 >> (7 * l));
 
-            const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
-            const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
+            const int signs0 = __vcmpne4(signs & 0x08040201, 0);
+            const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0);
 
-            const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
-            const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
+            const int signs1 = __vcmpne4(signs & 0x80402010, 0);
+            const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1);
 
 #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
@@ -2589,12 +2733,12 @@ template  static __device__ __forceinline__ void loa
 #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         }
 
-        const int ls = aux32 >> 28;
+        const int ls = aux32 >> 27 | 1; // (scale * 2 + 1)
         const float d = bxi->d;
 #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
-        x_df[i*MMQ_MMA_TILE_X_K_Q8_0   + kqsx] = (ls*d + d/2)/4;
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0   + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4
 #else
-        x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
+        x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4
 #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)  || defined(AMD_WMMA_AVAILABLE)
     }
 }
@@ -2632,11 +2776,14 @@ template  static __device__ __forceinline__ void loa
 
     #pragma unroll
         for (int l = 0; l < QR2_XS; ++l) {
-            const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF));
-            const uint32_t * signs    = (const uint32_t *)(ksigns64   + (q2[l] >> 9));
+            const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l] & 0x1FF];
+            const uint32_t signs = unpack_ksigns(q2[l] >> 9);
 
-            const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
-            const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
+            const int signs0 = __vcmpne4(signs & 0x08040201, 0);
+            const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
+
+            const int signs1 = __vcmpne4(signs & 0x80402010, 0);
+            const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
 
 #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
@@ -2760,11 +2907,13 @@ template  static __device__ __forceinline__ void loa
 #pragma unroll
         for (int l = 0; l < QR3_XXS; ++l) {
             const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]);
+            const uint32_t signs = unpack_ksigns(aux32 >> (7*l));
 
-            const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F));
+            const int signs0 = __vcmpne4(signs & 0x08040201, 0);
+            const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
 
-            const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
-            const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
+            const int signs1 = __vcmpne4(signs & 0x80402010, 0);
+            const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
 
 #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
@@ -3015,7 +3164,7 @@ static __device__ __forceinline__ void mmq_write_back_mma(
 
 #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
     constexpr int tileC_IJ = mmq_get_granularity_device(0);
-    typedef tile tile_C;
+    typedef tile tile_C;
     constexpr int rows_per_warp = granularity;
 #else
     typedef tile<16, 8, int> tile_C;
@@ -3102,8 +3251,13 @@ struct mmq_type_traits {
 template 
 struct mmq_type_traits {
     static constexpr int              vdr          = VDR_MXFP4_Q8_1_MMQ;
+#ifdef BLACKWELL_MMA_AVAILABLE
+    static constexpr load_tiles_mmq_t load_tiles  = load_tiles_mxfp4_fp4;
+    static constexpr vec_dot_mmq_t    vec_dot_mma = vec_dot_mxfp4_mxfp4_mma;
+#else
     static constexpr load_tiles_mmq_t load_tiles   = load_tiles_mxfp4;
     static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma;
+#endif // BLACKWELL_MMA_AVAILABLE
     static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
 };
 
@@ -3236,17 +3390,26 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
     constexpr mmq_write_back_t write_back = mmq_write_back_dp4a;
 #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 
-    constexpr int blocks_per_iter = MMQ_ITER_K / qk;
+#if defined(BLACKWELL_MMA_AVAILABLE)
+    // FP4 tile stores 8 blocks
+    constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1;
+#else
+    constexpr int ne_block = 4 * QK8_1;
+#endif  // defined(BLACKWELL_MMA_AVAILABLE)
+
+    constexpr int ITER_K          = get_iter_k(type);
+    constexpr int blocks_per_iter = ITER_K / qk;
 
     float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
 
+    constexpr int sz = sizeof(block_q8_1_mmq) / sizeof(int);
+
     for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
         load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
-
         {
-            const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
+            const int * by0 = y + ncols_y * (kb0 * qk / ne_block) * sz;
 #pragma unroll
-            for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
+            for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
                 int l = l0 + threadIdx.y*warp_size + threadIdx.x;
 
                 tile_y[l] = by0[l];
@@ -3260,9 +3423,9 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
         __syncthreads();
 
         {
-            const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
+            const int * by0 = y + ncols_y * ((kb0 * qk / ne_block) * sz + sz);
 #pragma unroll
-            for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
+            for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
                 int l = l0 + threadIdx.y*warp_size + threadIdx.x;
 
                 tile_y[l] = by0[l];
@@ -3394,8 +3557,10 @@ static __global__ void mul_mat_q(
     }
 #endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
 
+    constexpr int ITER_K = get_iter_k(type);
+
     const     int64_t blocks_per_ne00 = ncols_x / qk;
-    constexpr int     blocks_per_iter = MMQ_ITER_K / qk;
+    constexpr int     blocks_per_iter = ITER_K / qk;
 
     // kbc == k block continuous, current index in continuous ijk space.
     int64_t kbc      = (int64_t) blockIdx.x     *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
@@ -3456,7 +3621,7 @@ static __global__ void mul_mat_q(
             __syncthreads();
         }
 
-        offset_y   += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
+        offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
         offset_dst += it*mmq_y;
 
         const int tile_x_max_i = nrows_x  - it*mmq_y - 1;
@@ -3523,7 +3688,7 @@ static __global__ void mul_mat_q(
         __syncthreads();
     }
 
-    offset_y   += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
+    offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
     offset_dst += it*mmq_y;
 
     const int tile_x_max_i = nrows_x  - it*mmq_y - 1;
@@ -3537,16 +3702,25 @@ static __global__ void mul_mat_q(
          tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
 }
 
-
 template 
-static __global__ void mul_mat_q_stream_k_fixup(
-        const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
-        const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
-        const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst,
-        const int ncols_max) {
+static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
+                                                const int32_t * expert_bounds,
+                                                float * __restrict__ dst,
+                                                const float * __restrict__ tmp_last_tile,
+                                                const int    ncols_x,
+                                                const int    nrows_x,
+                                                const int    ncols_dst,
+                                                const size_t stride_col_dst,
+                                                const int    nchannels_y,
+                                                const size_t stride_channel_dst,
+                                                const int    nsamples_y,
+                                                const size_t stride_sample_dst,
+                                                const int    ncols_max) {
     constexpr int     mmq_y           = get_mmq_y_device();
     constexpr int     qk              = ggml_cuda_type_traits::qk;
-    constexpr int     blocks_per_iter = MMQ_ITER_K / qk;
+    constexpr int     ITER_K          = get_iter_k(type);
+
+    constexpr int     blocks_per_iter = ITER_K / qk;
     const     int64_t blocks_per_ne00 = ncols_x / qk;
 
     constexpr int nwarps = mmq_get_nwarps_device();
@@ -3704,7 +3878,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int
     const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
     const size_t nbs_ids = mmq_x*sizeof(int);
     const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
-    const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
+    const size_t nbs_y = mmq_x * (sizeof(block_q8_1_mmq));
     return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
 }
 
@@ -3920,4 +4094,4 @@ void ggml_cuda_op_mul_mat_q(
     const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
     const int64_t src1_padded_row_size, cudaStream_t stream);
 
-bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11);
+bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmvf.cu b/ml/backend/ggml/ggml/src/ggml-cuda/mmvf.cu
index 32948e4d7a1..d9147202429 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/mmvf.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmvf.cu
@@ -4,26 +4,48 @@
 #include "mmvf.cuh"
 #include "convert.cuh"
 
-template 
+template 
 static __global__ void mul_mat_vec_f(
         const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
-        const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
+        const int ncols2, const uint3 nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
         const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
-        const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
+        const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
+        const int ids_stride) {
     const int row         = blockIdx.x;
+    // for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens)
     const int channel_dst = blockIdx.y;
-    const int channel_x   = ids ? ids[channel_dst]          : fastdiv((uint32_t) channel_dst, channel_ratio);
-    const int channel_y   = ids ? channel_dst % nchannels_y : channel_dst;
-    const int sample_dst  = blockIdx.z;
+    const int tid         = threadIdx.x;
+
+    int token_idx;
+    int channel_x;
+    int channel_y;
+    int sample_dst;
+
+    if constexpr (is_multi_token_id) {
+        // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
+        token_idx  = blockIdx.z;
+        channel_x  = ids[channel_dst + token_idx * ids_stride];
+        channel_y  = fastmodulo(channel_dst, nchannels_y);
+        sample_dst = 0;
+    } else {
+        token_idx  = ids ? blockIdx.z                                          : 0;
+        channel_x  = ids ? ids[blockIdx.y + token_idx * ids_stride]            : fastdiv((uint32_t) channel_dst, channel_ratio);
+        channel_y  = ids ? fastmodulo(blockIdx.y, nchannels_y)                 : channel_dst;
+        sample_dst = ids ? 0                                                   : blockIdx.z;
+    }
+
     const int sample_x    = fastdiv((uint32_t) sample_dst, sample_ratio);
     const int sample_y    = sample_dst;
-    const int tid         = threadIdx.x;
 
     constexpr int warp_size   = ggml_cuda_get_physical_warp_size();
 
     x   += int64_t(sample_x)  *stride_sample_x   + channel_x  *stride_channel_x   + row*stride_row;
     y   += int64_t(sample_y)  *stride_sample_y   + channel_y  *stride_channel_y;
     dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
+    if constexpr (is_multi_token_id) {
+        y   += token_idx*stride_col_y2*2;
+        dst += token_idx*stride_col_dst;
+    }
 
     bool use_gate = false;
     bool use_bias = false;
@@ -56,8 +78,10 @@ static __global__ void mul_mat_vec_f(
     if (use_gate) {
         gate_x += int64_t(sample_x)  *stride_sample_x   + channel_x  *stride_channel_x   + row*stride_row;
     }
+
+    const int channel_bias = ids ? channel_x : channel_dst;
+
     if constexpr (has_fusion) {
-        const int channel_bias = ids ? channel_x : channel_dst;
         if (use_bias) {
             x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
         }
@@ -349,36 +373,36 @@ static __global__ void mul_mat_vec_f(
     }
 }
 
-template
+template
 static void mul_mat_vec_f_switch_fusion(
         const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
-        const int64_t ncols, const int64_t nrows,
+        const int64_t ncols, const uint3 nchannels_y,
         const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
         const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
         const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
-        const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) {
+        const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const int ids_stride, const cudaStream_t stream) {
 
     const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
     if constexpr (ncols_dst == 1) {
         if (has_fusion) {
-            mul_mat_vec_f<<>>
-                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_vec_f<<>>
+                (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
             return;
        }
     }
 
     GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
 
-    mul_mat_vec_f<<>>
-        (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+    mul_mat_vec_f<<>>
+        (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
         channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-        sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
 
 }
 
-template 
+template 
 void launch_mul_mat_vec_f_cuda(
         const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
         const int64_t ncols, const int64_t nrows,
@@ -386,12 +410,13 @@ void launch_mul_mat_vec_f_cuda(
         const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
         const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
-        cudaStream_t stream) {
+        const int64_t nsamples_or_ntokens, const int64_t ids_stride, cudaStream_t stream) {
     GGML_ASSERT(ncols        % 2 == 0);
     GGML_ASSERT(stride_row   % 2 == 0);
     GGML_ASSERT(stride_col_y % 2 == 0);
     GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
     GGML_ASSERT(       nsamples_dst  % nsamples_x  == 0);
+    const uint3 nchannels_y_fd   = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0);
     const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
     const uint3 sample_ratio_fd  = init_fastdiv_values(nsamples_dst  / nsamples_x);
 
@@ -415,56 +440,56 @@ void launch_mul_mat_vec_f_cuda(
     const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
 
     const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0);
-    const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
+    const dim3 block_nums(nrows, nchannels_dst, nsamples_or_ntokens);
     const dim3 block_dims(block_size_best, 1, 1);
     switch (block_size_best) {
         case   32: {
-            mul_mat_vec_f_switch_fusion
-                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
         } break;
         case   64: {
-            mul_mat_vec_f_switch_fusion
-                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
         } break;
         case   96: {
-            mul_mat_vec_f_switch_fusion
-                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
         } break;
         case  128: {
-            mul_mat_vec_f_switch_fusion
-                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
         } break;
         case  160: {
-            mul_mat_vec_f_switch_fusion
-                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
         } break;
         case  192: {
-            mul_mat_vec_f_switch_fusion
-                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
         } break;
         case  224: {
-            mul_mat_vec_f_switch_fusion
-                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
         } break;
         case  256: {
-            mul_mat_vec_f_switch_fusion
-                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
         } break;
         default: {
             GGML_ABORT("fatal error");
@@ -480,55 +505,88 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst(
         const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
         const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
-        cudaStream_t stream) {
+        const int64_t ids_stride, cudaStream_t stream) {
+
+    const bool has_ids = ids != nullptr;
+
+    if (has_ids && ncols_dst > 1) {
+        // Multi-token MUL_MAT_ID path only - single-token goes through regular path below
+        constexpr int c_ncols_dst = 1;
+        launch_mul_mat_vec_f_cuda
+            (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+             nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+             stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+             ncols_dst, ids_stride, stream);
+        return;
+    }
+
+    if (has_ids) {
+        // Single-token MUL_MAT_ID path
+        constexpr int c_ncols_dst = 1;
+        launch_mul_mat_vec_f_cuda
+            (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+             nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+             stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+             ncols_dst, ids_stride, stream);
+        return;
+    }
+
     switch (ncols_dst) {
         case 1:
             launch_mul_mat_vec_f_cuda
                 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 nsamples_dst, ids_stride, stream);
             break;
         case 2:
             launch_mul_mat_vec_f_cuda
                 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 nsamples_dst, ids_stride, stream);
             break;
         case 3:
             launch_mul_mat_vec_f_cuda
                 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 nsamples_dst, ids_stride, stream);
             break;
         case 4:
             launch_mul_mat_vec_f_cuda
                 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 nsamples_dst, ids_stride, stream);
             break;
         case 5:
             launch_mul_mat_vec_f_cuda
                 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 nsamples_dst, ids_stride, stream);
             break;
         case 6:
             launch_mul_mat_vec_f_cuda
                 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 nsamples_dst, ids_stride, stream);
             break;
         case 7:
             launch_mul_mat_vec_f_cuda
                 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 nsamples_dst, ids_stride, stream);
             break;
         case 8:
             launch_mul_mat_vec_f_cuda
                 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 nsamples_dst, ids_stride, stream);
             break;
         default:
             GGML_ABORT("fatal error");
@@ -544,21 +602,21 @@ static void mul_mat_vec_f_cuda(
         const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
         const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
-        enum ggml_prec prec, cudaStream_t stream) {
+        const int64_t ids_stride, enum ggml_prec prec, cudaStream_t stream) {
 
     if constexpr(std::is_same_v) {
         if (prec == GGML_PREC_DEFAULT) {
             mul_mat_vec_f_cuda_switch_ncols_dst
                 (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             return;
         }
     }
     mul_mat_vec_f_cuda_switch_ncols_dst
         (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
         nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-        stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+        stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
 }
 
 void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
@@ -573,7 +631,7 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
     const size_t ts_src1 = ggml_type_size(src1->type);
     const size_t ts_dst  = ggml_type_size(dst->type);
 
-    GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for  batch size 1.
+    GGML_ASSERT(!ids || ne12 <= MMVF_MAX_BATCH_SIZE);
     GGML_ASSERT(ne13 == ne3);
 
     GGML_ASSERT(        nb00       == ts_src0);
@@ -626,29 +684,31 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
     const int64_t ncols_dst          = ids ? ne2  : ne1;
     const int64_t nchannels_y        = ids ? ne11 : ne12;
     const int64_t nchannels_dst      = ids ? ne1  : ne2;
+    const int64_t stride_col_dst     = ids ? s2   : s1;
+    const int64_t stride_col_y       = ids ? s12  : s11;
     const int64_t stride_channel_dst = ids ? s1   : s2;
     const int64_t stride_channel_y   = ids ? s11  : s12;
 
-    GGML_ASSERT(!ids || ncols_dst == 1);
+    const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
 
     switch (src0->type) {
         case GGML_TYPE_F32: {
             const float * src0_d = (const float *) src0->data;
-            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
+            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
                 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
-                ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
+                ne03,              ne3,           s03, s13,              s3,                 ids_stride, prec, ctx.stream());
         } break;
         case GGML_TYPE_F16: {
             const half * src0_d = (const half *) src0->data;
-            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
+            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
                 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
-                ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
+                ne03,              ne3,           s03, s13,              s3,                 ids_stride, prec, ctx.stream());
         } break;
         case GGML_TYPE_BF16: {
             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
-            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
+            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
                 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
-                ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
+                ne03,              ne3,           s03, s13,              s3,                 ids_stride, prec, ctx.stream());
         } break;
         default:
             GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
@@ -695,19 +755,19 @@ void ggml_cuda_op_mul_mat_vec_f(
             const float * src0_d = (const float *) src0_dd_i;
             mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
         } break;
         case GGML_TYPE_F16: {
             const half * src0_d = (const half *) src0_dd_i;
             mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
         } break;
         case GGML_TYPE_BF16: {
             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
             mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
         } break;
         default:
             GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmvf.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/mmvf.cuh
index a09fbdc7202..a50f7c02180 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/mmvf.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmvf.cuh
@@ -1,5 +1,7 @@
 #include "common.cuh"
 
+#define MMVF_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVF kernels.
+
 void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
     const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
 
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu b/ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
index d671551c171..ce25ccf427c 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
@@ -137,15 +137,15 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
     return 1;
 }
 
-// tell the compiler to use as many registers as it wants, see nwarps definition below
-template 
+template 
 __launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
 static __global__ void mul_mat_vec_q(
         const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
         const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
         const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
         const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
-        const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) {
+        const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
+        const uint32_t ids_stride) {
 
     constexpr int qk  = ggml_cuda_type_traits::qk;
     constexpr int qi  = ggml_cuda_type_traits::qi;
@@ -162,11 +162,25 @@ static __global__ void mul_mat_vec_q(
     const     int blocks_per_row_x = ncols_x / qk;
     constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
 
-    // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
     const uint32_t channel_dst = blockIdx.y;
-    const uint32_t channel_x   = ncols_dst == 1 && ids ? ids[channel_dst]                     : fastdiv(channel_dst, channel_ratio);
-    const uint32_t channel_y   = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
-    const uint32_t sample_dst  = blockIdx.z;
+
+    uint32_t token_idx = 0;
+    uint32_t channel_x;
+    uint32_t channel_y;
+    uint32_t sample_dst;
+
+    if constexpr (is_multi_token_id) {
+        // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
+        token_idx  = blockIdx.z;
+        channel_x  = ids[channel_dst + token_idx * ids_stride];
+        channel_y  = fastmodulo(channel_dst, nchannels_y);
+        sample_dst = 0;
+    } else {
+        channel_x  = ncols_dst == 1 && ids ? ids[channel_dst]                     : fastdiv(channel_dst, channel_ratio);
+        channel_y  = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
+        sample_dst = blockIdx.z;
+    }
+
     const uint32_t sample_x    = fastdiv(sample_dst, sample_ratio);
     const uint32_t sample_y    = sample_dst;
 
@@ -188,11 +202,11 @@ static __global__ void mul_mat_vec_q(
         active_glu    = fusion.glu_op;
     }
 
-    const uint32_t channel_bias = ids ? channel_x : channel_dst;
 
     float x_biases[ncols_dst]    = { 0.0f };
     float gate_biases[ncols_dst] = { 0.0f };
     if constexpr (has_fusion) {
+        const uint32_t channel_bias = ids ? channel_x : channel_dst;
         if (use_bias) {
             x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
             // 1. Hide latency by prefetching bias and gate here
@@ -222,6 +236,9 @@ static __global__ void mul_mat_vec_q(
     float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};
 
     const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
+    if constexpr (is_multi_token_id) {
+        y += token_idx*stride_col_y;
+    }
     const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
 
     for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
@@ -275,6 +292,10 @@ static __global__ void mul_mat_vec_q(
 
     dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0;
 
+    if constexpr (is_multi_token_id) {
+        dst += token_idx*stride_col_dst;
+    }
+
     // sum up partial sums and write back result
 #pragma unroll
     for (int j = 0; j < ncols_dst; ++j) {
@@ -335,40 +356,41 @@ static __global__ void mul_mat_vec_q(
 }
 
 static std::pair calc_launch_params(
-        const int ncols_dst, const int nrows_x, const int nchannels_y, const int nsamples_y,
+        const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
         const int warp_size, const mmvq_parameter_table_id table_id) {
     const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
-    const dim3 block_nums(nblocks, nchannels_y, nsamples_y);
+    const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);
     const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1);
     return {block_nums, block_dims};
 }
 
-template
+template
 static void mul_mat_vec_q_switch_fusion(
         const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
         const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
         const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
         const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
         const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
-        const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) {
+        const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared,
+        const uint32_t ids_stride, cudaStream_t stream) {
 
     const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
     if constexpr (c_ncols_dst == 1) {
         if (has_fusion) {
-            mul_mat_vec_q<<>>
+            mul_mat_vec_q<<>>
                 (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
             return;
         }
     }
 
     GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
 
-    mul_mat_vec_q<<>>
+    mul_mat_vec_q<<>>
         (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
         channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-        sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
 }
 
 template 
@@ -379,7 +401,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
         const int nchannels_x, const int nchannels_y, const int nchannels_dst,
         const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
         const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
-        cudaStream_t stream) {
+        const int ids_stride, cudaStream_t stream) {
 
     GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
     GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
@@ -393,8 +415,19 @@ static void mul_mat_vec_q_switch_ncols_dst(
     const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
 
     const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
+    const bool has_ids = ids != nullptr;
+
+    if (has_ids && ncols_dst > 1) {
+        // Multi-token MUL_MAT_ID path only - single-token goes through regular path below
+        constexpr int c_ncols_dst = 1;
+        std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
+        mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+             channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+             sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
+             dims.first, dims.second, 0, ids_stride, stream);
+        return;
+    }
 
-    GGML_ASSERT(!ids || ncols_dst == 1);
     switch (ncols_dst) {
         case 1: {
             constexpr int c_ncols_dst = 1;
@@ -402,7 +435,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
             mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 dims.first, dims.second, 0, stream);
+                 dims.first, dims.second, 0, ids_stride, stream);
         } break;
         case 2: {
             constexpr int c_ncols_dst = 2;
@@ -410,7 +443,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
             mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 dims.first, dims.second, 0, stream);
+                 dims.first, dims.second, 0, ids_stride, stream);
         } break;
         case 3: {
             constexpr int c_ncols_dst = 3;
@@ -418,7 +451,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
             mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 dims.first, dims.second, 0, stream);
+                 dims.first, dims.second, 0, ids_stride, stream);
         } break;
         case 4: {
             constexpr int c_ncols_dst = 4;
@@ -426,7 +459,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
             mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 dims.first, dims.second, 0, stream);
+                 dims.first, dims.second, 0, ids_stride, stream);
         } break;
         case 5: {
             constexpr int c_ncols_dst = 5;
@@ -434,7 +467,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
             mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 dims.first, dims.second, 0, stream);
+                 dims.first, dims.second, 0, ids_stride, stream);
         } break;
         case 6: {
             constexpr int c_ncols_dst = 6;
@@ -442,7 +475,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
             mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 dims.first, dims.second, 0, stream);
+                 dims.first, dims.second, 0, ids_stride, stream);
         } break;
         case 7: {
             constexpr int c_ncols_dst = 7;
@@ -450,7 +483,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
             mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 dims.first, dims.second, 0, stream);
+                 dims.first, dims.second, 0, ids_stride, stream);
         } break;
         case 8: {
             constexpr int c_ncols_dst = 8;
@@ -458,7 +491,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
             mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 dims.first, dims.second, 0, stream);
+                 dims.first, dims.second, 0, ids_stride, stream);
         } break;
         default:
             GGML_ABORT("fatal error");
@@ -474,127 +507,127 @@ static void mul_mat_vec_q_switch_type(
         const int nchannels_x, const int nchannels_y, const int nchannels_dst,
         const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
         const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
-        cudaStream_t stream) {
+        const int ids_stride, cudaStream_t stream) {
     switch (type_x) {
         case GGML_TYPE_Q4_0:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q4_1:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q5_0:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q5_1:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q8_0:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_MXFP4:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q2_K:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q3_K:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q4_K:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q5_K:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q6_K:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ2_XXS:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ2_XS:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ2_S:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ3_XXS:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ1_S:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ1_M:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ4_NL:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ4_XS:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ3_S:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         default:
             GGML_ABORT("fatal error");
@@ -622,7 +655,7 @@ void ggml_cuda_mul_mat_vec_q(
     GGML_ASSERT(        nb0        == ts_dst);
     GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
 
-    GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
+    GGML_ASSERT(!ids || ne12 <= MMVQ_MAX_BATCH_SIZE);
 
     const float   * src1_d =       (const float   *) src1->data;
     const int32_t *  ids_d = ids ? (const int32_t *)  ids->data : nullptr;
@@ -693,11 +726,13 @@ void ggml_cuda_mul_mat_vec_q(
     const int64_t stride_channel_dst = ids ? s1   : s2;
     const int64_t stride_channel_y   = ids ? s11  : s12;
 
+    const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
+
     mul_mat_vec_q_switch_type(
         src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00,
         ne01,              ncols_dst,     s01, stride_col_y,     stride_col_dst,
         ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
-        ne03,              ne3,           s03, s13,              s3,               stream);
+        ne03,              ne3,           s03, s13,              s3,               ids_stride, stream);
 }
 
 void ggml_cuda_op_mul_mat_vec_q(
@@ -726,7 +761,7 @@ void ggml_cuda_op_mul_mat_vec_q(
     ggml_cuda_mm_fusion_args_device fusion_local{};
     mul_mat_vec_q_switch_type(
         src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
-        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream);
+        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, stream);
 
     GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size);
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cuh
index 4bb10cfaec2..8a154631f69 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cuh
@@ -1,6 +1,7 @@
 #include "common.cuh"
 
 #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
+#define MMVQ_MMID_MAX_BATCH_SIZE 4 // Max. batch size for which to use MMVQ kernels for MUL_MAT_ID
 
 void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
     const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/norm.cu b/ml/backend/ggml/ggml/src/ggml-cuda/norm.cu
index 4f153c5718e..ef98f675aa7 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/norm.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/norm.cu
@@ -25,19 +25,8 @@ static __global__ void norm_f32(
     }
 
     // sum up partial sums
-    mean_var = warp_reduce_sum(mean_var);
-    if constexpr (block_size > WARP_SIZE) {
-        static_assert(block_size == 1024, "unexpected block_size");
-        __shared__ float2 s_sum[32];
-        const int warp_id = threadIdx.x / WARP_SIZE;
-        const int lane_id = threadIdx.x % WARP_SIZE;
-        if (lane_id == 0) {
-            s_sum[warp_id] = mean_var;
-        }
-        __syncthreads();
-        mean_var = s_sum[lane_id];
-        mean_var = warp_reduce_sum(mean_var);
-    }
+    extern __shared__ float2 s_sum2[];
+    mean_var = block_reduce(mean_var, s_sum2);
 
     const float mean = mean_var.x / ncols;
     const float var = mean_var.y / ncols - mean * mean;
@@ -61,19 +50,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
         tmp += x[j];
     }
 
-    tmp = warp_reduce_sum(tmp);
-    if constexpr (block_size > WARP_SIZE) {
-        static_assert(block_size == 1024, "unexpected block_size");
-        __shared__ float s_sum[32];
-        const int warp_id = threadIdx.x / WARP_SIZE;
-        const int lane_id = threadIdx.x % WARP_SIZE;
-        if (lane_id == 0) {
-            s_sum[warp_id] = tmp;
-        }
-        __syncthreads();
-        tmp = s_sum[lane_id];
-        tmp = warp_reduce_sum(tmp);
-    }
+    extern __shared__ float s_sum[];
+    tmp = block_reduce(tmp, s_sum);
 
     const float mean = tmp / group_size;
     tmp = 0.0f;
@@ -84,18 +62,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
         tmp += xi * xi;
     }
 
-    tmp = warp_reduce_sum(tmp);
-    if (block_size > WARP_SIZE) {
-        __shared__ float s_sum[32];
-        const int warp_id = threadIdx.x / WARP_SIZE;
-        const int lane_id = threadIdx.x % WARP_SIZE;
-        if (lane_id == 0) {
-            s_sum[warp_id] = tmp;
-        }
-        __syncthreads();
-        tmp = s_sum[lane_id];
-        tmp = warp_reduce_sum(tmp);
-    }
+    tmp = block_reduce(tmp, s_sum);
 
     const float variance = tmp / group_size;
     const float scale = rsqrtf(variance + eps);
@@ -163,22 +130,8 @@ static __global__ void rms_norm_f32(const float * x,
     }
 
     // sum up partial sums
-    tmp = warp_reduce_sum(tmp);
-    if constexpr (block_size > WARP_SIZE) {
-        static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size");
-        __shared__ float s_sum[32];
-        const int        warp_id = tid / WARP_SIZE;
-        const int        lane_id = tid % WARP_SIZE;
-        if (lane_id == 0) {
-            s_sum[warp_id] = tmp;
-        }
-        __syncthreads();
-        tmp = 0.0f;
-        if (lane_id < (block_size / WARP_SIZE)) {
-            tmp = s_sum[lane_id];
-        }
-        tmp = warp_reduce_sum(tmp);
-    }
+    extern __shared__ float s_sum[];
+    tmp = block_reduce(tmp, s_sum);
 
     const float mean = tmp / ncols;
     const float scale = rsqrtf(mean + eps);
@@ -306,19 +259,8 @@ static __global__ void l2_norm_f32(
     }
 
     // sum up partial sums
-    tmp = warp_reduce_sum(tmp);
-    if constexpr (block_size > WARP_SIZE) {
-        static_assert(block_size == 1024, "unexpected block_size");
-        __shared__ float s_sum[32];
-        const int warp_id = threadIdx.x / WARP_SIZE;
-        const int lane_id = threadIdx.x % WARP_SIZE;
-        if (lane_id == 0) {
-            s_sum[warp_id] = tmp;
-        }
-        __syncthreads();
-        tmp = s_sum[lane_id];
-        tmp = warp_reduce_sum(tmp);
-    }
+    extern __shared__ float s_sum[];
+    tmp = block_reduce(tmp, s_sum);
 
     // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
     const float scale = rsqrtf(fmaxf(tmp, eps * eps));
@@ -337,7 +279,7 @@ static void norm_f32_cuda(
         norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     } else {
         const dim3 block_dims(1024, 1, 1);
-        norm_f32<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+        norm_f32<1024><< WARP_SIZE ? 32 * sizeof(float2): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     }
 }
 
@@ -348,7 +290,7 @@ static void group_norm_f32_cuda(
         group_norm_f32<<>>(x, dst, group_size, ne_elements, eps);
     } else {
         const dim3 block_dims(1024, 1, 1);
-        group_norm_f32<1024><<>>(x, dst, group_size, ne_elements, eps);
+        group_norm_f32<1024><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, group_size, ne_elements, eps);
     }
 }
 
@@ -358,10 +300,10 @@ static void rms_norm_f32_cuda(
     const dim3 blocks_num(nrows, nchannels, nsamples);
     if (ncols < 1024) {
         const dim3 block_dims(256, 1, 1);
-        rms_norm_f32<256, false><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+        rms_norm_f32<256, false><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     } else {
         const dim3 block_dims(1024, 1, 1);
-        rms_norm_f32<1024, false><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+        rms_norm_f32<1024, false><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     }
 }
 
@@ -404,12 +346,12 @@ static void rms_norm_mul_f32_cuda(const float *  x,
         const uint3 mul_nsamples_packed  = init_fastdiv_values(mul_nsamples);
         if (ncols < 1024) {
             const dim3 block_dims(256, 1, 1);
-            rms_norm_f32<256, true><<>>(
+            rms_norm_f32<256, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
                 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
                 mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
         } else {
             const dim3 block_dims(1024, 1, 1);
-            rms_norm_f32<1024, true><<>>(
+            rms_norm_f32<1024, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
                 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
                 mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
         }
@@ -425,14 +367,14 @@ static void rms_norm_mul_f32_cuda(const float *  x,
         const uint3 add_nsamples_packed  = init_fastdiv_values(add_nsamples);
         if (ncols < 1024) {
             const dim3 block_dims(256, 1, 1);
-            rms_norm_f32<256, true, true><<>>(
+            rms_norm_f32<256, true, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
                 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
                 mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
                 add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
                 add_nchannels_packed, add_nsamples_packed);
         } else {
             const dim3 block_dims(1024, 1, 1);
-            rms_norm_f32<1024, true, true><<>>(
+            rms_norm_f32<1024, true, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
                 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
                 mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
                 add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
@@ -460,7 +402,7 @@ static void l2_norm_f32_cuda(
         l2_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     } else {
         const dim3 block_dims(1024, 1, 1);
-        l2_norm_f32<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+        l2_norm_f32<1024><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     }
 }
 
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/pad.cu b/ml/backend/ggml/ggml/src/ggml-cuda/pad.cu
index 660c192e48a..31cd00f7781 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/pad.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/pad.cu
@@ -7,7 +7,7 @@ __device__ __forceinline__ int64_t wrap_around(int64_t coord, int64_t size) {
     return (coord + size) % size;
 }
 
-static __global__ void pad_f32(const float * src, float * dst,
+static __global__ void pad_f32(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, float * dst,
                                const int lp0, const int rp0, const int lp1, const int rp1,
                                const int lp2, const int rp2, const int lp3, const int rp3,
                                const int ne0, const int ne1, const int ne2, const int ne3,
@@ -34,11 +34,8 @@ static __global__ void pad_f32(const float * src, float * dst,
             const int64_t i01  = i1 - lp1;
             const int64_t i02  = i2 - lp2;
             const int64_t i03  = i3 - lp3;
-            const int64_t ne02 = ne2 - lp2 - rp2;
-            const int64_t ne01 = ne1 - lp1 - rp1;
-            const int64_t ne00 = ne0 - lp0 - rp0;
 
-            const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00;
+            const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00;
 
             dst[dst_idx] = src[src_idx];
         } else {
@@ -57,21 +54,21 @@ static __global__ void pad_f32(const float * src, float * dst,
         const int64_t i02 = wrap_around(i2 - lp2, ne02);
         const int64_t i03 = wrap_around(i3 - lp3, ne03);
 
-        const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00;
+        const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00;
 
         dst[dst_idx] = src[src_idx];
     }
 }
 
 
-static void pad_f32_cuda(const float * src, float * dst,
+static void pad_f32_cuda(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, float * dst,
     const int lp0, const int rp0, const int lp1, const int rp1,
     const int lp2, const int rp2, const int lp3, const int rp3,
     const int ne0, const int ne1, const int ne2, const int ne3,
     const bool circular, cudaStream_t stream) {
     int  num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
     dim3 gridDim(num_blocks, ne1, ne2 * ne3);
-    pad_f32<<>>(src, dst,
+    pad_f32<<>>(src, s00, s01, s02, s03, dst,
                                                          lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
                                                          ne0, ne1, ne2, ne3, circular);
 }
@@ -82,9 +79,10 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     float *             dst_d  = (float *) dst->data;
     cudaStream_t        stream = ctx.stream();
 
+    GGML_TENSOR_UNARY_OP_LOCALS;
+
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->type == GGML_TYPE_F32);
-    GGML_ASSERT(ggml_is_contiguous(src0));
 
     const int32_t lp0      = ((const int32_t *) (dst->op_params))[0];
     const int32_t rp0      = ((const int32_t *) (dst->op_params))[1];
@@ -96,7 +94,12 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const int32_t rp3      = ((const int32_t *) (dst->op_params))[7];
     const int32_t circular = ((const int32_t *) (dst->op_params))[8];
 
-    pad_f32_cuda(src0_d, dst_d,
+    const size_t s00 = nb00 / ggml_type_size(src0->type);
+    const size_t s01 = nb01 / ggml_type_size(src0->type);
+    const size_t s02 = nb02 / ggml_type_size(src0->type);
+    const size_t s03 = nb03 / ggml_type_size(src0->type);
+
+    pad_f32_cuda(src0_d, s00, s01, s02, s03, dst_d,
                  lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
                  dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
                  (bool) circular, stream);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/quantize.cu b/ml/backend/ggml/ggml/src/ggml-cuda/quantize.cu
index 5117f9ffc0f..a8c68e44b16 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/quantize.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/quantize.cu
@@ -47,6 +47,131 @@ static __global__ void quantize_q8_1(
     y[ib].ds = make_half2(d, sum);
 }
 
+__device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) {
+    if (!(amax > 0.0f)) {
+        return 0;
+    }
+
+    // FP4 E2M1: max exponent (unbiased) is 2.
+    constexpr int FP4_E2M1_EMAX = 2;
+
+    const float e = log2f(amax);
+
+    // "even" -> round-to-nearest integer, ties-to-even
+    const int e_int = __float2int_rn(e);
+
+    const int shared_exp = e_int - FP4_E2M1_EMAX;
+
+    int biased = shared_exp + 127;
+
+    biased = max(biased, 0);
+    biased = min(biased, 254);
+
+    return static_cast(biased);
+}
+
+// quantize values in the format mxfp4 is stored which is interleaved nibbles
+// i.e. a block a0-a31 is represented as a0a16,a1a17 ...a15a31
+static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
+                                          const int32_t * __restrict__ ids,
+                                          void * __restrict__ vy,
+                                          const int64_t ne00,
+                                          const int64_t s01,
+                                          const int64_t s02,
+                                          const int64_t s03,
+                                          const int64_t ne0,
+                                          const int     ne1,
+                                          const int     ne2) {
+    constexpr int vals_per_scale = 32;
+    constexpr int vals_per_warp  = 2 * vals_per_scale;  // Each warp processes 2 blocks of 32 = 64 values
+
+    const int warp_id = threadIdx.y;
+    const int lane_id_32 = threadIdx.x;
+
+    const int nwarps = blockDim.y;
+
+    const int64_t warp_start_offset = (blockIdx.y * nwarps + warp_id) * vals_per_warp;
+
+    if (warp_start_offset >= ne0) {
+        return;
+    }
+
+    const int64_t i1 = blockIdx.x;
+    const int64_t i2 = blockIdx.z % ne2;
+    const int64_t i3 = blockIdx.z / ne2;
+
+    const int64_t i01 = ids ? ids[i1] : i1;
+    const int64_t i02 = i2;
+    const int64_t i03 = i3;
+
+    block_fp4_mmq * y = (block_fp4_mmq *) vy;
+
+    const int64_t block_fp4_mmq_size = 8 * QK_MXFP4;  // 256 values
+    const int64_t ib0                = blockIdx.z * ((int64_t) ne1 * (ne0 / block_fp4_mmq_size));
+    const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx.x;
+    const int64_t quad_idx_in_block  = (warp_start_offset % block_fp4_mmq_size) / vals_per_warp;
+
+    const int group_id = lane_id_32 / 4;
+    const int lane_in_group = lane_id_32 % 4;
+    const int base = group_id * 2;
+    char2 * yqs2 = (char2 *) y[ib].qs;
+
+    const int64_t base_pos = i03 * s03 + i02 * s02 + i01 * s01;
+
+    uint8_t scales[2];
+
+#pragma unroll
+    for (int b = 0; b < 2; ++b) {
+        const int64_t i0 = warp_start_offset + b * vals_per_scale + lane_id_32;
+        const float xi = (i0 < ne00) ? x[base_pos + i0] : 0.0f;
+
+        float amax = fabsf(xi);
+#pragma unroll
+        for (int mask = 16; mask > 0; mask >>= 1) {
+            amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
+        }
+
+        const uint8_t e = compute_e8m0_scale(amax);
+        scales[b] = e;
+        const float inv_s = (amax == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e));
+
+#if CUDART_VERSION >= 12080
+        const float scaled_val = xi * inv_s;
+
+        const float val0 = __shfl_sync(0xFFFFFFFF, scaled_val, base, WARP_SIZE);
+        const float val1 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 16, WARP_SIZE);
+        const float val2 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 1, WARP_SIZE);
+        const float val3 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 17, WARP_SIZE);
+
+        if (lane_in_group == 0) {
+            __nv_fp4x4_e2m1 fp4_packed(make_float4(val0, val1, val2, val3));
+
+            yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = *(char2 *) &fp4_packed;
+        }
+#else
+        // Fallback: manual FP4 conversion using LUT
+        const uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s);
+
+        const uint8_t q_lo_0 = __shfl_sync(0xFFFFFFFF, q_val, base,      WARP_SIZE);
+        const uint8_t q_lo_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 1,  WARP_SIZE);
+        const uint8_t q_hi_0 = __shfl_sync(0xFFFFFFFF, q_val, base + 16, WARP_SIZE);
+        const uint8_t q_hi_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 17, WARP_SIZE);
+
+        if (lane_in_group == 0) {
+            char2 q;
+            q.x = (q_hi_0 << 4) | q_lo_0;
+            q.y = (q_hi_1 << 4) | q_lo_1;
+            yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = q;
+        }
+#endif // CUDART_VERSION >= 12080
+    }
+
+    if (lane_id_32 == 0) {
+        // Store 2 scales packed into 1 uint32
+        y[ib].d4[quad_idx_in_block] = (scales[1] << 8) | scales[0];
+    }
+}
+
 template 
 static __global__ void quantize_mmq_q8_1(
         const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy,
@@ -190,3 +315,29 @@ void quantize_mmq_q8_1_cuda(
             break;
     }
 }
+
+void quantize_mmq_mxfp4_cuda(const float *                    x,
+                             const int32_t *                  ids,
+                             void *                           vy,
+                             [[maybe_unused]] const ggml_type type_src0,
+                             const int64_t                    ne00,
+                             const int64_t                    s01,
+                             const int64_t                    s02,
+                             const int64_t                    s03,
+                             const int64_t                    ne0,
+                             const int64_t                    ne1,
+                             const int64_t                    ne2,
+                             const int64_t                    ne3,
+                             cudaStream_t                     stream) {
+    GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0);
+
+    constexpr int nwarps = 8;
+    constexpr int vals_per_warp  = 2 * QK_MXFP4;
+    constexpr int vals_per_block = nwarps * vals_per_warp;
+
+    const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block;
+    const dim3    num_blocks(ne1, block_num_y, ne2 * ne3);
+    const dim3    block_size(WARP_SIZE, nwarps, 1);
+
+    quantize_mmq_mxfp4<<>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/quantize.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/quantize.cuh
index 725ab52443c..6a91df63578 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/quantize.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/quantize.cuh
@@ -25,3 +25,17 @@ void quantize_mmq_q8_1_cuda(
         const float * x, const int32_t * ids, void * vy,
         ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
         int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
+
+void quantize_mmq_mxfp4_cuda(const float *   x,
+                             const int32_t * ids,
+                             void *          vy,
+                             ggml_type       type_src0,
+                             int64_t         ne00,
+                             int64_t         s01,
+                             int64_t         s02,
+                             int64_t         s03,
+                             int64_t         ne0,
+                             int64_t         ne1,
+                             int64_t         ne2,
+                             int64_t         ne3,
+                             cudaStream_t    stream);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/reduce_rows.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/reduce_rows.cuh
index 6bcae9e52fb..de240fd4413 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/reduce_rows.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/reduce_rows.cuh
@@ -28,22 +28,8 @@ static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __r
     }
 
     // sum up partial sums
-    sum = warp_reduce_sum(sum);
-    if (blockDim.x > WARP_SIZE) {
-        assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
-        __shared__ float s_sum[32];
-        const int        warp_id = threadIdx.x / WARP_SIZE;
-        const int        lane_id = threadIdx.x % WARP_SIZE;
-        if (lane_id == 0) {
-            s_sum[warp_id] = sum;
-        }
-        __syncthreads();
-        sum = 0.0f;
-        if (lane_id < (static_cast(blockDim.x) / WARP_SIZE)) {
-            sum = s_sum[lane_id];
-        }
-        sum = warp_reduce_sum(sum);
-    }
+    __shared__ float shared_vals[32];
+    sum = block_reduce(sum, shared_vals);
 
     if (col != 0) {
         return;
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/rope.cu b/ml/backend/ggml/ggml/src/ggml-cuda/rope.cu
index 71ca6021430..f47392de633 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/rope.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/rope.cu
@@ -43,10 +43,15 @@ static __device__ void rope_yarn(
 template 
 static __global__ void rope_norm(const T *            x,
                                  D *                  dst,
-                                 const int            ne0,
-                                 const int            ne1,
+                                 const int            ne00,
+                                 const int            ne01,
+                                 const int            ne02,
+                                 const int            s01,
+                                 const int            s02,
+                                 const int            s03,
                                  const int            s1,
                                  const int            s2,
+                                 const int            s3,
                                  const int            n_dims,
                                  const int32_t *      pos,
                                  const float          freq_scale,
@@ -59,23 +64,23 @@ static __global__ void rope_norm(const T *            x,
                                  const int            set_rows_stride) {
     const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
-    if (i0 >= ne0) {
+    if (i0 >= ne00) {
         return;
     }
 
     const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
 
-    const int row_x     = row_dst % ne1;
-    const int channel_x = row_dst / ne1;
-
-    int       idst = row_dst * ne0 + i0;
-    const int ix   = channel_x*s2 + row_x*s1 + i0;
+    const uint32_t i3 = row_dst / (ne01 * ne02);
+    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
+    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
 
+    int       idst = i0 + i1 * s1  + i2 * s2  + i3 * s3;
+    const int ix   = i0 + i1 * s01 + i2 * s02 + i3 * s03;
     // Fusion optimization: ROPE + VIEW + SET_ROWS.
     // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
     if (set_rows_stride != 0) {
-        idst = row_x * ne0 + i0;
-        idst += row_indices[channel_x] * set_rows_stride;
+        idst = i1 * s1 + i0;
+        idst += row_indices[i2] * set_rows_stride;
     }
 
     const auto & store_coaelsced = [&](float x0, float x1) {
@@ -92,7 +97,7 @@ static __global__ void rope_norm(const T *            x,
         return;
     }
 
-    const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
+    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
 
     const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
 
@@ -110,10 +115,15 @@ static __global__ void rope_norm(const T *            x,
 template 
 static __global__ void rope_neox(const T *            x,
                                  D *                  dst,
-                                 const int            ne0,
-                                 const int            ne1,
+                                 const int            ne00,
+                                 const int            ne01,
+                                 const int            ne02,
+                                 const int            s01,
+                                 const int            s02,
+                                 const int            s03,
                                  const int            s1,
                                  const int            s2,
+                                 const int            s3,
                                  const int            n_dims,
                                  const int32_t *      pos,
                                  const float          freq_scale,
@@ -126,23 +136,24 @@ static __global__ void rope_neox(const T *            x,
                                  const int            set_rows_stride) {
     const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
-    if (i0 >= ne0) {
+    if (i0 >= ne00) {
         return;
     }
 
     const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
 
-    const int row_x     = row_dst % ne1;
-    const int channel_x = row_dst / ne1;
+    const uint32_t i3 = row_dst / (ne01 * ne02);
+    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
+    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
 
-    int       idst = row_dst * ne0 + i0 / 2;
-    const int ix   = channel_x*s2 + row_x*s1 + i0/2;
+    int       idst = i0 / 2 + i1 * s1  + i2 * s2  + i3 * s3;
+    const int ix   = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
 
     // Fusion optimization: ROPE + VIEW + SET_ROWS.
     // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
     if (set_rows_stride != 0) {
-        idst = row_x * ne0 + i0 / 2;
-        idst += row_indices[channel_x] * set_rows_stride;
+        idst = i1 * s1 + i0 / 2;
+        idst += row_indices[i2] * set_rows_stride;
     }
 
     if (i0 >= n_dims) {
@@ -152,7 +163,7 @@ static __global__ void rope_neox(const T *            x,
         return;
     }
 
-    const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
+    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
 
     const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
 
@@ -168,24 +179,42 @@ static __global__ void rope_neox(const T *            x,
     dst[idst + n_dims / 2] = ggml_cuda_cast(x0 * sin_theta + x1 * cos_theta);
 }
 
-template
-static __global__ void rope_multi(
-        const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
-        const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
-        const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) {
-    const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
-
-    if (i0 >= ne0) {
+template 
+static __global__ void rope_multi(const T *            x,
+                                  T *                  dst,
+                                  const int            ne00,
+                                  const int            ne01,
+                                  const int            ne02,
+                                  const int            s01,
+                                  const int            s02,
+                                  const int            s03,
+                                  const int            s1,
+                                  const int            s2,
+                                  const int            s3,
+                                  const int            n_dims,
+                                  const int32_t *      pos,
+                                  const float          freq_scale,
+                                  const float          ext_factor,
+                                  const float          attn_factor,
+                                  const rope_corr_dims corr_dims,
+                                  const float          theta_scale,
+                                  const float *        freq_factors,
+                                  const mrope_sections sections,
+                                  const bool           is_imrope) {
+    const int i0 = 2 * (blockDim.y * blockIdx.y + threadIdx.y);
+
+    if (i0 >= ne00) {
         return;
     }
 
     const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
 
-    const int row_x     = row_dst % ne1;
-    const int channel_x = row_dst / ne1;
+    const uint32_t i3 = row_dst / (ne01 * ne02);
+    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
+    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
 
-    const int idst = row_dst*ne0 + i0/2;
-    const int ix   = channel_x*s2 + row_x*s1 + i0/2;
+    int       idst = i0 / 2 + i1 * s1  + i2 * s2  + i3 * s3;
+    const int ix   = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
 
     if (i0 >= n_dims) {
         dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
@@ -200,27 +229,24 @@ static __global__ void rope_multi(
 
     float theta_base = 0.0;
     if (is_imrope) {
-        if (sector % 3 == 1 && sector < 1 + 3 * sections.v[1]) { // h
-            theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
-        } else if (sector % 3 == 2 && sector < 2 + 3 * sections.v[2]) { // w
-            theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
-        } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
-            theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
+        if (sector % 3 == 1 && sector < 1 + 3 * sections.v[1]) {         // h
+            theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f);
+        } else if (sector % 3 == 2 && sector < 2 + 3 * sections.v[2]) {  // w
+            theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f);
+        } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) {  // t
+            theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f);
         // } else {
-        //     theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
+        //    theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f);
         }
     } else {
         if (sector < sections.v[0]) {
-            theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
-        }
-        else if (sector >= sections.v[0] && sector < sec_w) {
-            theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
-        }
-        else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
-            theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
-        }
-        else if (sector >= sec_w + sections.v[2]) {
-            theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
+            theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f);
+        } else if (sector >= sections.v[0] && sector < sec_w) {
+            theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f);
+        } else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
+            theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f);
+        } else if (sector >= sec_w + sections.v[2]) {
+            theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f);
         }
     }
 
@@ -238,37 +264,53 @@ static __global__ void rope_multi(
     dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
 }
 
-template
-static __global__ void rope_vision(
-        const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
-        const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
-        const float theta_scale, const float * freq_factors, const mrope_sections sections) {
+template 
+static __global__ void rope_vision(const T *            x,
+                                   T *                  dst,
+                                   const int            ne00,
+                                   const int            ne01,
+                                   const int            ne02,
+                                   const int            s01,
+                                   const int            s02,
+                                   const int            s03,
+                                   const int            s1,
+                                   const int            s2,
+                                   const int            s3,
+                                   const int            n_dims,
+                                   const int32_t *      pos,
+                                   const float          freq_scale,
+                                   const float          ext_factor,
+                                   const float          attn_factor,
+                                   const rope_corr_dims corr_dims,
+                                   const float          theta_scale,
+                                   const float *        freq_factors,
+                                   const mrope_sections sections) {
     const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
-    if (i0 >= ne0) {
+    if (i0 >= ne00) {
         return;
     }
 
     const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
 
-    const int row_x     = row_dst % ne1;
-    const int channel_x = row_dst / ne1;
+    const uint32_t i3 = row_dst / (ne01 * ne02);
+    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
+    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
 
-    const int idst = row_dst*ne0 + i0/2;
-    const int ix   = channel_x*s2 + row_x*s1 + i0/2;
+    int       idst = i0 / 2 + i1 * s1  + i2 * s2  + i3 * s3;
+    const int ix   = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
 
     const int sect_dims = sections.v[0] + sections.v[1];
-    const int sec_w = sections.v[1] + sections.v[0];
-    const int sector = (i0 / 2) % sect_dims;
+    const int sec_w     = sections.v[1] + sections.v[0];
+    const int sector    = (i0 / 2) % sect_dims;
 
     float theta_base = 0.0;
     if (sector < sections.v[0]) {
         const int p = sector;
-        theta_base = pos[channel_x]*powf(theta_scale, p);
-    }
-    else if (sector >= sections.v[0] && sector < sec_w) {
+        theta_base  = pos[i2] * powf(theta_scale, p);
+    } else if (sector >= sections.v[0] && sector < sec_w) {
         const int p = sector - sections.v[0];
-        theta_base = pos[channel_x + ne2]*powf(theta_scale, p);
+        theta_base  = pos[i2 + ne02] * powf(theta_scale, p);
     }
 
     const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -288,10 +330,15 @@ static __global__ void rope_vision(
 template 
 static void rope_norm_cuda(const T *            x,
                            D *                  dst,
-                           const int            ne0,
-                           const int            ne1,
+                           const int            ne00,
+                           const int            ne01,
+                           const int            ne02,
+                           const int            s01,
+                           const int            s02,
+                           const int            s03,
                            const int            s1,
                            const int            s2,
+                           const int            s3,
                            const int            n_dims,
                            const int            nr,
                            const int32_t *      pos,
@@ -304,31 +351,36 @@ static void rope_norm_cuda(const T *            x,
                            const int64_t *      row_indices,
                            const int            set_rows_stride,
                            cudaStream_t         stream) {
-    GGML_ASSERT(ne0 % 2 == 0);
+    GGML_ASSERT(ne00 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
-    const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+    const int  n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
     const dim3 block_nums(nr, n_blocks_x, 1);
 
-    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+    const float theta_scale = powf(freq_base, -2.0f / n_dims);
 
     if (freq_factors == nullptr) {
         rope_norm<<>>(
-            x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
-            freq_factors, row_indices, set_rows_stride);
+            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
     } else {
         rope_norm<<>>(
-            x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
-            freq_factors, row_indices, set_rows_stride);
+            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
     }
 }
 
 template 
 static void rope_neox_cuda(const T *            x,
                            D *                  dst,
-                           const int            ne0,
-                           const int            ne1,
+                           const int            ne00,
+                           const int            ne01,
+                           const int            ne02,
+                           const int            s01,
+                           const int            s02,
+                           const int            s03,
                            const int            s1,
                            const int            s2,
+                           const int            s3,
                            const int            n_dims,
                            const int            nr,
                            const int32_t *      pos,
@@ -341,55 +393,92 @@ static void rope_neox_cuda(const T *            x,
                            const int64_t *      row_indices,
                            const int            set_rows_stride,
                            cudaStream_t         stream) {
-    GGML_ASSERT(ne0 % 2 == 0);
+    GGML_ASSERT(ne00 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
-    const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+    const int  n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
     const dim3 block_nums(nr, n_blocks_x, 1);
 
-    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+    const float theta_scale = powf(freq_base, -2.0f / n_dims);
 
     if (freq_factors == nullptr) {
         rope_neox<<>>(
-            x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
-            freq_factors, row_indices, set_rows_stride);
+            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
     } else {
         rope_neox<<>>(
-            x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
-            freq_factors, row_indices, set_rows_stride);
+            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
     }
 }
 
-template
-static void rope_multi_cuda(
-        const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
-        const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
-        const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) {
-    GGML_ASSERT(ne0 % 2 == 0);
+template 
+static void rope_multi_cuda(const T *            x,
+                            T *                  dst,
+                            const int            ne00,
+                            const int            ne01,
+                            const int            ne02,
+                            const int            s01,
+                            const int            s02,
+                            const int            s03,
+                            const int            s1,
+                            const int            s2,
+                            const int            s3,
+                            const int            n_dims,
+                            const int            nr,
+                            const int32_t *      pos,
+                            const float          freq_scale,
+                            const float          freq_base,
+                            const float          ext_factor,
+                            const float          attn_factor,
+                            const rope_corr_dims corr_dims,
+                            const float *        freq_factors,
+                            const mrope_sections sections,
+                            const bool           is_imrope,
+                            cudaStream_t         stream) {
+    GGML_ASSERT(ne00 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
-    const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+    const int  n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
     const dim3 block_nums(nr, n_blocks_x, 1);
 
-    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+    const float theta_scale = powf(freq_base, -2.0f / n_dims);
 
     if (freq_factors == nullptr) {
         rope_multi<<>>(
-            x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
             attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
     } else {
         rope_multi<<>>(
-            x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
             attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
     }
 }
 
-template
-static void rope_vision_cuda(
-        const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
-        const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
-        const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
-    GGML_ASSERT(ne0 % 2 == 0);
+template 
+static void rope_vision_cuda(const T *            x,
+                             T *                  dst,
+                             const int            ne00,
+                             const int            ne01,
+                             const int            ne02,
+                             const int            s01,
+                             const int            s02,
+                             const int            s03,
+                             const int            s1,
+                             const int            s2,
+                             const int            s3,
+                             const int            n_dims,
+                             const int            nr,
+                             const int32_t *      pos,
+                             const float          freq_scale,
+                             const float          freq_base,
+                             const float          ext_factor,
+                             const float          attn_factor,
+                             const rope_corr_dims corr_dims,
+                             const float *        freq_factors,
+                             const mrope_sections sections,
+                             cudaStream_t         stream) {
+    GGML_ASSERT(ne00 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
-    const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+    const int  n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
     const dim3 block_nums(nr, n_blocks_x, 1);
     // break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq)
     // where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE);
@@ -398,11 +487,11 @@ static void rope_vision_cuda(
 
     if (freq_factors == nullptr) {
         rope_vision<<>>(
-            x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
             attn_factor, corr_dims, theta_scale, freq_factors, sections);
     } else {
         rope_vision<<>>(
-            x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
             attn_factor, corr_dims, theta_scale, freq_factors, sections);
     }
 }
@@ -445,6 +534,11 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
 
     const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
     const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
+    const size_t s03 = src0->nb[3] / ggml_type_size(src0->type);
+
+    const size_t s1 = dst->nb[1] / ggml_type_size(dst->type);
+    const size_t s2 = dst->nb[2] / ggml_type_size(dst->type);
+    const size_t s3 = dst->nb[3] / ggml_type_size(dst->type);
 
     //const int n_past     = ((int32_t *) dst->op_params)[0];
     const int n_dims     = ((int32_t *) dst->op_params)[1];
@@ -495,57 +589,63 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
     // compute
     if (is_neox) {
         if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
-            rope_neox_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
-                                                  nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
-                                                  freq_factors, row_indices, set_rows_stride, stream);
+            rope_neox_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02,
+                                                  s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
+                                                  ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
+                                                  set_rows_stride, stream);
         } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
-            rope_neox_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
-                                                 nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
-                                                 freq_factors, row_indices, set_rows_stride, stream);
+            rope_neox_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
+                                                 s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
+                                                 ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
+                                                 set_rows_stride, stream);
         } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
-            rope_neox_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
-                                                pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
-                                                freq_factors, row_indices, set_rows_stride, stream);
+            rope_neox_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
+                                                s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
+                                                ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
+                                                set_rows_stride, stream);
         } else {
             GGML_ABORT("fatal error");
         }
     } else if (is_mrope && !is_vision) {
         if (src0->type == GGML_TYPE_F32) {
-            rope_multi_cuda(
-                (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
-                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
+            rope_multi_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
+                                     s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
+                                     corr_dims, freq_factors, sections, is_imrope, stream);
         } else if (src0->type == GGML_TYPE_F16) {
-            rope_multi_cuda(
-                (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
-                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
+            rope_multi_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
+                                     s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
+                                     corr_dims, freq_factors, sections, is_imrope, stream);
         } else {
             GGML_ABORT("fatal error");
         }
     } else if (is_vision) {
         if (src0->type == GGML_TYPE_F32) {
-            rope_vision_cuda(
-                (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
-                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
+            rope_vision_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
+                                      s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
+                                      corr_dims, freq_factors, sections, stream);
         } else if (src0->type == GGML_TYPE_F16) {
-            rope_vision_cuda(
-                (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
-                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
+            rope_vision_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
+                                      s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
+                                      corr_dims, freq_factors, sections, stream);
         } else {
             GGML_ABORT("fatal error");
         }
     } else {
         if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
-            rope_norm_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
-                                                  nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
-                                                  freq_factors, row_indices, set_rows_stride, stream);
+            rope_norm_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02,
+                                                  s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
+                                                  ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
+                                                  set_rows_stride, stream);
         } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
-            rope_norm_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
-                                                 nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
-                                                 freq_factors, row_indices, set_rows_stride, stream);
+            rope_norm_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
+                                                 s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
+                                                 ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
+                                                 set_rows_stride, stream);
         } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
-            rope_norm_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
-                                                pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
-                                                freq_factors, row_indices, set_rows_stride, stream);
+            rope_norm_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
+                                                s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
+                                                ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
+                                                set_rows_stride, stream);
         } else {
             GGML_ABORT("fatal error");
         }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/softmax.cu b/ml/backend/ggml/ggml/src/ggml-cuda/softmax.cu
index eeacde0bdb1..dc06d06930e 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/softmax.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/softmax.cu
@@ -1,6 +1,14 @@
 #include "common.cuh"
 #include "ggml.h"
 #include "softmax.cuh"
+
+#ifdef GGML_USE_HIP
+#include 
+#else
+#include 
+#include 
+#endif // GGML_USE_HIP
+
 #include 
 #include 
 
@@ -67,9 +75,6 @@ static __global__ void soft_max_f32(
 
     const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
 
-    const int warp_id = threadIdx.x / WARP_SIZE;
-    const int lane_id = threadIdx.x % WARP_SIZE;
-
     const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
 
     extern __shared__ float data_soft_max_f32[];
@@ -94,21 +99,7 @@ static __global__ void soft_max_f32(
     }
 
     // find the max value in the block
-    max_val = warp_reduce_max(max_val);
-    if (block_size > WARP_SIZE) {
-        if (warp_id == 0) {
-            buf_iw[lane_id] = -INFINITY;
-        }
-        __syncthreads();
-
-        if (lane_id == 0) {
-            buf_iw[warp_id] = max_val;
-        }
-        __syncthreads();
-
-        max_val = buf_iw[lane_id];
-        max_val = warp_reduce_max(max_val);
-    }
+    max_val = block_reduce(max_val, buf_iw);
 
     float tmp = 0.0f; // partial sum
 
@@ -126,22 +117,7 @@ static __global__ void soft_max_f32(
     }
 
     // find the sum of exps in the block
-    tmp = warp_reduce_sum(tmp);
-    if (block_size > WARP_SIZE) {
-        __syncthreads();
-        if (warp_id == 0) {
-            buf_iw[lane_id] = 0.0f;
-        }
-        __syncthreads();
-
-        if (lane_id == 0) {
-            buf_iw[warp_id] = tmp;
-        }
-        __syncthreads();
-
-        tmp = buf_iw[lane_id];
-        tmp = warp_reduce_sum(tmp);
-    }
+    tmp = block_reduce(tmp, buf_iw);
 
     if (sinks) {
         tmp += expf(sinks[i02] - max_val);
@@ -160,6 +136,113 @@ static __global__ void soft_max_f32(
         dst[col] = vals[col] * inv_sum;
     }
 }
+
+// TODO: Template to allow keeping ncols in registers if they fit
+static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x,
+                                                                float * __restrict__ dst,
+                                                                float * __restrict__ tmp_maxs,
+                                                                float * __restrict__ tmp_sums,
+                                                                const soft_max_params p) {
+    namespace cg = cooperative_groups;
+
+    const cg::grid_group g = cg::this_grid();
+
+    const int tid               = threadIdx.x;
+    const int col_start         = blockIdx.x * blockDim.x + tid;
+    const int n_elem_per_thread = 4;
+
+    float     local_vals[n_elem_per_thread] = { -INFINITY, -INFINITY, -INFINITY, -INFINITY };
+    float     local_max                     = -INFINITY;
+    const int step_size                     = gridDim.x * blockDim.x;
+    __shared__ float shared_vals[32];
+
+    // Compute thread-local max
+    for (int col = col_start; col < p.ncols;) {
+#pragma unroll
+        for (int i = 0; i < n_elem_per_thread; i++) {
+            const int idx = col + i * step_size;
+            local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;
+        }
+#pragma unroll
+        for (int i = 0; i < n_elem_per_thread; i++) {
+            local_max = fmaxf(local_max, local_vals[i]);
+        }
+        col += step_size * n_elem_per_thread;
+    }
+
+    // Compute CTA-level max
+    local_max = block_reduce(local_max, shared_vals);
+
+    // Store CTA-level max to GMEM
+    if (tid == 0) {
+        tmp_maxs[blockIdx.x] = local_max;
+    }
+    g.sync();
+
+    // Compute compute global max from CTA-level maxs
+    assert(gridDim.x < blockDim.x);  // currently we only support this case
+    if (tid < gridDim.x) {
+        local_max = tmp_maxs[tid];
+    } else {
+        local_max = -INFINITY;
+    }
+    local_max = block_reduce(local_max, shared_vals);
+
+    // Compute softmax dividends, accumulate divisor
+    float tmp_expf = 0.0f;
+    for (int col = col_start; col < p.ncols;) {
+#pragma unroll
+        for (int i = 0; i < n_elem_per_thread; i++) {
+            const int idx = col + i * step_size;
+            local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;
+        }
+#pragma unroll
+        for (int i = 0; i < n_elem_per_thread; i++) {
+            const int idx = col + i * step_size;
+            if (idx < p.ncols) {
+                const float tmp = expf(local_vals[i] - local_max);
+                tmp_expf += tmp;
+                dst[idx] = tmp;
+            }
+        }
+        col += step_size * n_elem_per_thread;
+    }
+
+    // Reduce divisor within CTA
+    tmp_expf = block_reduce(tmp_expf, shared_vals);
+
+    // Store CTA-level sum to GMEM
+    if (tid == 0) {
+        tmp_sums[blockIdx.x] = tmp_expf;
+    }
+    g.sync();
+
+    // Compute global sum from CTA-level sums
+    if (tid < gridDim.x) {
+        tmp_expf = tmp_sums[tid];
+    } else {
+        tmp_expf = 0.0f;
+    }
+    tmp_expf = block_reduce(tmp_expf, shared_vals);
+
+    // Divide dividend by global sum + store data
+    for (int col = col_start; col < p.ncols;) {
+#pragma unroll
+        for (int i = 0; i < n_elem_per_thread; i++) {
+            const int idx = col + i * step_size;
+            local_vals[i] = idx < p.ncols ? dst[idx] : -INFINITY;
+        }
+#pragma unroll
+        for (int i = 0; i < n_elem_per_thread; i++) {
+            const int idx = col + i * step_size;
+            if (idx < p.ncols) {
+                dst[idx] = local_vals[i] / tmp_expf;
+            }
+        }
+        col += step_size * n_elem_per_thread;
+    }
+}
+
 #ifdef __clang__
 #pragma clang diagnostic pop
 #endif // __clang__
@@ -216,9 +299,31 @@ static void launch_soft_max_kernels(const float * x, const T * mask, const float
     soft_max_f32<<>>(x, mask, sinks, dst, p);
 }
 
+__launch_bounds__(8*WARP_SIZE, 1) static __global__ void soft_max_f32_parallelize_cols(const float * __restrict__ x,
+                                                     float * __restrict__ dst,
+                                                     float * __restrict__ tmp_maxs,
+                                                     float * __restrict__ tmp_sums,
+                                                     const soft_max_params p)
+// We loop over all instead of parallelizing across gridDim.y as cooperative groups
+// currently only support synchronizing the complete grid if not launched as a cluster group
+// (which requires CC > 9.0)
+// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#grid-synchronization
+// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#class-cluster-group
+{
+    for (int rowx = 0; rowx < p.ne01 * p.ne02 * p.ne03; rowx++) {
+        soft_max_f32_parallelize_cols_single_row(x + int64_t(rowx) * p.ncols, dst + int64_t(rowx) * p.ncols, tmp_maxs,
+                                                 tmp_sums, p);
+    }
+}
 
-template
-static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) {
+template 
+static void soft_max_f32_cuda(const float *                                x,
+                              const T *                                    mask,
+                              const float *                                sinks,
+                              float *                                      dst,
+                              const soft_max_params &                      params,
+                              cudaStream_t                                 stream,
+                              [[maybe_unused]] ggml_backend_cuda_context & ctx) {
     int nth = WARP_SIZE;
     const int64_t ncols_x = params.ncols;
 
@@ -236,8 +341,25 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const float * sin
     if (nbytes_shared <= smpbo) {
         launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
     } else {
-        const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
-        soft_max_f32<<>>(x, mask, sinks, dst, params);
+        // Parallelize across SMs for top-p/dist-sampling
+        // The heuristic for parallelizing rows across SMs vs parallelizing single row & looping over all rows was done on the basis of a B6000 GPU and
+        // Can be adapted further for lower-SM-count GPUs, though keeping data in registers should be implemented first as that is the optimal solution.
+        if (ggml_cuda_info().devices[id].supports_cooperative_launch &&
+            ncols_x / (params.ne01 * params.ne02 * params.ne03) > 8192 && mask == nullptr && sinks == nullptr &&
+            params.scale == 1.0f && params.max_bias == 0.0f) {
+            ggml_cuda_pool_alloc tmp_maxs_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float));
+            ggml_cuda_pool_alloc tmp_sums_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float));
+
+            void * kernel_args[] = { (void *) &x, (void *) &dst, (void *) &tmp_maxs_alloc.ptr,
+                                     (void *) &tmp_sums_alloc.ptr, (void *) const_cast(¶ms) };
+            CUDA_CHECK(cudaLaunchCooperativeKernel((void *) soft_max_f32_parallelize_cols,
+                                                   dim3(ggml_cuda_info().devices[id].nsm, 1, 1),
+                                                   dim3(WARP_SIZE * 8, 1, 1), kernel_args, 0, stream));
+        } else {
+            const size_t nbytes_shared_low = WARP_SIZE * sizeof(float);
+            soft_max_f32
+                <<>>(x, mask, sinks, dst, params);
+        }
     }
 }
 
@@ -315,9 +437,9 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     params.m1 = m1;
 
     if (use_f16) {
-        soft_max_f32_cuda(src0_d, (const half  *) src1_d, (const float *) src2_d, dst_d, params, stream);
+        soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx);
     } else {
-        soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream);
+        soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx);
     }
 }
 
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/ssm-conv.cu b/ml/backend/ggml/ggml/src/ggml-cuda/ssm-conv.cu
index 41979733601..6d5ea704c65 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/ssm-conv.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/ssm-conv.cu
@@ -102,31 +102,25 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
     const int threads = 128;
     GGML_ASSERT(nr % threads == 0);
 
-    if (n_t <= 32) {
-        const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
-        if (nc == 4) {
-            ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
-                                                                     dst, dst_nb0, dst_nb1, dst_nb2, n_t);
-        } else if (nc == 3) {
-            ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
-                                                                     dst, dst_nb0, dst_nb1, dst_nb2, n_t);
+    auto launch_kernel = [&](auto NC) {
+        constexpr int kNC = decltype(NC)::value;
+        if (n_t <= 32) {
+            const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
+            ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
+                                                                       dst, dst_nb0, dst_nb1, dst_nb2, n_t);
         } else {
-            GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
-        }
-    } else {
-        if (nc == 4) {
-            const int64_t split_n_t = 32;
-            dim3          blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
-            ssm_conv_long_token_f32<<>>(
-                src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
-        } else if (nc == 3) {
             const int64_t split_n_t = 32;
             dim3          blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
-            ssm_conv_long_token_f32<<>>(
+            ssm_conv_long_token_f32<<>>(
                 src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
-        } else {
-            GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
         }
+    };
+
+    switch (nc) {
+        case 3: launch_kernel(std::integral_constant{}); break;
+        case 4: launch_kernel(std::integral_constant{}); break;
+        case 9: launch_kernel(std::integral_constant{}); break;
+        default: GGML_ABORT("Only support kernel sizes 3, 4, 9 right now.");
     }
 }
 
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/ssm-scan.cu b/ml/backend/ggml/ggml/src/ggml-cuda/ssm-scan.cu
index 6b424381df5..c1d4e2bc8df 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/ssm-scan.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/ssm-scan.cu
@@ -114,7 +114,7 @@ __global__ void __launch_bounds__(splitD, 1)
 #endif // __clang__
 
 // assumes as many threads as d_state
-template 
+template 
 __global__ void __launch_bounds__(d_state, 1)
     ssm_scan_f32_group(
         const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
@@ -125,20 +125,25 @@ __global__ void __launch_bounds__(d_state, 1)
         const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
         const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {
 
-    const int head_idx = (blockIdx.x * splitH) / d_head;
-    const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
-    const int seq_idx = blockIdx.y;
+    const int warp     = threadIdx.x / WARP_SIZE;
+    const int lane     = threadIdx.x % WARP_SIZE;
+    const int warp_idx = blockIdx.x  * c_factor + warp;
+
+    const int head_idx =  warp_idx / d_head;
+    const int head_off = (warp_idx % d_head) * sizeof(float);
+    const int seq_idx  = blockIdx.y;
 
     const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);
 
-    const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
-    const float * x_block  = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));
-    const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
-    const float * A_block  = (const float *) ((const char *) src3 + head_idx * src3_nb1);
-    const float * B_block  = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
-    const float * C_block  = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
-    float *       y_block  = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH;
-    float *       s_block  = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
+    // TODO: refactor strides to be in elements/floats instead of bytes to be cleaner and consistent with the rest of the codebase
+    const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
+    const float * x_warp  = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float)));
+    const float * dt_warp = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
+    const float * A_warp  = (const float *) ((const char *) src3 + head_idx * src3_nb1);
+    const float * B_warp  = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
+    const float * C_warp  = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
+    float *       y_warp  = dst + (seq_idx * n_tok * n_head * d_head) + warp_idx;
+    float *       s_warp  = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
 
     // strides across n_seq_tokens
     const int stride_x  = src1_nb2 / sizeof(float);
@@ -147,80 +152,42 @@ __global__ void __launch_bounds__(d_state, 1)
     const int stride_C  = src5_nb2 / sizeof(float);
     const int stride_y  = n_head * d_head;
 
-    float state[splitH];
-    // for the parallel accumulation
-    __shared__ float stateC[splitH * d_state];
+    float state[c_factor];
+    float state_sum = 0.0f;
 
 #pragma unroll
-    for (int j = 0; j < splitH; j++) {
-        state[j] = s0_block[j * d_state + threadIdx.x];
+    for (int j = 0; j < c_factor; j++) {
+        state[j] = s0_warp[WARP_SIZE * j + lane];
     }
 
     for (int64_t i = 0; i < n_tok; i++) {
-        // TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements
-        // TODO: only calculate B and C once per head group
-        // NOTE: dt_soft_plus, dA and x_dt have the same value across threads here.
-        float dt_soft_plus = dt_block[i * stride_dt];
-        if (dt_soft_plus <= 20.0f) {
-            dt_soft_plus = log1pf(expf(dt_soft_plus));
-        }
-        const float dA = expf(dt_soft_plus * A_block[0]);
-        const float B = B_block[i * stride_B + threadIdx.x];
-        const float C = C_block[i * stride_C + threadIdx.x];
+        // NOTE: dt_soft_plus, dA and x_dt have the same value for a warp here.
+        // Recalculation is intentional; sharing via shuffles/smem proved slower due to sync overhead.
+        const float dt_soft_plus = (dt_warp[i * stride_dt] <= 20.0f ? log1pf(expf(dt_warp[i * stride_dt])) : dt_warp[i * stride_dt]);
 
-        // across d_head
+        state_sum = 0.0f;
+        const float dA   = expf(dt_soft_plus * A_warp[0]);
+        const float x_dt = x_warp[i * stride_x] * dt_soft_plus;
 #pragma unroll
-        for (int j = 0; j < splitH; j++) {
-            const float x_dt = x_block[i * stride_x + j] * dt_soft_plus;
-
-            state[j] = (state[j] * dA) + (B * x_dt);
-
-            stateC[j * d_state + threadIdx.x] = state[j] * C;
+        for (int j = 0; j < c_factor; j++) {
+            const float B_val = B_warp[i * stride_B + WARP_SIZE * j + lane];
+            const float C_val = C_warp[i * stride_C + WARP_SIZE * j + lane];
+            state[j] = (state[j] * dA) + (B_val * x_dt);
+            state_sum += state[j] * C_val;
         }
 
-        __syncthreads();
-
-        // parallel accumulation for stateC
-        // TODO: simplify
-        {
-            static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2");
-            static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2");
-
-            // reduce until w matches the warp size
-            // TODO: does this work even when the physical warp size is 64?
-#pragma unroll
-            for (int w = d_state; w > WARP_SIZE; w >>= 1) {
-                // (assuming there are d_state threads)
-#pragma unroll
-                for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) {
-                    // TODO: check for bank conflicts
-                    const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1));
-                    stateC[k] += stateC[k + (w >> 1)];
-
-                }
-                __syncthreads();
-            }
-
-            static_assert(splitH >= d_state / WARP_SIZE);
+        // parallel accumulation for output
+        state_sum = warp_reduce_sum(state_sum);
 
-#pragma unroll
-            for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) {
-                float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)];
-                y = warp_reduce_sum(y);
-
-                // store the above accumulations
-                if (threadIdx.x % WARP_SIZE == 0) {
-                    const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE);
-                    y_block[i * stride_y + k] = y;
-                }
-            }
+        if (lane == 0) {
+            y_warp[i * stride_y] = state_sum;
         }
     }
 
     // write back the state
 #pragma unroll
-    for (int j = 0; j < splitH; j++) {
-        s_block[j * d_state + threadIdx.x] = state[j];
+    for (int j = 0; j < c_factor; j++) {
+        s_warp[WARP_SIZE * j + lane] = state[j];
     }
 }
 
@@ -231,27 +198,24 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
                               const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
                               const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
                               cudaStream_t stream) {
-    const int threads = 128;
     // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
     if (src3_nb1 == sizeof(float)) {
         // Mamba-2
         if (d_state == 128) {
-            GGML_ASSERT(d_state % threads == 0);
-            // NOTE: can be any power of two between 4 and 64
-            const int splitH = 16;
-            GGML_ASSERT(head_dim % splitH == 0);
-            const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
-            ssm_scan_f32_group<16, 128><<>>(
+            constexpr int threads   = 128;
+            constexpr int num_warps = threads/WARP_SIZE;
+
+            const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1);
+            ssm_scan_f32_group<128/WARP_SIZE, 128><<>>(
                     src0, src1, src2, src3, src4, src5, src6, dst,
                     src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
                     src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
         } else if (d_state == 256) { // Falcon-H1
-            const int threads = 256;
-            // NOTE: can be any power of two between 8 and 64
-            const int splitH = 16;
-            GGML_ASSERT(head_dim % splitH == 0);
-            const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
-            ssm_scan_f32_group<16, 256><<>>(
+            constexpr int threads   = 256;
+            constexpr int num_warps = threads/WARP_SIZE;
+
+            const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1);
+            ssm_scan_f32_group<256/WARP_SIZE, 256><<>>(
                     src0, src1, src2, src3, src4, src5, src6, dst,
                     src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
                     src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
@@ -260,6 +224,7 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
         }
     } else {
         // Mamba-1
+        constexpr int threads = 128;
         GGML_ASSERT(n_head % threads == 0);
         GGML_ASSERT(head_dim == 1);
         GGML_ASSERT(n_group == 1);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu
new file mode 100644
index 00000000000..1f554d81e5e
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu
new file mode 100644
index 00000000000..264751d65ec
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/top-k.cu b/ml/backend/ggml/ggml/src/ggml-cuda/top-k.cu
new file mode 100644
index 00000000000..785a18389f2
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/top-k.cu
@@ -0,0 +1,95 @@
+#include "argsort.cuh"
+#include "top-k.cuh"
+
+#ifdef GGML_CUDA_USE_CUB
+#    include 
+#    if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2)
+#        define CUB_TOP_K_AVAILABLE
+using namespace cub;
+#    endif  // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2
+#endif      // GGML_CUDA_USE_CUB
+
+#ifdef CUB_TOP_K_AVAILABLE
+
+static void top_k_cub(ggml_cuda_pool & pool,
+                      const float *    src,
+                      int *            dst,
+                      const int        ncols,
+                      const int        k,
+                      cudaStream_t     stream) {
+    auto requirements = cuda::execution::require(cuda::execution::determinism::not_guaranteed,
+                                                 cuda::execution::output_ordering::unsorted);
+    auto stream_env   = cuda::stream_ref{ stream };
+    auto env          = cuda::std::execution::env{ stream_env, requirements };
+
+    auto indexes_in = cuda::make_counting_iterator(0);
+
+    size_t temp_storage_bytes = 0;
+    DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k,
+                         env);
+
+    ggml_cuda_pool_alloc temp_storage_alloc(pool, temp_storage_bytes);
+    void *                        d_temp_storage = temp_storage_alloc.get();
+
+    DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst,
+                         ncols, k, env);
+}
+
+#elif defined(GGML_CUDA_USE_CUB)  // CUB_TOP_K_AVAILABLE
+
+static int next_power_of_2(int x) {
+    int n = 1;
+    while (n < x) {
+        n *= 2;
+    }
+    return n;
+}
+
+#endif                            // CUB_TOP_K_AVAILABLE
+
+void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0   = dst->src[0];
+    const float *       src0_d = (const float *) src0->data;
+    int *               dst_d  = (int *) dst->data;
+    cudaStream_t        stream = ctx.stream();
+
+    // are these asserts truly necessary?
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    const int64_t    ncols = src0->ne[0];
+    const int64_t    nrows = ggml_nrows(src0);
+    const int64_t    k     = dst->ne[0];
+    ggml_cuda_pool & pool  = ctx.pool();
+#ifdef CUB_TOP_K_AVAILABLE
+    // TODO: Switch to `DeviceSegmentedTopK` for multi-row TopK once implemented
+    // https://github.com/NVIDIA/cccl/issues/6391
+    // TODO: investigate if there exists a point where parallelized argsort is faster than sequential top-k
+    for (int i = 0; i < nrows; i++) {
+        top_k_cub(pool, src0_d + i * ncols, dst_d + i * k, ncols, k, stream);
+    }
+#elif defined(GGML_CUDA_USE_CUB)  // CUB_TOP_K_AVAILABLE
+    // Fall back to argsort + copy
+    const int    ncols_pad      = next_power_of_2(ncols);
+    const size_t shared_mem     = ncols_pad * sizeof(int);
+    const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
+
+    ggml_cuda_pool_alloc temp_dst_alloc(pool, ncols * nrows);
+    int *                     tmp_dst = temp_dst_alloc.get();
+
+    if (shared_mem > max_shared_mem || ncols > 1024) {
+        argsort_f32_i32_cuda_cub(pool, src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
+    } else {
+        argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
+    }
+    CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows,
+                                 cudaMemcpyDeviceToDevice, stream));
+#else                             // GGML_CUDA_USE_CUB
+    ggml_cuda_pool_alloc temp_dst_alloc(pool, ncols * nrows);
+    int *                     tmp_dst = temp_dst_alloc.get();
+    argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
+    CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows,
+                                 cudaMemcpyDeviceToDevice, stream));
+#endif
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/top-k.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/top-k.cuh
new file mode 100644
index 00000000000..f4d8f61e5b3
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/top-k.cuh
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cu b/ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cu
index 572379fcbf0..08a88990dde 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cu
@@ -5,6 +5,13 @@
 #include 
 #include 
 
+// Kernel config struct - passed by value to CUDA kernel
+struct topk_moe_config {
+    bool use_sigmoid;
+    bool with_norm;
+    bool delayed_softmax;
+};
+
 // Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
 template 
 __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
@@ -50,6 +57,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in
     }
 }
 
+template 
+__device__ void sigmoid_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
+#pragma unroll
+    for (int i = 0; i < experts_per_thread; i++) {
+        const int  idx    = lane + i * WARP_SIZE;
+        const bool active = !use_limit || (idx < limit);
+        vals[i]           = active ? 1.f / (1.f + expf(-vals[i])) : -INFINITY;
+    }
+}
+
 /*
     This kernel does the following:
     1. optionally softmax over the logits per token [n_experts, n_tokens]
@@ -59,13 +76,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in
 
     It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
 */
-template 
-__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
-                                                                  float *       weights,
-                                                                  int32_t *     ids,
-                                                                  const int     n_rows,
-                                                                  const int     n_expert_used,
-                                                                  const float   clamp_val) {
+template 
+__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *         logits,
+                                                                  float *               weights,
+                                                                  int32_t *             ids,
+                                                                  float *               bias,
+                                                                  const int             n_rows,
+                                                                  const int             n_expert_used,
+                                                                  const float           clamp_val,
+                                                                  const float           scale_val,
+                                                                  const topk_moe_config config) {
     const int row = blockIdx.x * blockDim.y + threadIdx.y;
     if (row >= n_rows) {
         return;
@@ -79,14 +99,41 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
 
     float wt[experts_per_thread];
 
+    // Initialize all slots to -INFINITY
+#pragma unroll
+    for (int i = 0; i < experts_per_thread; i++) {
+        wt[i] = -INFINITY;
+    }
+
 #pragma unroll
     for (int i = 0; i < n_experts; i += WARP_SIZE) {
         const int expert  = i + threadIdx.x;
         wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
     }
 
-    if constexpr (!delayed_softmax) {
-        softmax_warp_inplace(wt, n_experts, threadIdx.x);
+    if (!config.delayed_softmax) {
+        if (config.use_sigmoid) {
+           sigmoid_warp_inplace(wt, n_experts, threadIdx.x);
+        } else {
+           softmax_warp_inplace(wt, n_experts, threadIdx.x);
+        }
+    }
+
+    // selection_wt is only needed when bias is present (selection uses wt + bias)
+    // when no bias, we use wt directly for both selection and weight values
+    float selection_wt[has_bias ? experts_per_thread : 1];
+
+    if constexpr (has_bias) {
+#pragma unroll
+        for (int i = 0; i < experts_per_thread; i++) {
+            selection_wt[i] = -INFINITY;
+        }
+#pragma unroll
+        for (int i = 0; i < n_experts; i += WARP_SIZE) {
+            const int expert = i + threadIdx.x;
+            selection_wt[i / WARP_SIZE] =
+                (n_experts % WARP_SIZE == 0 || expert < n_experts) ? wt[i / WARP_SIZE] + bias[expert] : -INFINITY;
+        }
     }
 
     //at this point, each thread holds either a portion of the softmax distribution
@@ -106,22 +153,56 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
         float max_val    = wt[0];
         int   max_expert = threadIdx.x;
 
+        if constexpr (has_bias) {
+            float max_val_s = selection_wt[0];
+
 #pragma unroll
-        for (int i = 1; i < experts_per_thread; i++) {
-            const int expert = threadIdx.x + i * WARP_SIZE;
-            if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
-                max_val    = wt[i];
-                max_expert = expert;
+            for (int i = 1; i < experts_per_thread; i++) {
+                const int expert = threadIdx.x + i * WARP_SIZE;
+                if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_wt[i] > max_val_s) {
+                    max_val    = wt[i];
+                    max_val_s  = selection_wt[i];
+                    max_expert = expert;
+                }
             }
-        }
 
 #pragma unroll
-        for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
-            const float val    = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
-            const int   expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
-            if (val > max_val || (val == max_val && expert < max_expert)) {
-                max_val    = val;
-                max_expert = expert;
+            for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
+                const float val    = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
+                const float val_s  = __shfl_xor_sync(0xFFFFFFFF, max_val_s, mask, WARP_SIZE);
+                const int   expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
+                if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) {
+                    max_val    = val;
+                    max_val_s  = val_s;
+                    max_expert = expert;
+                }
+            }
+
+            if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
+                selection_wt[max_expert / WARP_SIZE] = -INFINITY;
+            }
+        } else {
+#pragma unroll
+            for (int i = 1; i < experts_per_thread; i++) {
+                const int expert = threadIdx.x + i * WARP_SIZE;
+                if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
+                    max_val    = wt[i];
+                    max_expert = expert;
+                }
+            }
+
+#pragma unroll
+            for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
+                const float val    = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
+                const int   expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
+                if (val > max_val || (val == max_val && expert < max_expert)) {
+                    max_val    = val;
+                    max_expert = expert;
+                }
+            }
+
+            if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
+                wt[max_expert / WARP_SIZE] = -INFINITY;
             }
         }
 
@@ -130,16 +211,14 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
         }
 
         if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
-            wt[max_expert / WARP_SIZE] = -INFINITY;
-
             ids[k] = max_expert;
-            if constexpr (with_norm) {
+            if (config.with_norm) {
                 wt_sum += max_val;
             }
         }
     }
 
-    if constexpr (with_norm) {
+    if (config.with_norm) {
         wt_sum              = warp_reduce_sum(wt_sum);
         wt_sum              = max(wt_sum, clamp_val);
         const float inv_sum = 1.0f / wt_sum;
@@ -149,7 +228,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
         }
     }
 
-    if constexpr (delayed_softmax) {
+    if (config.delayed_softmax) {
         softmax_warp_inplace(output_weights, n_expert_used, threadIdx.x);
     }
 
@@ -157,25 +236,25 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
     for (int i = 0; i < experts_per_thread; i++) {
         const int idx = i * WARP_SIZE + threadIdx.x;
         if (idx < n_expert_used) {
-            weights[idx] = output_weights[i];
+            weights[idx] = output_weights[i] * scale_val;
         }
     }
-
-    if (!with_norm) {
-        GGML_UNUSED(clamp_val);
-    }
 }
 
-template 
+template
 static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
                                  const float *               logits,
                                  float *                     weights,
                                  int32_t *                   ids,
+                                 float *                     bias,
                                  const int                   n_rows,
                                  const int                   n_expert,
                                  const int                   n_expert_used,
-                                 const float                 clamp_val) {
-    static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
+                                 const float                 clamp_val,
+                                 const float                 scale_val,
+                                 const topk_moe_config       config) {
+    GGML_ASSERT(!(config.with_norm && config.delayed_softmax) &&
+                "delayed softmax is not supported with weight normalization");
     const int    rows_per_block = 4;
     dim3         grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
     dim3         block_dims(WARP_SIZE, rows_per_block, 1);
@@ -183,44 +262,48 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
 
     switch (n_expert) {
         case 1:
-            topk_moe_cuda<1, with_norm, delayed_softmax>
-                <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<1, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                   clamp_val, scale_val, config);
             break;
         case 2:
-            topk_moe_cuda<2, with_norm, delayed_softmax>
-                <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<2, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                   clamp_val, scale_val, config);
             break;
         case 4:
-            topk_moe_cuda<4, with_norm, delayed_softmax>
-                <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<4, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                   clamp_val, scale_val, config);
             break;
         case 8:
-            topk_moe_cuda<8, with_norm, delayed_softmax>
-                <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<8, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                   clamp_val, scale_val, config);
             break;
         case 16:
-            topk_moe_cuda<16, with_norm, delayed_softmax>
-                <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<16, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                    clamp_val, scale_val, config);
             break;
         case 32:
-            topk_moe_cuda<32, with_norm, delayed_softmax>
-                <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<32, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                    clamp_val, scale_val, config);
             break;
         case 64:
-            topk_moe_cuda<64, with_norm, delayed_softmax>
-                <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<64, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                    clamp_val, scale_val, config);
             break;
         case 128:
-            topk_moe_cuda<128, with_norm, delayed_softmax>
-                <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<128, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                     clamp_val, scale_val, config);
             break;
         case 256:
-            topk_moe_cuda<256, with_norm, delayed_softmax>
-                <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<256, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                     clamp_val, scale_val, config);
             break;
         case 512:
-            topk_moe_cuda<512, with_norm, delayed_softmax>
-                <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<512, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                     clamp_val, scale_val, config);
+            break;
+        case 576:
+            topk_moe_cuda<576, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                     clamp_val, scale_val, config);
             break;
         default:
             GGML_ASSERT(false && "fatal error");
@@ -228,13 +311,14 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
     }
 }
 
-void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
-                           const ggml_tensor *         logits,
-                           ggml_tensor *               weights,
-                           ggml_tensor *               ids,
-                           const bool                  with_norm,
-                           const bool                  delayed_softmax,
-                           ggml_tensor *               clamp) {
+void ggml_cuda_op_topk_moe(ggml_backend_cuda_context &     ctx,
+                           const ggml_tensor *             logits,
+                           ggml_tensor *                   weights,
+                           ggml_tensor *                   ids,
+                           const ggml_tensor *             clamp,
+                           const ggml_tensor *             scale,
+                           const ggml_tensor *             bias,
+                           const ggml_cuda_topk_moe_args & args) {
     GGML_ASSERT(logits->type == GGML_TYPE_F32);
     GGML_ASSERT(weights->type == GGML_TYPE_F32);
     GGML_ASSERT(ids->type == GGML_TYPE_I32);
@@ -245,92 +329,75 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
     const float * logits_d  = (const float *) logits->data;
     float *       weights_d = (float *) weights->data;
     int32_t *     ids_d     = (int32_t *) ids->data;
+    float *       bias_d    = bias ? (float *) bias->data : nullptr;
+
+    float scale_val = scale ? ggml_get_op_params_f32(scale, 0) : 1.0f;
 
     GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
 
     const int n_expert_used = weights->ne[1];
 
+    const bool with_norm = clamp != nullptr;
+
     float clamp_val = -INFINITY;
-    if (with_norm) {
-        if (clamp) {
-            clamp_val = ggml_get_op_params_f32(clamp, 0);
-        }
-        launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val);
-    } else {
-        GGML_ASSERT(clamp == nullptr);
-        if (delayed_softmax) {
-            launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
-                                              clamp_val);
-        } else {
-            launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
-                                               clamp_val);
-        }
+    if (clamp) {
+        clamp_val = ggml_get_op_params_f32(clamp, 0);
     }
-}
 
-bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) {
-    float scale    = 1.0f;
-    float max_bias = 0.0f;
+    topk_moe_config config;
+    config.use_sigmoid     = args.sigmoid;
+    config.with_norm       = with_norm;
+    config.delayed_softmax = args.delayed_softmax;
 
-    memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
-    memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
-
-    if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
-        return false;
+    if (bias) {
+        launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
+                             scale_val, config);
+    } else {
+        launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
+                             scale_val, config);
     }
+}
 
-    if (scale != 1.0f || max_bias != 0.0f) {
+bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
+                                   const ggml_tensor * weights,
+                                   const ggml_tensor * logits,
+                                   const ggml_tensor * ids) {
+    const int n_expert = ids->nb[1] / ids->nb[0];
+    if (((n_expert & (n_expert - 1)) != 0 || n_expert > 512) && n_expert != 576) {
         return false;
     }
 
-    // don't fuse when masks or sinks are present
-    if (softmax->src[1] || softmax->src[2]) {
+    if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(logits)) {
         return false;
     }
 
-    const int n_expert = softmax->ne[0];
-    // n_expert must be a power of 2
-    if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
-        return false;
-    }
+    if (gating_op->op == GGML_OP_SOFT_MAX) {
+        const ggml_tensor * softmax  = gating_op;
+        float               scale    = 1.0f;
+        float               max_bias = 0.0f;
 
-    if (clamp) {
-        if (clamp->op != GGML_OP_CLAMP) {
+        memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
+        memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
+
+        if (!ggml_is_contiguous(softmax->src[0])) {
             return false;
         }
-        float max_val = ggml_get_op_params_f32(clamp, 1);
 
-        if (max_val != INFINITY) {
+        if (scale != 1.0f || max_bias != 0.0f) {
             return false;
         }
-    }
-
-
-    return true;
-}
-
-std::initializer_list ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
-    static std::initializer_list norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE,  GGML_OP_ARGSORT,
-                                                            GGML_OP_VIEW,     GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
-                                                            GGML_OP_SUM_ROWS, GGML_OP_CLAMP,    GGML_OP_DIV,
-                                                            GGML_OP_RESHAPE };
-
-    static std::initializer_list no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
-                                                               GGML_OP_VIEW, GGML_OP_GET_ROWS };
 
-    static std::initializer_list delayed_softmax_ops = { GGML_OP_ARGSORT,  GGML_OP_VIEW,
-                                                                       GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
-                                                                       GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
-
-    GGML_ASSERT(!norm || !delayed_softmax);
-
-    if (delayed_softmax) {
-        return delayed_softmax_ops;
-    }
+        // don't fuse when masks or sinks are present
+        if (softmax->src[1] || softmax->src[2]) {
+            return false;
+        }
+    } else if (gating_op->op == GGML_OP_UNARY) {
+        ggml_unary_op op = ggml_get_unary_op(gating_op);
 
-    if (norm) {
-        return norm_ops;
+        if (op != GGML_UNARY_OP_SIGMOID) {
+            return false;
+        }
     }
 
-    return no_norm_ops;
+    return true;
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cuh
index 2eff408b030..243dc2f1c41 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cuh
@@ -3,14 +3,25 @@
 
 #include 
 
-void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
-                           const ggml_tensor *         logits,
-                           ggml_tensor *               weights,
-                           ggml_tensor *               ids,
-                           const bool                  with_norm,
-                           const bool                  delayed_softmax = false,
-                           ggml_tensor *               weight_clamp    = nullptr);
+struct ggml_cuda_topk_moe_args {
+    bool sigmoid{};
+    bool softmax{};
+    bool delayed_softmax{};
+    bool prob_bias{};
+    bool norm{};
+    bool scale{};
+};
 
-bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp = nullptr);
+void ggml_cuda_op_topk_moe(ggml_backend_cuda_context &     ctx,
+                           const ggml_tensor *             logits,
+                           ggml_tensor *                   weights,
+                           ggml_tensor *                   ids,
+                           const ggml_tensor *             clamp,
+                           const ggml_tensor *             scale,
+                           const ggml_tensor *             bias,
+                           const ggml_cuda_topk_moe_args & args);
 
-std::initializer_list ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);
+bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
+                                   const ggml_tensor * weights,
+                                   const ggml_tensor * logits,
+                                   const ggml_tensor * ids);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/vecdotq.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/vecdotq.cuh
index 6baab1176ff..ab803aca21b 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/vecdotq.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/vecdotq.cuh
@@ -94,6 +94,15 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
 #endif
 }
 
+static __device__ __forceinline__ uint32_t unpack_ksigns(const uint8_t v) {
+    // v is a 7 bit int, with the 8th sign being encodable as popcnt
+    // with xor we can "correct" the bit instead of having to mask
+    const uint32_t p = __popc(v) & 1;
+    const uint32_t s = v ^ p << 7;
+    // broadcast over uint to allow for 0x08040201 / 0x80402010 as selectors
+    return s * 0x01010101;
+}
+
 // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
 // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
 
@@ -905,22 +914,22 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
     int sumi = 0;
 #pragma unroll
     for (int k0 = 0; k0 < 8; k0 += 2) {
-        const int * grid_pos = (const int *) (iq2xxs_grid + aux8[k0/2]);
-        const int signs_packed = ksigns_iq2xs[(aux32 >> (7*k0/2)) & 0x7F];
+        const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[k0/2]];
+        const uint32_t signs = unpack_ksigns(aux32 >> (7 * k0 / 2));
 
-        const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
-        const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
+        const int signs0 = __vcmpne4(signs & 0x08040201, 0);
+        const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0);
         const int u0 = get_int_b4(bq8_1[iqs/2].qs, k0 + 0);
         sumi = ggml_cuda_dp4a(grid0, u0, sumi);
 
-        const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
-        const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
+        const int signs1 = __vcmpne4(signs & 0x80402010, 0);
+        const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1);
         const int u1 = get_int_b4(bq8_1[iqs/2].qs, k0 + 1);
         sumi = ggml_cuda_dp4a(grid1, u1, sumi);
     }
 
-    const int ls = aux32 >> 28;
-    sumi = (ls*sumi + sumi/2)/4;
+    const int ls = aux32 >> 27 | 1; // (scale * 2 + 1)
+    sumi = sumi * ls / 8;           // (sumi * scale + sumi / 2) / 4
     const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);
     return d * sumi;
 }
@@ -942,13 +951,15 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
     int sumi1 = 0;
 #pragma unroll
     for (int l0 = 0; l0 < 8; l0 += 2) {
-        const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l0/2] & 0x000001FF));
-        const uint32_t * signs    = (const uint32_t *)(ksigns64   + (q2[l0/2] >> 9));
-
-        const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
-        const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
+        const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l0/2] & 0x1FF];
+        const uint32_t signs = unpack_ksigns(q2[l0/2] >> 9);
 
+        const int signs0 = __vcmpne4(signs & 0x08040201, 0);
+        const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
         const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
+
+        const int signs1 = __vcmpne4(signs & 0x80402010, 0);
+        const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
         const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
 
         if (l0 < 4) {
@@ -1028,13 +1039,16 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
 #pragma unroll
     for (int l0 = 0; l0 < 8; l0 += 2) {
         const int2 grid_pos = make_int2(iq3xxs_grid[q3[l0 + 0]], iq3xxs_grid[q3[l0 + 1]]);
+        const uint32_t signs = unpack_ksigns(aux32 >> (7*l0/2));
 
-        const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l0/2)) & 0x7F));
-
-        const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
-        const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
+        const int signs0 = __vcmpne4(signs & 0x08040201, 0);
+        const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
 
         const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
+
+        const int signs1 = __vcmpne4(signs & 0x80402010, 0);
+        const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
+
         const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
 
         sumi = ggml_cuda_dp4a(grid_l, u0, sumi);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/cuda.h b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/cuda.h
index 3b3086778ee..ba032cfab4b 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/cuda.h
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/cuda.h
@@ -10,6 +10,10 @@
 #include 
 #endif // CUDART_VERSION >= 12050
 
+#if CUDART_VERSION >= 12080
+#include 
+#endif // CUDART_VERSION >= 12080
+
 #if CUDART_VERSION < 11020
 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
 #define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h
index d89e35a8edf..14473a97ca2 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h
@@ -47,9 +47,11 @@
 #define cublasSgemm hipblasSgemm
 #define cublasStatus_t hipblasStatus_t
 #define cublasOperation_t hipblasOperation_t
+#define cudaDevAttrCooperativeLaunch hipDeviceAttributeCooperativeLaunch
 #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
 #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
 #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
+#define cudaDeviceGetAttribute hipDeviceGetAttribute
 #define cudaDeviceProp hipDeviceProp_t
 #define cudaDeviceReset hipDeviceReset
 #define cudaDeviceSynchronize hipDeviceSynchronize
@@ -74,6 +76,7 @@
 #define cudaHostRegisterPortable hipHostRegisterPortable
 #define cudaHostRegisterReadOnly hipHostRegisterReadOnly
 #define cudaHostUnregister hipHostUnregister
+#define cudaLaunchCooperativeKernel hipLaunchCooperativeKernel
 #define cudaLaunchHostFunc hipLaunchHostFunc
 #define cudaMalloc hipMalloc
 #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
@@ -139,6 +142,8 @@
 #define cudaStream_t hipStream_t
 #define cudaSuccess hipSuccess
 #define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor
+#define cudaFuncSetAttribute hipFuncSetAttribute
+#define cudaFuncAttributeMaxDynamicSharedMemorySize hipFuncAttributeMaxDynamicSharedMemorySize
 #define __trap() do { abort(); __builtin_unreachable(); } while(0)
 #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
 #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/musa.h b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/musa.h
index 221e67f96a7..1abb8acfd4b 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/musa.h
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/musa.h
@@ -61,6 +61,7 @@
 #define cudaHostRegisterPortable musaHostRegisterPortable
 #define cudaHostRegisterReadOnly musaHostRegisterReadOnly
 #define cudaHostUnregister musaHostUnregister
+#define cudaLaunchCooperativeKernel musaLaunchCooperativeKernel
 #define cudaLaunchHostFunc musaLaunchHostFunc
 #define cudaMalloc musaMalloc
 #define cudaMallocHost musaMallocHost
diff --git a/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt
index 23b6889919f..80037d24361 100644
--- a/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt
+++ b/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt
@@ -62,6 +62,8 @@ file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
 list(APPEND GGML_SOURCES_ROCM ${SRCS})
 file(GLOB   SRCS "../ggml-cuda/template-instances/mmq*.cu")
 list(APPEND GGML_SOURCES_ROCM ${SRCS})
+file(GLOB   SRCS "../ggml-cuda/template-instances/mmf*.cu")
+list(APPEND GGML_SOURCES_ROCM ${SRCS})
 
 if (GGML_CUDA_FA_ALL_QUANTS)
     file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-vec*.cu")
diff --git a/ml/backend/ggml/ggml/src/ggml-impl.h b/ml/backend/ggml/ggml/src/ggml-impl.h
index 7e17032c72c..b502fdf78e1 100644
--- a/ml/backend/ggml/ggml/src/ggml-impl.h
+++ b/ml/backend/ggml/ggml/src/ggml-impl.h
@@ -24,10 +24,6 @@
 #include 
 #endif
 
-#if defined(__F16C__)
-#include 
-#endif
-
 #ifdef __cplusplus
 extern "C" {
 #endif
@@ -102,6 +98,10 @@ static bool ggml_op_is_empty(enum ggml_op op) {
     }
 }
 
+static inline bool ggml_impl_is_view(const struct ggml_tensor * t) {
+    return t->view_src != NULL;
+}
+
 static inline float ggml_compute_softplus_f32(float input) {
     return (input > 20.0f) ? input : logf(1 + expf(input));
 }
@@ -615,6 +615,9 @@ static inline bool ggml_can_fuse_ext(const struct ggml_cgraph * cgraph, const in
         if (node->op != ops[i]) {
             return false;
         }
+        if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
+            return false;
+        }
         if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idxs[i], 1)) {
             return false;
         }
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-metal/CMakeLists.txt
index 63418fe1430..42054d841aa 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/CMakeLists.txt
+++ b/ml/backend/ggml/ggml/src/ggml-metal/CMakeLists.txt
@@ -23,11 +23,6 @@ if (GGML_METAL_NDEBUG)
     add_compile_definitions(GGML_METAL_NDEBUG)
 endif()
 
-# copy metal files to bin directory
-configure_file(../ggml-common.h  ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h     COPYONLY)
-configure_file(ggml-metal.metal  ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal  COPYONLY)
-configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY)
-
 set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h")
 if (GGML_METAL_EMBED_LIBRARY)
     enable_language(ASM)
@@ -37,12 +32,12 @@ if (GGML_METAL_EMBED_LIBRARY)
     set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
     set(METALLIB_IMPL   "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h")
 
-    file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated")
+    file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/autogenerated")
 
     # merge ggml-common.h and ggml-metal.metal into a single file
-    set(METALLIB_EMBED_ASM        "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s")
-    set(METALLIB_SOURCE_EMBED     "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal")
-    set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp")
+    set(METALLIB_EMBED_ASM        "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.s")
+    set(METALLIB_SOURCE_EMBED     "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal")
+    set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp")
 
     add_custom_command(
         OUTPUT "${METALLIB_EMBED_ASM}"
@@ -62,6 +57,11 @@ if (GGML_METAL_EMBED_LIBRARY)
 
     target_sources(ggml-metal PRIVATE "${METALLIB_EMBED_ASM}")
 else()
+    # copy metal files to bin directory
+    configure_file(../ggml-common.h  ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h     COPYONLY)
+    configure_file(ggml-metal.metal  ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal  COPYONLY)
+    configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY)
+
     if (GGML_METAL_SHADER_DEBUG)
         # custom command to do the following:
         #   xcrun -sdk macosx metal    -fno-fast-math -c ggml-metal.metal -o ggml-metal.air
@@ -71,7 +71,7 @@ else()
         #       disabling fast math is needed in order to pass tests/test-backend-ops
         # note: adding -fno-inline fixes the tests when using MTL_SHADER_VALIDATION=1
         # note: unfortunately, we have to call it default.metallib instead of ggml.metallib
-        #       ref: https://github.com/ggerganov/whisper.cpp/issues/1720
+        #       ref: https://github.com/ggml-org/whisper.cpp/issues/1720
         # note: adding -g causes segmentation fault during compile
         #set(XC_FLAGS -fno-fast-math -fno-inline -g)
         set(XC_FLAGS -fno-fast-math -fno-inline)
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-common.cpp b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-common.cpp
index 95627d38665..2eb9820bff9 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-common.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-common.cpp
@@ -264,15 +264,26 @@ static std::vector ggml_metal_graph_optimize_reorder(const std::vector ggml_metal_graph_optimize_reorder(const std::vectorev_cpy = ggml_metal_device_event_init(dev);
+
+    const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev);
+
+    snprintf(res->name, sizeof(res->name), "%s", props_dev->name);
 
     res->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
 
@@ -206,9 +214,15 @@ void ggml_metal_free(ggml_metal_t ctx) {
 
     dispatch_release(ctx->d_queue);
 
+    ggml_metal_device_event_free(ctx->dev, ctx->ev_cpy);
+
     free(ctx);
 }
 
+const char * ggml_metal_get_name(ggml_metal_t ctx) {
+    return ctx->name;
+}
+
 void ggml_metal_synchronize(ggml_metal_t ctx) {
     // wait for any backend operations to finish
     if (ctx->cmd_buf_last) {
@@ -273,8 +287,8 @@ void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor,
         // wrap the source data into a Metal buffer
         id device = ggml_metal_device_get_obj(ctx->dev);
         id buf_src = [device newBufferWithBytes:data
-                                                         length:size
-                                                        options:MTLResourceStorageModeShared];
+                                                    length:size
+                                                   options:MTLResourceStorageModeShared];
 
         GGML_ASSERT(buf_src);
 
@@ -316,9 +330,9 @@ void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * te
     @autoreleasepool {
         id device = ggml_metal_device_get_obj(ctx->dev);
         id buf_dst = [device newBufferWithBytesNoCopy:data
-                                                               length:size
-                                                              options:MTLResourceStorageModeShared
-                                                          deallocator:nil];
+                                                          length:size
+                                                         options:MTLResourceStorageModeShared
+                                                     deallocator:nil];
 
         GGML_ASSERT(buf_dst);
 
@@ -356,9 +370,52 @@ void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * te
     }
 }
 
+bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) {
+    @autoreleasepool {
+        struct ggml_metal_buffer_id bid_src = ggml_metal_get_buffer_id(src);
+        struct ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(dst);
+
+        if (bid_src.metal == nil || bid_dst.metal == nil) {
+            return false;
+        }
+
+        // queue the copy operation into the Metal context
+        // this will be queued at the end, after any currently ongoing GPU operations
+        id queue = ggml_metal_device_get_queue(ctx_src->dev);
+        id cmd_buf = [queue commandBuffer];
+        id encoder = [cmd_buf blitCommandEncoder];
+
+        [encoder copyFromBuffer:bid_src.metal
+                   sourceOffset:bid_src.offs
+                       toBuffer:bid_dst.metal
+              destinationOffset:bid_dst.offs
+                           size:ggml_nbytes(src)];
+
+        [encoder endEncoding];
+
+        ggml_metal_event_t ev_cpy = ggml_metal_get_ev_cpy(ctx_src);
+        ggml_metal_event_encode_signal(ev_cpy, cmd_buf);
+
+        [cmd_buf commit];
+
+        // do not wait here for completion
+        //[cmd_buf waitUntilCompleted];
+
+        // instead, remember a reference to the command buffer and wait for it later if needed
+        [ctx_src->cmd_bufs_ext addObject:cmd_buf];
+        ctx_src->cmd_buf_last = cmd_buf;
+
+        [cmd_buf retain];
+
+        ggml_metal_event_wait(ctx_dst, ev_cpy);
+
+        return true;
+    }
+}
+
 enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * gf) {
     // number of nodes encoded by the main thread (empirically determined)
-    const int n_main = 64;
+    const int n_main = MAX(64, 0.1*gf->n_nodes);
 
     // number of threads in addition to the main thread
     const int n_cb = ctx->n_cb;
@@ -530,6 +587,42 @@ void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf) {
     //printf("%s: graph optimize took %.3f ms\n", __func__, (ggml_time_us() - t_start) / 1000.0);
 }
 
+void ggml_metal_event_record(ggml_metal_t ctx, ggml_metal_event_t ev) {
+    @autoreleasepool {
+        id queue = ggml_metal_device_get_queue(ctx->dev);
+        id cmd_buf = [queue commandBuffer];
+
+        ggml_metal_event_encode_signal(ev, cmd_buf);
+
+        [cmd_buf commit];
+
+        [ctx->cmd_bufs_ext addObject:cmd_buf];
+        ctx->cmd_buf_last = cmd_buf;
+
+        [cmd_buf retain];
+    }
+}
+
+void ggml_metal_event_wait(ggml_metal_t ctx, ggml_metal_event_t ev) {
+    @autoreleasepool {
+        id queue = ggml_metal_device_get_queue(ctx->dev);
+        id cmd_buf = [queue commandBuffer];
+
+        ggml_metal_event_encode_wait(ev, cmd_buf);
+
+        [cmd_buf commit];
+
+        [ctx->cmd_bufs_ext addObject:cmd_buf];
+        ctx->cmd_buf_last = cmd_buf;
+
+        [cmd_buf retain];
+    }
+}
+
+ggml_metal_event_t ggml_metal_get_ev_cpy(ggml_metal_t ctx) {
+    return ctx->ev_cpy;
+}
+
 void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) {
     if (ctx->n_cb != n_cb) {
         ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS);
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp
index 83385c9ef60..06f3d804590 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp
@@ -17,10 +17,12 @@ struct ggml_metal_device_deleter {
 
 typedef std::unique_ptr ggml_metal_device_ptr;
 
-ggml_metal_device_t ggml_metal_device_get(void) {
-    static ggml_metal_device_ptr ctx { ggml_metal_device_init() };
+ggml_metal_device_t ggml_metal_device_get(int device) {
+    static std::vector devs;
 
-    return ctx.get();
+    devs.emplace_back(ggml_metal_device_init(device));
+
+    return devs.back().get();
 }
 
 struct ggml_metal_pipelines {
@@ -94,6 +96,31 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy(ggml_metal_l
     return res;
 }
 
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
+    GGML_ASSERT(ggml_is_contiguous(op->src[0]));
+    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
+
+    const char * pool_str = "undefined";
+    switch (op_pool) {
+        case GGML_OP_POOL_AVG: pool_str = "avg"; break;
+        case GGML_OP_POOL_MAX: pool_str = "max"; break;
+        default: GGML_ASSERT(false && "not implemented");
+    };
+
+    char base[256];
+    char name[256];
+
+    snprintf(base, sizeof(base), "kernel_pool_1d_%s_%s", pool_str, ggml_type_name(op->src[0]->type));
+    snprintf(name, sizeof(name), "%s", base);
+
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+    }
+
+    return res;
+}
+
 ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
     GGML_ASSERT(ggml_is_contiguous(op->src[0]));
     GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
@@ -149,6 +176,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows(ggml_me
     return res;
 }
 
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag(ggml_metal_library_t lib, const ggml_tensor * op) {
+    char base[256];
+    char name[256];
+
+    const int n = op->src[0]->ne[0];
+
+    snprintf(base, 256, "kernel_diag_%s", ggml_type_name(op->src[0]->type));
+    snprintf(name, 256, "%s_n=%d", base, n);
+
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+    }
+
+    res.nsg  = 1;
+    res.smem = 0;
+
+    return res;
+}
+
 ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) {
     char base[256];
     char name[256];
@@ -165,61 +212,69 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_meta
 }
 
 ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {
-    GGML_ASSERT(ggml_is_contiguous(op->src[0]));
-
     char base[256];
     char name[256];
 
-    const int64_t n = ggml_nelements(op);
+    int op_num = -1;
 
-    const char * op_str = "undefined";
     switch (op->op) {
-        case GGML_OP_SCALE:      op_str = "scale";      break;
-        case GGML_OP_FILL:       op_str = "fill";       break;
-        case GGML_OP_CLAMP:      op_str = "clamp";      break;
-        case GGML_OP_SQR:        op_str = "sqr";        break;
-        case GGML_OP_SQRT:       op_str = "sqrt";       break;
-        case GGML_OP_SIN:        op_str = "sin";        break;
-        case GGML_OP_COS:        op_str = "cos";        break;
-        case GGML_OP_LOG:        op_str = "log";        break;
-        case GGML_OP_LEAKY_RELU: op_str = "leaky_relu"; break;
+        case GGML_OP_SCALE:      op_num = OP_UNARY_NUM_SCALE;      break;
+        case GGML_OP_FILL:       op_num = OP_UNARY_NUM_FILL;       break;
+        case GGML_OP_CLAMP:      op_num = OP_UNARY_NUM_CLAMP;      break;
+        case GGML_OP_SQR:        op_num = OP_UNARY_NUM_SQR;        break;
+        case GGML_OP_SQRT:       op_num = OP_UNARY_NUM_SQRT;       break;
+        case GGML_OP_SIN:        op_num = OP_UNARY_NUM_SIN;        break;
+        case GGML_OP_COS:        op_num = OP_UNARY_NUM_COS;        break;
+        case GGML_OP_LOG:        op_num = OP_UNARY_NUM_LOG;        break;
+        case GGML_OP_LEAKY_RELU: op_num = OP_UNARY_NUM_LEAKY_RELU; break;
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
-                case GGML_UNARY_OP_TANH:        op_str = "tanh";        break;
-                case GGML_UNARY_OP_RELU:        op_str = "relu";        break;
-                case GGML_UNARY_OP_SIGMOID:     op_str = "sigmoid";     break;
-                case GGML_UNARY_OP_GELU:        op_str = "gelu";        break;
-                case GGML_UNARY_OP_GELU_ERF:    op_str = "gelu_erf";    break;
-                case GGML_UNARY_OP_GELU_QUICK:  op_str = "gelu_quick";  break;
-                case GGML_UNARY_OP_SILU:        op_str = "silu";        break;
-                case GGML_UNARY_OP_ELU:         op_str = "elu";         break;
-                case GGML_UNARY_OP_NEG:         op_str = "neg";         break;
-                case GGML_UNARY_OP_ABS:         op_str = "abs";         break;
-                case GGML_UNARY_OP_SGN:         op_str = "sgn";         break;
-                case GGML_UNARY_OP_STEP:        op_str = "step";        break;
-                case GGML_UNARY_OP_HARDSWISH:   op_str = "hardswish";   break;
-                case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break;
-                case GGML_UNARY_OP_EXP:         op_str = "exp";         break;
-                case GGML_UNARY_OP_SOFTPLUS:    op_str = "softplus";    break;
-                case GGML_UNARY_OP_EXPM1:       op_str = "expm1";       break;
+                case GGML_UNARY_OP_TANH:        op_num = OP_UNARY_NUM_TANH;        break;
+                case GGML_UNARY_OP_RELU:        op_num = OP_UNARY_NUM_RELU;        break;
+                case GGML_UNARY_OP_SIGMOID:     op_num = OP_UNARY_NUM_SIGMOID;     break;
+                case GGML_UNARY_OP_GELU:        op_num = OP_UNARY_NUM_GELU;        break;
+                case GGML_UNARY_OP_GELU_ERF:    op_num = OP_UNARY_NUM_GELU_ERF;    break;
+                case GGML_UNARY_OP_GELU_QUICK:  op_num = OP_UNARY_NUM_GELU_QUICK;  break;
+                case GGML_UNARY_OP_SILU:        op_num = OP_UNARY_NUM_SILU;        break;
+                case GGML_UNARY_OP_ELU:         op_num = OP_UNARY_NUM_ELU;         break;
+                case GGML_UNARY_OP_NEG:         op_num = OP_UNARY_NUM_NEG;         break;
+                case GGML_UNARY_OP_ABS:         op_num = OP_UNARY_NUM_ABS;         break;
+                case GGML_UNARY_OP_SGN:         op_num = OP_UNARY_NUM_SGN;         break;
+                case GGML_UNARY_OP_STEP:        op_num = OP_UNARY_NUM_STEP;        break;
+                case GGML_UNARY_OP_HARDSWISH:   op_num = OP_UNARY_NUM_HARDSWISH;   break;
+                case GGML_UNARY_OP_HARDSIGMOID: op_num = OP_UNARY_NUM_HARDSIGMOID; break;
+                case GGML_UNARY_OP_EXP:         op_num = OP_UNARY_NUM_EXP;         break;
+                case GGML_UNARY_OP_SOFTPLUS:    op_num = OP_UNARY_NUM_SOFTPLUS;    break;
+                case GGML_UNARY_OP_EXPM1:       op_num = OP_UNARY_NUM_EXPM1;       break;
                 default: GGML_ABORT("fatal error");
             } break;
         default: GGML_ABORT("fatal error");
     };
 
-    const char * suffix = "";
-    if (n % 4 == 0) {
-        suffix = "_4";
-    }
+    const char * t0_str = ggml_type_name(op->src[0]->type);
+    const char * t_str  = ggml_type_name(op->type);
 
-    snprintf(base, 256, "kernel_%s_%s%s", op_str, ggml_type_name(op->src[0]->type), suffix);
-    snprintf(name, 256, "%s", base);
+    const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
+    const bool is_cnt = ggml_is_contiguous(op->src[0]) && ggml_nelements(op) < 32768;
+
+    snprintf(base, 256, "kernel_unary_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
+    snprintf(name, 256, "%s_op=%d_cnt=%d", base, op_num, is_cnt);
 
     ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
     if (!res.pipeline) {
-        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+        ggml_metal_cv_set_int16(cv, op_num, FC_UNARY + 0);
+        ggml_metal_cv_set_bool (cv, is_cnt, FC_UNARY + 1);
+
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+        ggml_metal_cv_free(cv);
     }
 
+    res.c4  = is_c4;
+    res.cnt = is_cnt;
+
     return res;
 }
 
@@ -273,31 +328,46 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_l
 }
 
 ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
-    GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
+    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
 
     char base[256];
     char name[256];
 
-    const char * op_str = "undefined";
+    int op_num = -1;
+
     switch (op->op) {
-        case GGML_OP_SUM_ROWS:
-            op_str = "sum_rows"; break;
-        case GGML_OP_MEAN:
-            op_str = "mean"; break;
+        case GGML_OP_SUM_ROWS: op_num = OP_SUM_ROWS_NUM_SUM_ROWS; break;
+        case GGML_OP_MEAN:     op_num = OP_SUM_ROWS_NUM_MEAN;     break;
         default: GGML_ABORT("fatal error");
     };
 
-    snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
+    const char * t0_str = ggml_type_name(op->src[0]->type);
+    const char * t_str  = ggml_type_name(op->type);
 
-    snprintf(name, 256, "%s", base);
+    const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
+
+    snprintf(base, 256, "kernel_sum_rows_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
+    snprintf(name, 256, "%s_op=%d", base, op_num);
 
     ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
     if (!res.pipeline) {
-        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+        ggml_metal_cv_set_int16(cv, op_num, FC_SUM_ROWS + 0);
+
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+        ggml_metal_cv_free(cv);
     }
 
     res.smem = 32*sizeof(float);
 
+    if (is_c4) {
+        res.smem *= 4;
+    }
+
+    res.c4  = is_c4;
+
     return res;
 }
 
@@ -507,6 +577,36 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_
     return res;
 }
 
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
+    char base[256];
+    char name[256];
+
+    const int nsg = 8;
+    const int n   = op->src[1]->ne[1];
+    const int k   = op->src[1]->ne[0];
+
+    snprintf(base, 256, "kernel_solve_tri_%s", ggml_type_name(op->src[0]->type));
+    snprintf(name, 256, "%s_nsg=%d_n=%d_k=%d", base, nsg, n, k);
+
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+        ggml_metal_cv_set_int16(cv, nsg, FC_SOLVE_TRI + 0);
+        ggml_metal_cv_set_int16(cv, n,   FC_SOLVE_TRI + 1);
+        ggml_metal_cv_set_int16(cv, k,   FC_SOLVE_TRI + 2);
+
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+        ggml_metal_cv_free(cv);
+    }
+
+    res.nsg  = nsg;
+    res.smem = GGML_PAD(GGML_PAD(n, 32)*nsg*sizeof(float), 16);
+
+    return res;
+}
+
 ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
     char base[256];
     char name[256];
@@ -1315,71 +1415,95 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_v
     GGML_UNUSED(op);
 }
 
-ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(
-        ggml_metal_library_t lib,
-        ggml_op op,
-        int32_t n_fuse,
-        bool row) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {
     char base[256];
     char name[256];
 
-    const char * op_str = "undefined";
-    switch (op) {
-        case GGML_OP_ADD:   op_str = "add";   break;
-        case GGML_OP_SUB:   op_str = "sub";   break;
-        case GGML_OP_MUL:   op_str = "mul";   break;
-        case GGML_OP_DIV:   op_str = "div";   break;
+    int op_num = -1;
+
+    switch (op->op) {
+        case GGML_OP_ADD: op_num = 0; break;
+        case GGML_OP_SUB: op_num = 1; break;
+        case GGML_OP_MUL: op_num = 2; break;
+        case GGML_OP_DIV: op_num = 3; break;
         default: GGML_ABORT("fatal error");
     };
 
-    if (row) {
-        snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse);
-    } else {
-        snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse);
-    }
+    const char * t0_str = ggml_type_name(op->src[0]->type);
+    const char * t1_str = ggml_type_name(op->src[1]->type);
+    const char * t_str  = ggml_type_name(op->type);
 
-    snprintf(name, 256, "%s", base);
+    const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0);
+
+    const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536;
+
+    snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : "");
+    snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d", base, op_num, n_fuse, is_rb);
 
     ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
     if (!res.pipeline) {
-        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+        ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
+        ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1);
+        ggml_metal_cv_set_bool (cv, is_rb,  FC_BIN + 2);
+
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+        ggml_metal_cv_free(cv);
     }
 
+    res.c4  = is_c4;
+    res.cnt = is_rb;
+
     return res;
 }
 
-ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
-    assert(op->op == GGML_OP_L2_NORM);
-
-    GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
-    GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
-
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_metal_library_t lib, ggml_op op) {
     char base[256];
     char name[256];
 
-    snprintf(base, 256, "kernel_l2_norm_f32");
-    snprintf(name, 256, "%s", base);
+    int op_num = -1;
+
+    switch (op) {
+        case GGML_OP_ADD: op_num = 0; break;
+        case GGML_OP_SUB: op_num = 1; break;
+        case GGML_OP_MUL: op_num = 2; break;
+        case GGML_OP_DIV: op_num = 3; break;
+        default: GGML_ABORT("fatal error");
+    };
+
+    snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s", "f32", "f32", "f32");
+    snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, 1);
 
     ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
     if (!res.pipeline) {
-        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-    }
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
 
-    res.smem = 32*sizeof(float);
+        ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
+        ggml_metal_cv_set_int16(cv, 1,      FC_BIN + 1);
+        ggml_metal_cv_set_bool (cv, false,  FC_BIN + 2);
+
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+        ggml_metal_cv_free(cv);
+    }
 
     return res;
 }
 
-ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
-    assert(op->op == GGML_OP_SOLVE_TRI);
-
-    GGML_ASSERT(ggml_is_contiguous(op->src[0]));
-    GGML_ASSERT(ggml_is_contiguous(op->src[1]));
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
+    assert(op->op == GGML_OP_L2_NORM);
 
     char base[256];
     char name[256];
 
-    snprintf(base, 256, "kernel_solve_tri_f32");
+    const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
+
+    const char * t0_str = ggml_type_name(op->src[0]->type);
+    const char * t_str  = ggml_type_name(op->type);
+
+    snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
     snprintf(name, 256, "%s", base);
 
     ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
@@ -1387,6 +1511,9 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_m
         res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
+    res.c4   = is_c4;
+    res.smem = 32*sizeof(float);
+
     return res;
 }
 
@@ -1704,3 +1831,60 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd(ggm
 
     return res;
 }
+
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset(ggml_metal_library_t lib, const ggml_tensor *  op) {
+    GGML_ASSERT(op->type == GGML_TYPE_I64);
+
+    char base[256];
+    char name[256];
+
+    snprintf(base, 256, "kernel_memset_%s", ggml_type_name(op->type));
+    snprintf(name, 256, "%s", base);
+
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+    }
+
+    return res;
+}
+
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal(ggml_metal_library_t lib, const ggml_tensor *  op) {
+    assert(op->op == GGML_OP_COUNT_EQUAL);
+
+    GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne);
+
+    GGML_ASSERT(op->src[0]->type == op->src[1]->type);
+    GGML_ASSERT(op->src[0]->type == GGML_TYPE_I32);
+    GGML_ASSERT(op->type == GGML_TYPE_I64);
+
+    // note: the kernel only supports i32 output due to metal atomic add only supporting atomic_int
+    GGML_ASSERT(ggml_nelements(op->src[0]) < (1LL << 31));
+
+    char base[256];
+    char name[256];
+
+    int nsg = 1;
+    while (32*nsg < ne00 && nsg < 32) {
+        nsg *= 2;
+    }
+
+    snprintf(base, 256, "kernel_count_equal_%s", ggml_type_name(op->src[0]->type));
+    snprintf(name, 256, "%s_nsg=%d", base, nsg);
+
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+        ggml_metal_cv_set_int16(cv, nsg, FC_COUNT_EQUAL + 0);
+
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+        ggml_metal_cv_free(cv);
+    }
+
+    res.smem = 32 * sizeof(int32_t);
+    res.nsg  = nsg;
+
+    return res;
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h
index 8a9d1746018..93d7f6a216f 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h
@@ -53,6 +53,9 @@ struct ggml_metal_pipeline_with_params {
     int nr1;
 
     size_t smem;
+
+    bool c4;
+    bool cnt;
 };
 
 int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline);
@@ -104,9 +107,11 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_
 
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base              (ggml_metal_library_t lib, enum ggml_op op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy               (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d           (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d           (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows          (ggml_metal_library_t lib, enum ggml_type tsrc);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows          (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag              (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat            (ggml_metal_library_t lib, enum ggml_type tsrc);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary             (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu               (ggml_metal_library_t lib, const struct ggml_tensor * op);
@@ -120,6 +125,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched  (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan          (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv              (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri         (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext        (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm            (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv            (ggml_metal_library_t lib, const struct ggml_tensor * op);
@@ -131,9 +137,9 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge     (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k             (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge       (ggml_metal_library_t lib, const struct ggml_tensor * op);
-struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin               (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin               (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse );
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one           (ggml_metal_library_t lib, enum ggml_op op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm           (ggml_metal_library_t lib, const struct ggml_tensor * op);
-struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri         (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm        (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm              (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope              (ggml_metal_library_t lib, const struct ggml_tensor * op);
@@ -148,6 +154,8 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw    (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd      (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset            (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal       (ggml_metal_library_t lib, const struct ggml_tensor * op);
 
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(
         ggml_metal_library_t lib,
@@ -203,7 +211,9 @@ void ggml_metal_rsets_free(ggml_metal_rsets_t rsets);
 //
 
 struct ggml_metal_device_props {
+    int device;
     char name[128];
+    char desc[128];
 
     size_t max_buffer_size;
     size_t max_working_set_size;
@@ -218,13 +228,19 @@ struct ggml_metal_device_props {
     bool use_shared_buffers;
 
     bool supports_gpu_family_apple7;
+
+    int op_offload_min_batch_size;
 };
 
-ggml_metal_device_t ggml_metal_device_init(void);
+typedef struct ggml_metal_event * ggml_metal_event_t;
+
+void ggml_metal_event_encode_signal(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf);
+void ggml_metal_event_encode_wait  (ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf);
+
+ggml_metal_device_t ggml_metal_device_init(int device);
 void ggml_metal_device_free(ggml_metal_device_t dev);
 
-// return a singleton that is automatically destroyed when the program exits
-ggml_metal_device_t ggml_metal_device_get(void);
+ggml_metal_device_t ggml_metal_device_get(int device);
 
 void * ggml_metal_device_get_obj  (ggml_metal_device_t dev); // id
 void * ggml_metal_device_get_queue(ggml_metal_device_t dev); // id
@@ -236,6 +252,10 @@ void ggml_metal_device_rsets_rm (ggml_metal_device_t dev, ggml_metal_rset_t rset
 
 void ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev);
 
+ggml_metal_event_t ggml_metal_device_event_init(ggml_metal_device_t dev);
+void ggml_metal_device_event_free(ggml_metal_device_t dev, ggml_metal_event_t ev);
+void ggml_metal_device_event_synchronize(ggml_metal_device_t dev, ggml_metal_event_t ev);
+
 void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total);
 bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_tensor * op);
 
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m
index 4e5acfbe5fd..3db7f126291 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m
@@ -24,9 +24,6 @@
 static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
 static const NSInteger MTLGPUFamilyMetal4_GGML = 5002;
 
-// virtual address for GPU memory allocations
-static atomic_uintptr_t g_addr_device = 0x000000400ULL;
-
 #if !GGML_METAL_EMBED_LIBRARY
 // Here to assist with NSBundle Path Hack
 @interface GGMLMetalClass : NSObject
@@ -349,10 +346,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta
 
     struct ggml_metal_pipeline_with_params res = {
         /*.pipeline =*/ nil,
+        /*.nsg      =*/ 0,
         /*.nr0      =*/ 0,
         /*.nr1      =*/ 0,
-        /*.nsg      =*/ 0,
         /*.smem     =*/ 0,
+        /*.c4       =*/ false,
+        /*.cnt      =*/ false,
     };
 
     res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name);
@@ -365,10 +364,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta
 struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
     struct ggml_metal_pipeline_with_params res = {
         /*.pipeline =*/ nil,
+        /*.nsg      =*/ 0,
         /*.nr0      =*/ 0,
         /*.nr1      =*/ 0,
-        /*.nsg      =*/ 0,
         /*.smem     =*/ 0,
+        /*.c4       =*/ false,
+        /*.cnt      =*/ false,
     };
 
     [lib->lock lock];
@@ -523,6 +524,9 @@ void ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder) {
     ggml_metal_library_t library;
 
     struct ggml_metal_device_props props;
+
+    // virtual address for GPU memory allocations
+    atomic_uintptr_t addr_virt;
 };
 
 //
@@ -618,7 +622,7 @@ void ggml_metal_rsets_free(ggml_metal_rsets_t rsets) {
     free(rsets);
 }
 
-ggml_metal_device_t ggml_metal_device_init(void) {
+ggml_metal_device_t ggml_metal_device_init(int device) {
     ggml_metal_device_t dev = calloc(1, sizeof(struct ggml_metal_device));
 
     assert(dev != NULL);
@@ -632,6 +636,9 @@ ggml_metal_device_t ggml_metal_device_init(void) {
                 GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
             }
 
+            dev->addr_virt = 0x000000400ULL;
+
+            dev->props.device = device;
             dev->props.has_simdgroup_reduction  = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7];
             dev->props.has_simdgroup_reduction |= [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
 
@@ -782,11 +789,18 @@ ggml_metal_device_t ggml_metal_device_init(void) {
 
             dev->props.supports_gpu_family_apple7 = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7];
 
+            dev->props.op_offload_min_batch_size  = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
+
             dev->props.max_buffer_size            = dev->mtl_device.maxBufferLength;
-            dev->props.max_working_set_size       = dev->mtl_device.recommendedMaxWorkingSetSize;
             dev->props.max_theadgroup_memory_size = dev->mtl_device.maxThreadgroupMemoryLength;
+            if (@available(macOS 10.12, iOS 16.0, *)) {
+                dev->props.max_working_set_size   = dev->mtl_device.recommendedMaxWorkingSetSize;
+            } else {
+                dev->props.max_working_set_size   = dev->mtl_device.maxBufferLength;
+            }
 
-            strncpy(dev->props.name, [[dev->mtl_device name] UTF8String], sizeof(dev->props.name) - 1);
+            snprintf(dev->props.name, sizeof(dev->props.name), "%s%d", "MTL", device);
+            snprintf(dev->props.desc, sizeof(dev->props.desc), "%s", [[dev->mtl_device name] UTF8String]);
 
             dev->library = ggml_metal_library_init(dev);
             if (!dev->library) {
@@ -916,6 +930,59 @@ void ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev) {
     atomic_store_explicit(&dev->rsets->d_loop, 2*dev->rsets->keep_alive_s, memory_order_relaxed);
 }
 
+struct ggml_metal_event {
+    void * obj; // id
+
+    atomic_int value;
+};
+
+void ggml_metal_event_encode_signal(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf_raw) {
+    id event = (id)ev->obj;
+
+    id cmd_buf = (id) cmd_buf_raw;
+
+    [cmd_buf encodeSignalEvent:event value:atomic_fetch_add_explicit(&ev->value, 1, memory_order_relaxed) + 1];
+}
+
+void ggml_metal_event_encode_wait(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf_raw) {
+    id event = (id)ev->obj;
+
+    id cmd_buf = (id) cmd_buf_raw;
+
+    [cmd_buf encodeWaitForEvent:event value:atomic_load_explicit(&ev->value, memory_order_relaxed)];
+}
+
+ggml_metal_event_t ggml_metal_device_event_init(ggml_metal_device_t dev) {
+    id event = [dev->mtl_device newEvent];
+
+    ggml_metal_event_t ev = calloc(1, sizeof(struct ggml_metal_event));
+
+    ev->obj = (__bridge void *)event;
+    ev->value = 0;
+
+    return ev;
+}
+
+void ggml_metal_device_event_free(ggml_metal_device_t dev, ggml_metal_event_t ev) {
+    id event = ev->obj;
+    [event release];
+
+    free(ev);
+
+    GGML_UNUSED(dev);
+}
+
+void ggml_metal_device_event_synchronize(ggml_metal_device_t dev, ggml_metal_event_t ev) {
+    @autoreleasepool {
+        id event = ev->obj;
+
+        id cmd_buf = [dev->mtl_queue commandBuffer];
+        [cmd_buf encodeWaitForEvent:event value:atomic_load_explicit(&ev->value, memory_order_relaxed)];
+        [cmd_buf commit];
+        [cmd_buf waitUntilCompleted];
+    }
+}
+
 void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total) {
     if (@available(macOS 10.12, iOS 16.0, *)) {
         *total = dev->mtl_device.recommendedMaxWorkingSetSize;
@@ -944,6 +1011,15 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
     }
 
     switch (op->op) {
+        case GGML_OP_SCALE:
+        case GGML_OP_FILL:
+        case GGML_OP_CLAMP:
+        case GGML_OP_SQR:
+        case GGML_OP_SQRT:
+        case GGML_OP_SIN:
+        case GGML_OP_COS:
+        case GGML_OP_LOG:
+            return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
                 case GGML_UNARY_OP_TANH:
@@ -963,7 +1039,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
                 case GGML_UNARY_OP_EXP:
                 case GGML_UNARY_OP_SOFTPLUS:
                 case GGML_UNARY_OP_EXPM1:
-                    return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
+                    return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
                 default:
                     return false;
             }
@@ -991,11 +1067,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
         case GGML_OP_MUL:
         case GGML_OP_DIV:
         case GGML_OP_ADD_ID:
-            return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_ACC:
+            return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_REPEAT:
-        case GGML_OP_SCALE:
-        case GGML_OP_FILL:
         case GGML_OP_CONV_TRANSPOSE_1D:
             return true;
         case GGML_OP_CONV_TRANSPOSE_2D:
@@ -1003,14 +1077,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
                 (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
                 op->src[1]->type == GGML_TYPE_F32 &&
                 op->type == GGML_TYPE_F32;
-        case GGML_OP_CLAMP:
-            return op->src[0]->type == GGML_TYPE_F32;
-        case GGML_OP_SQR:
-        case GGML_OP_SQRT:
-        case GGML_OP_SIN:
-        case GGML_OP_COS:
-        case GGML_OP_LOG:
-            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_SUM:
             return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
         case GGML_OP_TRI:
@@ -1020,15 +1086,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
         case GGML_OP_MEAN:
         case GGML_OP_SOFT_MAX:
         case GGML_OP_GROUP_NORM:
-            return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
         case GGML_OP_L2_NORM:
-            return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
-        case GGML_OP_SOLVE_TRI:
-            return ggml_is_contiguous(op->src[0]) &&
-                ggml_is_contiguous(op->src[1]) &&
-                op->src[0]->type == GGML_TYPE_F32 &&
-                op->src[1]->type == GGML_TYPE_F32 &&
-                op->type == GGML_TYPE_F32;
+            return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
         case GGML_OP_COUNT_EQUAL:
             return has_simdgroup_reduction &&
                 op->src[0]->type == GGML_TYPE_I32 &&
@@ -1048,10 +1107,10 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
                    op->src[1]->type == GGML_TYPE_F32 &&
                    op->type == GGML_TYPE_F32 &&
                    (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
-        case GGML_OP_POOL_1D:
-            return false;
         case GGML_OP_UPSCALE:
             return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
+        case GGML_OP_POOL_1D:
+            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_POOL_2D:
             return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_PAD:
@@ -1096,9 +1155,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
         case GGML_OP_RWKV_WKV6:
         case GGML_OP_RWKV_WKV7:
             return true;
+        case GGML_OP_SOLVE_TRI:
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
             return has_simdgroup_reduction;
+        case GGML_OP_SET:
         case GGML_OP_CPY:
         case GGML_OP_DUP:
         case GGML_OP_CONT:
@@ -1177,6 +1238,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
                         return false;
                 };
             }
+        case GGML_OP_DIAG:
+            return true;
         case GGML_OP_OPT_STEP_ADAMW:
         case GGML_OP_OPT_STEP_SGD:
             return has_simdgroup_reduction;
@@ -1344,8 +1407,8 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
         res->all_data = ggml_metal_host_malloc(size_aligned);
         res->is_shared = true;
     } else {
-        // use virtual address from g_addr_device counter
-        res->all_data = (void *) atomic_fetch_add_explicit(&g_addr_device, size_aligned, memory_order_relaxed);
+        // use virtual address
+        res->all_data = (void *) atomic_fetch_add_explicit(&dev->addr_virt, size_aligned, memory_order_relaxed);
         res->is_shared = false;
     }
     res->all_size = size_aligned;
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
index df243edcbd7..033028d6675 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
@@ -1963,14 +1963,50 @@ GGML_TABLE_END()
 #define FC_MUL_MM                      700
 #define FC_ROPE                        800
 #define FC_SSM_CONV                    900
+#define FC_SOLVE_TRI                   1000
+#define FC_COUNT_EQUAL                 1100
+#define FC_UNARY                       1200
+#define FC_BIN                         1300
+#define FC_SUM_ROWS                    1400
 
 // op-specific constants
-#define OP_FLASH_ATTN_EXT_NQPTG 8
+#define OP_FLASH_ATTN_EXT_NQPSG 8
 #define OP_FLASH_ATTN_EXT_NCPSG 64
 
-#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1
+#define OP_FLASH_ATTN_EXT_VEC_NQPSG 1
 #define OP_FLASH_ATTN_EXT_VEC_NCPSG 32
 
+#define OP_UNARY_NUM_SCALE      10
+#define OP_UNARY_NUM_FILL       11
+#define OP_UNARY_NUM_CLAMP      12
+#define OP_UNARY_NUM_SQR        13
+#define OP_UNARY_NUM_SQRT       14
+#define OP_UNARY_NUM_SIN        15
+#define OP_UNARY_NUM_COS        16
+#define OP_UNARY_NUM_LOG        17
+#define OP_UNARY_NUM_LEAKY_RELU 18
+
+#define OP_UNARY_NUM_TANH        100
+#define OP_UNARY_NUM_RELU        101
+#define OP_UNARY_NUM_SIGMOID     102
+#define OP_UNARY_NUM_GELU        103
+#define OP_UNARY_NUM_GELU_ERF    104
+#define OP_UNARY_NUM_GELU_QUICK  105
+#define OP_UNARY_NUM_SILU        106
+#define OP_UNARY_NUM_ELU         107
+#define OP_UNARY_NUM_NEG         108
+#define OP_UNARY_NUM_ABS         109
+#define OP_UNARY_NUM_SGN         110
+#define OP_UNARY_NUM_STEP        111
+#define OP_UNARY_NUM_HARDSWISH   112
+#define OP_UNARY_NUM_HARDSIGMOID 113
+#define OP_UNARY_NUM_EXP         114
+#define OP_UNARY_NUM_SOFTPLUS    115
+#define OP_UNARY_NUM_EXPM1       116
+
+#define OP_SUM_ROWS_NUM_SUM_ROWS 10
+#define OP_SUM_ROWS_NUM_MEAN     11
+
 // kernel argument structs
 //
 // - element counters (e.g. ne00) typically use int32_t to reduce register usage
@@ -2006,6 +2042,31 @@ typedef struct {
     int32_t  dim;
 } ggml_metal_kargs_concat;
 
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    float    slope;
+    float    scale;
+    float    bias;
+    float    val;
+    float    min;
+    float    max;
+} ggml_metal_kargs_unary;
+
 typedef struct {
     int32_t  ne00;
     int32_t  ne01;
@@ -2063,20 +2124,6 @@ typedef struct {
     uint64_t nb3;
 } ggml_metal_kargs_repeat;
 
-typedef struct {
-    float scale;
-    float bias;
-} ggml_metal_kargs_scale;
-
-typedef struct {
-    float val;
-} ggml_metal_kargs_fill;
-
-typedef struct {
-    float min;
-    float max;
-} ggml_metal_kargs_clamp;
-
 typedef struct {
     int64_t  nk0;
     int64_t  ne00;
@@ -2378,13 +2425,6 @@ typedef struct {
     uint64_t nbf3[3];
 } ggml_metal_kargs_norm;
 
-typedef struct {
-    int32_t  ne00;
-    int32_t  ne00_4;
-    uint64_t nb01;
-    float    eps;
-} ggml_metal_kargs_l2_norm;
-
 typedef struct {
     int32_t  ne00;
     int32_t  ne01;
@@ -2394,17 +2434,16 @@ typedef struct {
     uint64_t nb01;
     uint64_t nb02;
     uint64_t nb03;
-    int32_t  ne10;
-    int32_t  ne11;
-    uint64_t nb10;
-    uint64_t nb11;
-    uint64_t nb12;
-    uint64_t nb13;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
     uint64_t nb0;
     uint64_t nb1;
     uint64_t nb2;
     uint64_t nb3;
-} ggml_metal_kargs_solve_tri;
+    float    eps;
+} ggml_metal_kargs_l2_norm;
 
 typedef struct {
     int64_t  ne00;
@@ -2638,6 +2677,33 @@ typedef struct {
     uint64_t nb0;
 } ggml_metal_kargs_ssm_scan;
 
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne10;
+    int32_t  ne11;
+    int32_t  ne12;
+    int32_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_solve_tri;
+
 typedef struct {
     int32_t  ne00t;
     int32_t  ne00;
@@ -2669,6 +2735,25 @@ typedef struct {
     uint64_t nb3;
 } ggml_metal_kargs_set_rows;
 
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_diag;
+
 typedef struct {
     int64_t  ne00;
     int64_t  ne01;
@@ -2738,10 +2823,6 @@ typedef struct {
     int      max_period;
 } ggml_metal_kargs_timestep_embedding;
 
-typedef struct {
-    float    slope;
-} ggml_metal_kargs_leaky_relu;
-
 typedef struct {
     int32_t  ne00;
     int32_t  ne01;
@@ -2800,6 +2881,25 @@ typedef struct {
     float    step;
 } ggml_metal_kargs_arange;
 
+typedef struct {
+    int64_t val;
+} ggml_metal_kargs_memset;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+} ggml_metal_kargs_count_equal;
+
 typedef struct {
     int32_t  k0;
     int32_t  k1;
@@ -2814,6 +2914,15 @@ typedef struct {
     int64_t  np;
 } ggml_metal_kargs_pool_2d;
 
+typedef struct {
+    int32_t  k0;
+    int32_t  s0;
+    int32_t  p0;
+    int64_t  IW;
+    int64_t  OW;
+    int64_t  np;
+} ggml_metal_kargs_pool_1d;
+
 typedef struct {
      int64_t ne00;
     uint64_t nb01;
@@ -2899,6 +3008,14 @@ static inline float dot(float x, float y) {
     return x*y;
 }
 
+static inline float sum(float x) {
+    return x;
+}
+
+static inline float sum(float4 x) {
+    return x[0] + x[1] + x[2] + x[3];
+}
+
 // NOTE: this is not dequantizing - we are simply fitting the template
 template 
 void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@@ -3717,751 +3834,428 @@ enum ggml_sort_order {
     GGML_SORT_ORDER_DESC,
 };
 
-// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
-// pros: works for non-contiguous tensors, supports broadcast across all dims
-// cons: not very efficient
-template 
-kernel void kernel_add_fuse_impl(
-        constant ggml_metal_kargs_bin & args,
+constant float GELU_COEF_A     = 0.044715f;
+constant float GELU_QUICK_COEF = -1.702f;
+constant float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
+constant float SQRT_2_INV      = 0.70710678118654752440084436210484f;
+
+// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
+// ref: https://www.johndcook.com/blog/python_erf/
+constant float p_erf  = 0.3275911f;
+constant float a1_erf = 0.254829592f;
+constant float a2_erf = -0.284496736f;
+constant float a3_erf = 1.421413741f;
+constant float a4_erf = -1.453152027f;
+constant float a5_erf = 1.061405429f;
+
+template
+inline T erf_approx(T x) {
+    T sign_x = sign(x);
+    x = fabs(x);
+    T t = 1.0f / (1.0f + p_erf * x);
+    T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
+    return sign_x * y;
+}
+
+template T elu_approx(T x);
+
+template<> inline float elu_approx(float x) {
+    return (x > 0.f) ? x : (exp(x) - 1);
+}
+
+template<> inline float4 elu_approx(float4 x) {
+    float4 res;
+
+    res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
+    res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
+    res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
+    res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
+
+    return res;
+}
+
+constant short FC_unary_op [[function_constant(FC_UNARY + 0)]];
+constant bool  FC_unary_cnt[[function_constant(FC_UNARY + 1)]];
+
+template 
+kernel void kernel_unary_impl(
+        constant ggml_metal_kargs_unary & args,
         device const char * src0,
-        device const char * src1,
         device       char * dst,
         uint3   tgpig[[threadgroup_position_in_grid]],
         ushort3 tpitg[[thread_position_in_threadgroup]],
         ushort3   ntg[[threads_per_threadgroup]]) {
-    const int i03 = tgpig.z;
-    const int i02 = tgpig.y;
-    const int i01 = tgpig.x;
+#define FC_OP  FC_unary_op
+#define FC_CNT FC_unary_cnt
+
+    device const T0 * src0_ptr;
+    device       T  * dst_ptr;
+
+    int i0;
+
+    if (FC_CNT) {
+        i0 = tgpig.x;
 
-    const int i13 = i03%args.ne13;
-    const int i12 = i02%args.ne12;
-    const int i11 = i01%args.ne11;
+        src0_ptr = (device const T0 *) (src0);
+        dst_ptr  = (device       T  *) (dst);
+    } else {
+        const int i03 = tgpig.z;
+        const int i02 = tgpig.y;
+        const int k0  = tgpig.x/args.ne01;
+        const int i01 = tgpig.x - k0*args.ne01;
 
-    device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
-    device       float * dst_ptr  = (device       float *) (dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs);
+        i0 = k0*ntg.x + tpitg.x;
 
-    device const float * src1_ptr[F];
-    for (short j = 0; j < F; ++j) {
-        src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
+        src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
+        dst_ptr  = (device       T  *) (dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1 );
     }
 
-    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-        const int i10 = i0%args.ne10;
+    {
+        //threadgroup_barrier(mem_flags::mem_none);
+
+        if (!FC_CNT) {
+            if (i0 >= args.ne0) {
+                return;
+            }
+        }
 
-        float res = src0_ptr[i0];
+        const TC x = (TC) src0_ptr[i0];
 
-#pragma unroll
-        for (short j = 0; j < F; ++j) {
-            res += src1_ptr[j][i10];
+        if (FC_OP == OP_UNARY_NUM_SCALE) {
+            dst_ptr[i0] = (T) (args.scale * x + args.bias);
         }
 
-        dst_ptr[i0] = res;
-    }
-}
+        if (FC_OP == OP_UNARY_NUM_FILL) {
+            dst_ptr[i0] = (T) args.val;
+        }
 
-typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t;
+        if (FC_OP == OP_UNARY_NUM_CLAMP) {
+            dst_ptr[i0] = (T) clamp(x, args.min, args.max);
+        }
 
-template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
-template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
-template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
-template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
-template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>;
-template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>;
-template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
-template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
+        if (FC_OP == OP_UNARY_NUM_SQR) {
+            dst_ptr[i0] = (T) (x * x);
+        }
 
-kernel void kernel_sub_fuse_1(
-        constant ggml_metal_kargs_bin & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort3 tpitg[[thread_position_in_threadgroup]],
-        ushort3   ntg[[threads_per_threadgroup]]) {
-    const int i03 = tgpig.z;
-    const int i02 = tgpig.y;
-    const int i01 = tgpig.x;
+        if (FC_OP == OP_UNARY_NUM_SQRT) {
+            dst_ptr[i0] = (T) sqrt(x);
+        }
 
-    const int i13 = i03%args.ne13;
-    const int i12 = i02%args.ne12;
-    const int i11 = i01%args.ne11;
+        if (FC_OP == OP_UNARY_NUM_SIN) {
+            dst_ptr[i0] = (T) sin(x);
+        }
 
-    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
-    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
-    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
+        if (FC_OP == OP_UNARY_NUM_COS) {
+            dst_ptr[i0] = (T) cos(x);
+        }
 
-    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-        const int i10 = i0%args.ne10;
-        *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10));
-    }
-}
+        if (FC_OP == OP_UNARY_NUM_LOG) {
+            dst_ptr[i0] = (T) log(x);
+        }
 
-kernel void kernel_mul_fuse_1(
-        constant ggml_metal_kargs_bin & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort3 tpitg[[thread_position_in_threadgroup]],
-        ushort3   ntg[[threads_per_threadgroup]]) {
-    const int i03 = tgpig.z;
-    const int i02 = tgpig.y;
-    const int i01 = tgpig.x;
+        if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {
+            dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope));
+        }
 
-    const int i13 = i03%args.ne13;
-    const int i12 = i02%args.ne12;
-    const int i11 = i01%args.ne11;
+        if (FC_OP == OP_UNARY_NUM_TANH) {
+            dst_ptr[i0] = (T) precise::tanh(x);
+        }
 
-    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
-    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
-    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
+        if (FC_OP == OP_UNARY_NUM_RELU) {
+            dst_ptr[i0] = (T) fmax(0, x);
+        }
 
-    if (args.ne10 == 1) {
-        const float x = *((device float *)(src1_ptr));
-        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-            *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
+        if (FC_OP == OP_UNARY_NUM_SIGMOID) {
+            dst_ptr[i0] = (T) (1 / (1 + exp(-x)));
         }
-    } else {
-        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-            const int i10 = i0%args.ne10;
-            *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
+
+        if (FC_OP == OP_UNARY_NUM_GELU) {
+            dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x))));
         }
-    }
-}
 
-kernel void kernel_div_fuse_1(
-        constant ggml_metal_kargs_bin & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort3 tpitg[[thread_position_in_threadgroup]],
-        ushort3   ntg[[threads_per_threadgroup]]) {
-    const int i03 = tgpig.z;
-    const int i02 = tgpig.y;
-    const int i01 = tgpig.x;
+        if (FC_OP == OP_UNARY_NUM_GELU_ERF) {
+            dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x)));
+        }
 
-    const int i13 = i03%args.ne13;
-    const int i12 = i02%args.ne12;
-    const int i11 = i01%args.ne11;
+        if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {
+            dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x))));
+        }
 
-    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
-    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
-    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
+        if (FC_OP == OP_UNARY_NUM_SILU) {
+            dst_ptr[i0] = (T) (x / (1 + exp(-x)));
+        }
 
-    if (args.ne10 == 1) {
-        const float x = 1.0f / *((device float *)(src1_ptr));
-        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-            *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
+        if (FC_OP == OP_UNARY_NUM_ELU) {
+            dst_ptr[i0] = (T) elu_approx(x);
         }
-    } else {
-        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-            const int i10 = i0%args.ne10;
-            *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
+
+        if (FC_OP == OP_UNARY_NUM_NEG) {
+            dst_ptr[i0] = (T) -x;
+        }
+
+        if (FC_OP == OP_UNARY_NUM_ABS) {
+            dst_ptr[i0] = (T) fabs(x);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_SGN) {
+            dst_ptr[i0] = T(x > 0) - T(x < 0);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_STEP) {
+            dst_ptr[i0] = T(x > 0);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_HARDSWISH) {
+            dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5)));
+        }
+
+        if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {
+            dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5));
+        }
+
+        if (FC_OP == OP_UNARY_NUM_EXP) {
+            dst_ptr[i0] = (T) exp(x);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {
+            dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_EXPM1) {
+            // TODO: precise implementation
+            dst_ptr[i0] = (T) (exp(x) - 1);
         }
     }
+
+#undef FC_OP
+#undef FC_CNT
 }
 
-kernel void kernel_add_id(
-        constant ggml_metal_kargs_add_id & args,
+typedef decltype(kernel_unary_impl) kernel_unary_t;
+
+template [[host_name("kernel_unary_f32_f32")]]   kernel kernel_unary_t kernel_unary_impl;
+template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl;
+template [[host_name("kernel_unary_f16_f16")]]   kernel kernel_unary_t kernel_unary_impl;
+template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl;
+
+// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
+constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
+constant short FC_bin_f  [[function_constant(FC_BIN + 1)]];
+constant bool  FC_bin_rb [[function_constant(FC_BIN + 2)]];
+
+template 
+kernel void kernel_bin_fuse_impl(
+        constant ggml_metal_kargs_bin & args,
         device const char * src0,
         device const char * src1,
-        device const char * src2,
         device       char * dst,
         uint3   tgpig[[threadgroup_position_in_grid]],
         ushort3 tpitg[[thread_position_in_threadgroup]],
         ushort3   ntg[[threads_per_threadgroup]]) {
-    const int i1 = tgpig.x;
-    const int i2 = tgpig.y;
+#define FC_OP FC_bin_op
+#define FC_F  FC_bin_f
+#define FC_RB FC_bin_rb
 
-    const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
+    if (FC_RB) {
+        // row broadcast
+        const uint i0 = tgpig.x;
+        const uint i1 = i0%args.ne10;
 
-    const size_t nb1 = args.ne0 * sizeof(float);
-    const size_t nb2 = args.ne1 * nb1;
+        device const T0 * src0_row = (device const T0 *) (src0);
+        device       T  * dst_row  = (device       T  *) (dst);
 
-    device       float * dst_row  = (device       float *)((device char *)dst + i1*nb1 + i2*nb2);
-    device const float * src0_row = (device const float *)((device char *)src0 +  i1*args.nb01 + i2*args.nb02);
-    device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
+        if (FC_F == 1) {
+            device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]);
 
-    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-        dst_row[i0] = src0_row[i0] + src1_row[i0];
-    }
-}
-
-template
-kernel void kernel_repeat(
-        constant ggml_metal_kargs_repeat & args,
-        device const char * src0,
-        device       char * dst,
-        uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort3 tpitg[[thread_position_in_threadgroup]],
-        ushort3   ntg[[threads_per_threadgroup]]) {
-    const int i3 = tgpig.z;
-    const int i2 = tgpig.y;
-    const int i1 = tgpig.x;
-
-    const int i03 = i3%args.ne03;
-    const int i02 = i2%args.ne02;
-    const int i01 = i1%args.ne01;
-
-    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
-    device       char * dst_ptr  = dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1;
-
-    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-        const int i00 = i0%args.ne00;
-        *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
-    }
-}
+            if (FC_OP == 0) {
+                dst_row[i0] = src0_row[i0] + src1_row[i1];
+            }
 
-typedef decltype(kernel_repeat) kernel_repeat_t;
+            if (FC_OP == 1) {
+                dst_row[i0] = src0_row[i0] - src1_row[i1];
+            }
 
-template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat;
-template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat;
-template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat;
-template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat;
+            if (FC_OP == 2) {
+                dst_row[i0] = src0_row[i0] * src1_row[i1];
+            }
 
-// assumption: src1 is a row
-// broadcast src1 into src0
-template 
-kernel void kernel_add_row_c4_fuse_impl(
-        constant ggml_metal_kargs_bin & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const uint nb = args.ne00/4;
-    const uint i  = tpig % nb;
+            if (FC_OP == 3) {
+                dst_row[i0] = src0_row[i0] / src1_row[i1];
+            }
+        } else {
+            T0 res = src0_row[i0];
 
-    device const float4 * src0_row = (device const float4 *) (src0);
-    device       float4 *  dst_row = (device       float4 *) (dst);
+            if (FC_OP == 0) {
+                FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                    res += ((device const T1 *) (src1 + args.o1[j]))[i1];
+                }
+            }
 
-    float4 res = src0_row[tpig];
+            if (FC_OP == 1) {
+                FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                    res -= ((device const T1 *) (src1 + args.o1[j]))[i1];
+                }
+            }
 
-#pragma unroll(F)
-    for (short j = 0; j < F; ++j) {
-        res += ((device const float4 *) (src1 + args.o1[j]))[i];
-    }
+            if (FC_OP == 2) {
+                FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                    res *= ((device const T1 *) (src1 + args.o1[j]))[i1];
+                }
+            }
 
-    dst_row[tpig] = res;
-}
+            if (FC_OP == 3) {
+                FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                    res /= ((device const T1 *) (src1 + args.o1[j]))[i1];
+                }
+            }
 
-typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t;
+            dst_row[i0] = res;
+        }
+    } else {
+        const int i03 = tgpig.z;
+        const int i02 = tgpig.y;
+        const int i01 = tgpig.x;
 
-template [[host_name("kernel_add_row_c4_fuse_1")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
-template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>;
-template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>;
-template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
-template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>;
-template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>;
-template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>;
-template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>;
+        if (i01 >= args.ne01) {
+            return;
+        }
 
-template 
-kernel void kernel_sub_row_c4_fuse_impl(
-        constant ggml_metal_kargs_bin & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint tpig[[thread_position_in_grid]]) {
+        const int i13 = i03%args.ne13;
+        const int i12 = i02%args.ne12;
+        const int i11 = i01%args.ne11;
 
-    const uint nb = args.ne00/4;
-    const uint i  = tpig % nb;
+        device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
+        device       T  * dst_ptr  = (device       T  *) (dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs);
 
-    device const float4 * src0_row = (device const float4 *) (src0);
-    device       float4 *  dst_row = (device       float4 *) (dst);
+        if (FC_F == 1) {
+            device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
 
-    device const float4 * src1_row[F];
-    for (short j = 0; j < F; ++j) {
-        src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
-    }
+            for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+                const int i10 = i0%args.ne10;
 
-    float4 res = src0_row[tpig];
+                if (FC_OP == 0) {
+                    dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
+                }
 
-#pragma unroll(F)
-    for (short j = 0; j < F; ++j) {
-        res -= src1_row[j][i];
-    }
+                if (FC_OP == 1) {
+                    dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10];
+                }
 
-    dst_row[tpig] = res;
-}
+                if (FC_OP == 2) {
+                    dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10];
+                }
 
-typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t;
+                if (FC_OP == 3) {
+                    dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10];
+                }
+            }
+        } else {
+            device const T1 * src1_ptr[8];
+            FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
+            }
 
-template [[host_name("kernel_sub_row_c4_fuse_1")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
+            for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+                const int i10 = i0%args.ne10;
 
-template 
-kernel void kernel_mul_row_c4_fuse_impl(
-        constant ggml_metal_kargs_bin & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint tpig[[thread_position_in_grid]]) {
+                T res = src0_ptr[i0];
 
-    const uint nb = args.ne00/4;
-    const uint i  = tpig % nb;
+                if (FC_OP == 0) {
+                    FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                        res += src1_ptr[j][i10];
+                    }
+                }
 
-    device const float4 * src0_row = (device const float4 *) (src0);
-    device       float4 *  dst_row = (device       float4 *) (dst);
+                if (FC_OP == 1) {
+                    FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                        res -= src1_ptr[j][i10];
+                    }
+                }
 
-    device const float4 * src1_row[F];
-    for (short j = 0; j < F; ++j) {
-        src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
-    }
+                if (FC_OP == 2) {
+                    FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                        res *= src1_ptr[j][i10];
+                    }
+                }
 
-    float4 res = src0_row[tpig];
+                if (FC_OP == 3) {
+                    FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                        res /= src1_ptr[j][i10];
+                    }
+                }
 
-#pragma unroll(F)
-    for (short j = 0; j < F; ++j) {
-        res *= src1_row[j][i];
+                dst_ptr[i0] = res;
+            }
+        }
     }
 
-    dst_row[tpig] = res;
+#undef FC_OP
+#undef FC_F
+#undef FC_RB
 }
 
-typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t;
+typedef decltype(kernel_bin_fuse_impl) kernel_bin_fuse_t;
 
-template [[host_name("kernel_mul_row_c4_fuse_1")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
+template [[host_name("kernel_bin_fuse_f32_f32_f32")]]   kernel kernel_bin_fuse_t kernel_bin_fuse_impl;
+template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl;
 
-template 
-kernel void kernel_div_row_c4_fuse_impl(
-        constant ggml_metal_kargs_bin & args,
+kernel void kernel_add_id(
+        constant ggml_metal_kargs_add_id & args,
         device const char * src0,
         device const char * src1,
+        device const char * src2,
         device       char * dst,
-        uint tpig[[thread_position_in_grid]]) {
-
-    const uint nb = args.ne00/4;
-    const uint i  = tpig % nb;
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i1 = tgpig.x;
+    const int i2 = tgpig.y;
 
-    device const float4 * src0_row = (device const float4 *) (src0);
-    device       float4 *  dst_row = (device       float4 *) (dst);
+    const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
 
-    device const float4 * src1_row[F];
-    for (short j = 0; j < F; ++j) {
-        src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
-    }
+    const size_t nb1 = args.ne0 * sizeof(float);
+    const size_t nb2 = args.ne1 * nb1;
 
-    float4 res = src0_row[tpig];
+    device       float * dst_row  = (device       float *)((device char *)dst  +  i1*nb1       + i2*nb2);
+    device const float * src0_row = (device const float *)((device char *)src0 +  i1*args.nb01 + i2*args.nb02);
+    device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
 
-#pragma unroll(F)
-    for (short j = 0; j < F; ++j) {
-        res /= src1_row[j][i];
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        dst_row[i0] = src0_row[i0] + src1_row[i0];
     }
-
-    dst_row[tpig] = res;
-}
-
-typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
-
-template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
-
-kernel void kernel_scale_f32(
-        constant ggml_metal_kargs_scale & args,
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * args.scale + args.bias;
 }
 
-kernel void kernel_scale_f32_4(
-        constant ggml_metal_kargs_scale & args,
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * args.scale + args.bias;
-}
-
-kernel void kernel_fill_f32(
-        constant ggml_metal_kargs_fill & args,
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = args.val;
-}
-
-kernel void kernel_fill_f32_4(
-        constant ggml_metal_kargs_fill & args,
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = args.val;
-}
-
-kernel void kernel_clamp_f32(
-        constant ggml_metal_kargs_clamp & args,
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = clamp(src0[tpig], args.min, args.max);
-}
-
-kernel void kernel_clamp_f32_4(
-        constant ggml_metal_kargs_clamp & args,
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = clamp(src0[tpig], args.min, args.max);
-}
-
-kernel void kernel_relu_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = max(0.0f, src0[tpig]);
-}
-
-kernel void kernel_relu_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = max(0.0f, src0[tpig]);
-}
-
-kernel void kernel_sigmoid_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
-}
-
-kernel void kernel_sigmoid_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
-}
-
-kernel void kernel_tanh_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = precise::tanh(src0[tpig]);
-}
-
-kernel void kernel_tanh_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = precise::tanh(src0[tpig]);
-}
-
-constant float GELU_COEF_A     = 0.044715f;
-constant float GELU_QUICK_COEF = -1.702f;
-constant float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
-constant float SQRT_2_INV      = 0.70710678118654752440084436210484f;
-
-kernel void kernel_gelu_f32(
-    device const float * src0,
-    device       float * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
-
-    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
-}
-
-kernel void kernel_gelu_f32_4(
-    device const float4 * src0,
-    device       float4 * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-
-    // BEWARE !!!
-    // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
-    // This was observed with Falcon 7B and 40B models
-    //
-    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
-}
-
-kernel void kernel_gelu_quick_f32(
-    device const float * src0,
-    device       float * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
-
-    dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
-}
-
-kernel void kernel_gelu_quick_f32_4(
-    device const float4 * src0,
-    device       float4 * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-
-    dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
-}
-
-// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
-// ref: https://www.johndcook.com/blog/python_erf/
-constant float p_erf  = 0.3275911f;
-constant float a1_erf = 0.254829592f;
-constant float a2_erf = -0.284496736f;
-constant float a3_erf = 1.421413741f;
-constant float a4_erf = -1.453152027f;
-constant float a5_erf = 1.061405429f;
-
 template
-T erf_approx(T x) {
-    T sign_x = sign(x);
-    x = fabs(x);
-    T t = 1.0f / (1.0f + p_erf * x);
-    T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
-    return sign_x * y;
-}
-
-kernel void kernel_gelu_erf_f32(
-    device const float * src0,
-    device       float * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
-
-    dst[tpig] = 0.5f*x*(1.0f+erf_approx(x*SQRT_2_INV));
-}
-
-kernel void kernel_gelu_erf_f32_4(
-    device const float4 * src0,
-    device       float4 * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-
-    dst[tpig] = 0.5f*x*(1.0f+erf_approx(x*SQRT_2_INV));
-}
-
-kernel void kernel_silu_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
-    dst[tpig] = x / (1.0f + exp(-x));
-}
-
-kernel void kernel_silu_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-    dst[tpig] = x / (1.0f + exp(-x));
-}
-
-kernel void kernel_elu_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float x = src0[tpig];
-    dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f);
-}
-
-kernel void kernel_elu_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float4 x = src0[tpig];
-    dst[tpig][0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
-    dst[tpig][1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
-    dst[tpig][2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
-    dst[tpig][3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
-}
-
-kernel void kernel_sqr_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * src0[tpig];
-}
-
-kernel void kernel_sqr_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * src0[tpig];
-}
-
-kernel void kernel_sqrt_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sqrt(src0[tpig]);
-}
-
-kernel void kernel_sqrt_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sqrt(src0[tpig]);
-}
-
-kernel void kernel_sin_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sin(src0[tpig]);
-}
-
-kernel void kernel_sin_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sin(src0[tpig]);
-}
-
-kernel void kernel_cos_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = cos(src0[tpig]);
-}
-
-kernel void kernel_cos_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = cos(src0[tpig]);
-}
-
-kernel void kernel_log_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = log(src0[tpig]);
-}
-
-kernel void kernel_log_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = log(src0[tpig]);
-}
-
-kernel void kernel_neg_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = -src0[tpig];
-}
-
-kernel void kernel_neg_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = -src0[tpig];
-}
-
-kernel void kernel_abs_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = fabs(src0[tpig]);
-}
-
-kernel void kernel_abs_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = fabs(src0[tpig]);
-}
-
-kernel void kernel_sgn_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sign(src0[tpig]);
-}
-
-kernel void kernel_sgn_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sign(src0[tpig]);
-}
-
-kernel void kernel_step_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = step(0.0f, src0[tpig]);
-}
-
-kernel void kernel_step_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = step(0.0f, src0[tpig]);
-}
-
-kernel void kernel_hardswish_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float x = src0[tpig];
-    dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
-}
-
-kernel void kernel_hardswish_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float4 x = src0[tpig];
-    dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
-}
-
-kernel void kernel_hardsigmoid_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float x = src0[tpig];
-    dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
-}
-
-kernel void kernel_hardsigmoid_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float4 x = src0[tpig];
-    dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
-}
-
-kernel void kernel_exp_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = exp(src0[tpig]);
-}
+kernel void kernel_repeat(
+        constant ggml_metal_kargs_repeat & args,
+        device const char * src0,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i3 = tgpig.z;
+    const int i2 = tgpig.y;
+    const int i1 = tgpig.x;
 
-kernel void kernel_exp_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = exp(src0[tpig]);
-}
+    const int i03 = i3%args.ne03;
+    const int i02 = i2%args.ne02;
+    const int i01 = i1%args.ne01;
 
-kernel void kernel_softplus_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
-    dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
-}
+    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
+    device       char * dst_ptr  = dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1;
 
-kernel void kernel_softplus_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-    dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        const int i00 = i0%args.ne00;
+        *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
+    }
 }
 
-kernel void kernel_expm1_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = exp(src0[tpig]) - 1.0f;
-}
+typedef decltype(kernel_repeat) kernel_repeat_t;
 
-kernel void kernel_expm1_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = exp(src0[tpig]) - 1.0f;
-}
+template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat;
+template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat;
+template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat;
+template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat;
 
 kernel void kernel_reglu_f32(
         constant ggml_metal_kargs_glu & args,
@@ -4612,6 +4406,7 @@ kernel void kernel_op_sum_f32(
         return;
     }
 
+    // TODO: become function constant
     const uint nsg = (ntg.x + 31) / 32;
 
     float sumf = 0;
@@ -4645,33 +4440,35 @@ kernel void kernel_op_sum_f32(
     }
 }
 
-template 
-kernel void kernel_sum_rows(
+constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]];
+
+template 
+kernel void kernel_sum_rows_impl(
         constant ggml_metal_kargs_sum_rows & args,
-        device const float * src0,
-        device       float * dst,
-        threadgroup  float * shmem_f32 [[threadgroup(0)]],
+        device const char * src0,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
         uint3   tgpig[[threadgroup_position_in_grid]],
         ushort3 tpitg[[thread_position_in_threadgroup]],
         ushort  sgitg[[simdgroup_index_in_threadgroup]],
         ushort  tiisg[[thread_index_in_simdgroup]],
         ushort3   ntg[[threads_per_threadgroup]]) {
-    int64_t i3 = tgpig.z;
-    int64_t i2 = tgpig.y;
-    int64_t i1 = tgpig.x;
+#define FC_OP  FC_sum_rows_op
 
-    if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
-        return;
-    }
+    const int i3 = tgpig.z;
+    const int i2 = tgpig.y;
+    const int i1 = tgpig.x;
+
+    threadgroup T0 * shmem_t = (threadgroup T0 *) shmem;
 
     if (sgitg == 0) {
-        shmem_f32[tiisg] = 0.0f;
+        shmem_t[tiisg] = 0.0f;
     }
 
-    device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
-    device       float * dst_row = (device       float *) ((device       char *) dst  + i1*args.nb1  + i2*args.nb2  + i3*args.nb3);
+    device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
+    device       T  * dst_row = (device       T  *) (dst  + i1*args.nb1  + i2*args.nb2  + i3*args.nb3);
 
-    float sumf = 0;
+    T0 sumf = T0(0.0f);
 
     for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
         sumf += src_row[i0];
@@ -4682,23 +4479,33 @@ kernel void kernel_sum_rows(
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     if (tiisg == 0) {
-        shmem_f32[sgitg] = sumf;
+        shmem_t[sgitg] = sumf;
     }
 
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    sumf = shmem_f32[tiisg];
+    sumf = shmem_t[tiisg];
     sumf = simd_sum(sumf);
 
     if (tpitg.x == 0) {
-        dst_row[0] = norm ? sumf / args.ne00 : sumf;
+        if (FC_OP == OP_SUM_ROWS_NUM_MEAN) {
+            if (is_same::value) {
+                dst_row[0] = sum(sumf) / (4*args.ne00);
+            } else {
+                dst_row[0] = sum(sumf) / args.ne00;
+            }
+        } else {
+            dst_row[0] = sum(sumf);
+        }
     }
+
+#undef FC_OP
 }
 
-typedef decltype(kernel_sum_rows) kernel_sum_rows_t;
+typedef decltype(kernel_sum_rows_impl) kernel_sum_rows_t;
 
-template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows;
-template [[host_name("kernel_mean_f32")]]     kernel kernel_sum_rows_t kernel_sum_rows;
+template [[host_name("kernel_sum_rows_f32_f32")]]   kernel kernel_sum_rows_t kernel_sum_rows_impl;
+template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl;
 
 template
 kernel void kernel_cumsum_blk(
@@ -5558,6 +5365,80 @@ kernel void kernel_rwkv_wkv7_f32(
     }
 }
 
+constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]];
+constant short FC_solve_tri_n   [[function_constant(FC_SOLVE_TRI + 1)]];
+constant short FC_solve_tri_k   [[function_constant(FC_SOLVE_TRI + 2)]];
+
+kernel void kernel_solve_tri_f32(
+        constant ggml_metal_kargs_solve_tri & args,
+        device   const char * src0,
+        device   const char * src1,
+        device         char * dst,
+        threadgroup    char * shmem [[threadgroup(0)]],
+        ushort3 tgpig[[threadgroup_position_in_grid]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    constexpr short NW = N_SIMDWIDTH;
+
+    const short NSG = FC_solve_tri_nsg;
+    const short N   = FC_solve_tri_n;
+    const short K   = FC_solve_tri_k;
+    const short NP  = PAD2(N, NW);
+
+    const int32_t i03 = tgpig.z;
+    const int32_t i02 = tgpig.y;
+    const int32_t i01 = tgpig.x*NSG + sgitg;
+
+    threadgroup float * sh0 = (threadgroup float *) shmem;
+
+    device const float * src0_ptr = (device const float *)(src0 + i02 * args.nb02 + i03 * args.nb03) + sgitg*N;
+    device const float * src1_ptr = (device const float *)(src1 + i02 * args.nb12 + i03 * args.nb13) + i01;
+    device       float * dst_ptr  = (device       float *)(dst  + i02 * args.nb2  + i03 * args.nb3)  + i01;
+
+    for (short rr = 0; rr < N; rr += NSG) {
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        {
+            threadgroup float * sh0_cur = sh0 + sgitg*NP;
+
+            for (short t = 0; t*NW < N; ++t) {
+                const short idx = t*NW + tiisg;
+                sh0_cur[idx] = src0_ptr[idx];
+            }
+
+            src0_ptr += NSG*N;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (i01 >= args.ne10) {
+            continue;
+        }
+
+        for (short ir = 0; ir < NSG && rr + ir < N; ++ir) {
+            const short r = rr + ir;
+
+            threadgroup float * sh0_cur = sh0 + ir*NP;
+
+            float sum = 0.0f;
+
+            for (short t = 0; t*NW < r; ++t) {
+                const short idx = t*NW + tiisg;
+                sum += sh0_cur[idx] * dst_ptr[idx*K] * (idx < r);
+            }
+
+            sum = simd_sum(sum);
+
+            if (tiisg == 0) {
+                const float diag = sh0_cur[r];
+
+                dst_ptr[r*K] = (src1_ptr[r*K] - sum) / diag;
+            }
+        }
+    }
+}
+
 kernel void kernel_argmax_f32(
         constant ggml_metal_kargs_argmax & args,
         device   const char * src0,
@@ -5791,26 +5672,32 @@ template [[host_name("kernel_rms_norm_f32_4")]]         kernel kernel_rms_norm_f
 template [[host_name("kernel_rms_norm_mul_f32_4")]]     kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl;
 template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl;
 
-kernel void kernel_l2_norm_f32(
+template 
+kernel void kernel_l2_norm_impl(
         constant ggml_metal_kargs_l2_norm & args,
         device const char * src0,
         device       char * dst,
         threadgroup float * shmem_f32 [[threadgroup(0)]],
-        uint   tgpig[[threadgroup_position_in_grid]],
-        ushort tpitg[[thread_position_in_threadgroup]],
-        ushort sgitg[[simdgroup_index_in_threadgroup]],
-        ushort tiisg[[thread_index_in_simdgroup]],
-        ushort   ntg[[threads_per_threadgroup]]) {
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig.z;
+    const int i02 = tgpig.y;
+    const int i01 = tgpig.x;
+
     if (sgitg == 0) {
         shmem_f32[tiisg] = 0.0f;
     }
 
-    device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
+    device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
+    device       T  * y = (device       T  *) (dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1);
 
     float sumf = 0.0f;
 
     // parallel sum
-    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+    for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
         sumf += dot(x[i00], x[i00]);
     }
     sumf = simd_sum(sumf);
@@ -5828,71 +5715,15 @@ kernel void kernel_l2_norm_f32(
 
     const float scale = 1.0f/sqrt(max(sumf, args.eps));
 
-    device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
-    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+    for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
         y[i00] = x[i00] * scale;
     }
 }
 
-kernel void kernel_solve_tri_f32(
-        constant ggml_metal_kargs_solve_tri & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint   tgpig[[threadgroup_position_in_grid]],
-        ushort tpitg[[thread_position_in_threadgroup]],
-        ushort   ntg[[threads_per_threadgroup]]) {
-    const uint64_t ncols = (uint64_t) args.ne10;
-    const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
-    const uint64_t nr = n_batches * ncols;
-
-    const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
-    if (gid >= nr) {
-        return;
-    }
-
-    const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
-    const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
-    const uint64_t i02 = rem / ncols;
-    const uint64_t i01 = rem - i02 * ncols;
-
-    const uint64_t sa0 = args.nb00 / sizeof(float);
-    const uint64_t sa1 = args.nb01 / sizeof(float);
-    const uint64_t sa2 = args.nb02 / sizeof(float);
-    const uint64_t sa3 = args.nb03 / sizeof(float);
-
-    const uint64_t sb0 = args.nb10 / sizeof(float);
-    const uint64_t sb1 = args.nb11 / sizeof(float);
-    const uint64_t sb2 = args.nb12 / sizeof(float);
-    const uint64_t sb3 = args.nb13 / sizeof(float);
-
-    const uint64_t sx0 = args.nb0 / sizeof(float);
-    const uint64_t sx1 = args.nb1 / sizeof(float);
-    const uint64_t sx2 = args.nb2 / sizeof(float);
-    const uint64_t sx3 = args.nb3 / sizeof(float);
-
-    device const float * A = (device const float *) src0;
-    device const float * B = (device const float *) src1;
-    device       float * X = (device       float *) dst;
+typedef decltype(kernel_l2_norm_impl) kernel_l2_norm_t;
 
-    const uint64_t A_base = i02 * sa2 + i03 * sa3;
-    const uint64_t B_base = i02 * sb2 + i03 * sb3;
-    const uint64_t X_base = i02 * sx2 + i03 * sx3;
-
-    const uint64_t n = (uint64_t) args.ne11;
-
-    for (uint64_t i00 = 0; i00 < n; ++i00) {
-        float sum = 0.0f;
-        for (uint64_t t = 0; t < i00; ++t) {
-            sum += A[A_base + i00 * sa1 + t * sa0] *
-                X[X_base + t * sx1 + i01 * sx0];
-        }
-
-        const float diag = A[A_base + i00 * sa1 + i00 * sa0];
-        X[X_base + i00 * sx1 + i01 * sx0] =
-            (B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
-    }
-}
+template [[host_name("kernel_l2_norm_f32_f32")]]   kernel kernel_l2_norm_t kernel_l2_norm_impl;
+template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl;
 
 kernel void kernel_group_norm_f32(
         constant ggml_metal_kargs_group_norm & args,
@@ -8210,24 +8041,6 @@ template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge
 template [[host_name("kernel_argsort_merge_i32_i32_asc")]]  kernel argsort_merge_t kernel_argsort_merge_i32_i32;
 template [[host_name("kernel_argsort_merge_i32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_i32_i32;
 
-kernel void kernel_leaky_relu_f32(
-        constant     ggml_metal_kargs_leaky_relu & args,
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float x = src0[tpig];
-    dst[tpig] = x > 0.0f ? x : x * args.slope;
-}
-
-kernel void kernel_leaky_relu_f32_4(
-        constant     ggml_metal_kargs_leaky_relu & args,
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float4 x = src0[tpig];
-    dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope);
-}
-
 constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
 
 constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]];
@@ -8304,6 +8117,7 @@ constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_E
 // scan the blocks of the mask that are not masked
 // 0 -     masked (i.e. full of -INF, skip)
 // 1 - not masked (i.e. at least one element of the mask is not -INF)
+// 2 - all zero
 kernel void kernel_flash_attn_ext_blk(
         constant ggml_metal_kargs_flash_attn_ext_blk & args,
         device const char * mask,
@@ -8325,27 +8139,29 @@ kernel void kernel_flash_attn_ext_blk(
 
     device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg;
 
-    // fast route
-    if (res == 0) {
-        if (simd_max(*mask_src) > -MAXHALF/2) {
-            res = 1;
-        }
-    }
-
     // detailed check of the elements of the block
     if ((C > NW || Q > 1) && res == 0) {
-        half m = -MAXHALF;
+        half mmin =  MAXHALF;
+        half mmax = -MAXHALF;
 
         FOR_UNROLL (short j = 0; j < Q; ++j) {
             FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) {
-                m = max(m, mask_src[ii*NW]);
+                mmin = min(mmin, mask_src[ii*NW]);
+                mmax = max(mmax, mask_src[ii*NW]);
             }
 
             mask_src += args.nb31/2;
         }
 
-        if (simd_max(m) > -MAXHALF/2) {
-            res = 1;
+        mmin = simd_min(mmin);
+        mmax = simd_max(mmax);
+
+        if (mmax > -MAXHALF) {
+            if (mmin == 0.0 && mmax == 0.0) {
+                res = 2;
+            } else {
+                res = 1;
+            }
         }
     }
 
@@ -8587,9 +8403,13 @@ void kernel_flash_attn_ext_impl(
                 ic = 0;
             }
 
+            char blk_cur = 1;
+
             // read the mask into shared mem
             if (FC_flash_attn_ext_has_mask) {
-                if (blk[ic0] == 0) {
+                blk_cur = blk[ic0];
+
+                if (blk_cur == 0) {
                     FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
                         pm2[jj] += NW;
                     }
@@ -8597,16 +8417,22 @@ void kernel_flash_attn_ext_impl(
                     continue;
                 }
 
-                FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
-                    const short j = jj*NSG + sgitg;
+                if (blk_cur == 1) {
+                    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+                        const short j = jj*NSG + sgitg;
+
+                        if (FC_flash_attn_ext_bc_mask) {
+                            sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
+                        } else {
+                            sm2[j*SH + tiisg] = pm2[jj][tiisg];
+                        }
 
-                    if (FC_flash_attn_ext_bc_mask) {
-                        sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
-                    } else {
-                        sm2[j*SH + tiisg] = pm2[jj][tiisg];
+                        pm2[jj] += NW;
+                    }
+                } else if (blk_cur == 2) {
+                    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+                        pm2[jj] += NW;
                     }
-
-                    pm2[jj] += NW;
                 }
 
 #if 0
@@ -8648,9 +8474,7 @@ void kernel_flash_attn_ext_impl(
 
                 constexpr short NC = (C/8)/NSG;
 
-                // note: do not unroll for large heads
-                #pragma unroll (DK <= 64 ? NC : 1)
-                for (short cc = 0; cc < NC; ++cc) {
+                FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
                     qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f);
 
                     if (DK % 16 != 0) {
@@ -8671,7 +8495,9 @@ void kernel_flash_attn_ext_impl(
                         k8x8_t mk[2];
                         q8x8_t mq[2];
 
-                        FOR_UNROLL (short i = 0; i < DK8/2; ++i) {
+                        // note: too much unroll can tank the performance for large heads
+                        #pragma unroll (MIN(DK8/2, 4*NSG))
+                        for (short i = 0; i < DK8/2; ++i) {
                             simdgroup_barrier(mem_flags::mem_none);
 
                             simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);
@@ -8771,10 +8597,12 @@ void kernel_flash_attn_ext_impl(
                 }
 
                 // mqk = mqk + slope*mask
-                if (FC_flash_attn_ext_has_bias) {
-                    s2 += s2_t(sm2[j*SH + tiisg])*slope;
-                } else {
-                    s2 += s2_t(sm2[j*SH + tiisg]);
+                if (blk_cur != 2) {
+                    if (FC_flash_attn_ext_has_bias) {
+                        s2 += s2_t(sm2[j*SH + tiisg])*slope;
+                    } else {
+                        s2 += s2_t(sm2[j*SH + tiisg]);
+                    }
                 }
 
                 M[jj] = simd_max(max(M[jj], max(s2[0], s2[1])));
@@ -8845,7 +8673,9 @@ void kernel_flash_attn_ext_impl(
                                 pv  += 8*NS20;
                             }
                         } else {
-                            FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) {
+                            constexpr short NC = (C/8)/2;
+
+                            FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
                                 s8x8_t vs[2];
 
                                 simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);
@@ -9025,7 +8855,7 @@ template<
     void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
     short DK,         // K head size
     short DV,         // V head size
-    short Q  = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup
+    short Q  = OP_FLASH_ATTN_EXT_NQPSG, // queries per threadgroup
     short C  = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup
 kernel void kernel_flash_attn_ext(
         constant ggml_metal_kargs_flash_attn_ext & args,
@@ -9235,11 +9065,10 @@ template<
     void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
     short DK,       // K head size
     short DV,       // V head size
-    short NE,       // head elements per thread
-    short Q,        // queries per threadgroup
-    short C,        // cache items per threadgroup
-    short NSG>      // number of simd groups
-void kernel_flash_attn_ext_vec_impl(
+    short NE = 4,   // head elements per thread
+    short Q  = OP_FLASH_ATTN_EXT_VEC_NQPSG,  // queries per threadgroup
+    short C  = OP_FLASH_ATTN_EXT_VEC_NCPSG>  // cache items per threadgroup
+kernel void kernel_flash_attn_ext_vec(
         constant ggml_metal_kargs_flash_attn_ext_vec & args,
         device const char * q,
         device const char * k,
@@ -9256,6 +9085,7 @@ void kernel_flash_attn_ext_vec_impl(
     static_assert(DV % 32 == 0, "DV must be divisible by 32");
 
 #define NWG  (FC_flash_attn_ext_vec_nwg)
+#define NSG  (FC_flash_attn_ext_vec_nsg)
 
 #define NS10 (FC_flash_attn_ext_vec_ns10)
 #define NS20 (FC_flash_attn_ext_vec_ns20)
@@ -9282,14 +9112,14 @@ void kernel_flash_attn_ext_vec_impl(
     static_assert(DK4 % NL == 0, "DK4 must be divisible by NL");
     static_assert(DV4 % NL == 0, "DV4 must be divisible by NL");
 
-    const short T = PK + NSG*SH; // shared memory size per query in (half)
+  //const short T = PK + NSG*SH; // shared memory size per query in (half)
 
-  //threadgroup q_t   * sq  = (threadgroup q_t   *) (shmem_f16 +                    0*PK); // holds the query data
-    threadgroup q4_t  * sq4 = (threadgroup q4_t  *) (shmem_f16 +                    0*PK); // same as above but in q4_t
-    threadgroup s_t   * ss  = (threadgroup s_t   *) (shmem_f16 +   sgitg*SH       + Q*PK); // scratch buffer for attention
-    threadgroup s4_t  * ss4 = (threadgroup s4_t  *) (shmem_f16 +   sgitg*SH       + Q*PK); // same as above but in s4_t
-    threadgroup half  * sm  = (threadgroup half  *) (shmem_f16 +   sgitg*SH + 2*C + Q*PK); // scratch buffer for mask
-    threadgroup o4_t  * so4 = (threadgroup o4_t  *) (shmem_f16 + 2*sgitg*PV       + Q*T);  // scratch buffer for the results
+  //threadgroup q_t   * sq  = (threadgroup q_t   *) (shmem_f16 +                      0*PK); // holds the query data
+    threadgroup q4_t  * sq4 = (threadgroup q4_t  *) (shmem_f16 +                      0*PK); // same as above but in q4_t
+    threadgroup s_t   * ss  = (threadgroup s_t   *) (shmem_f16 +   sgitg*SH       + NSG*PK); // scratch buffer for attention
+    threadgroup s4_t  * ss4 = (threadgroup s4_t  *) (shmem_f16 +   sgitg*SH       + NSG*PK); // same as above but in s4_t
+    threadgroup half  * sm  = (threadgroup half  *) (shmem_f16 +   sgitg*SH + 2*C + NSG*PK); // scratch buffer for mask
+    threadgroup o4_t  * so4 = (threadgroup o4_t  *) (shmem_f16 + 2*sgitg*PV       + NSG*PK + NSG*SH); // scratch buffer for the results
 
     // store the result for all queries in shared memory (the O matrix from the paper)
     so4 += tiisg;
@@ -9307,11 +9137,13 @@ void kernel_flash_attn_ext_vec_impl(
     // load heads from Q to shared memory
     device const float4 * q4 = (device const float4 *) ((device const char *) q);
 
-    for (short i = tiisg; i < PK4; i += NW) {
-        if (iq1 < args.ne01 && i < DK4) {
-            sq4[i] = (q4_t) q4[i];
-        } else {
-            sq4[i] = (q4_t) 0.0f;
+    if (iq1 < args.ne01) {
+        for (short i = tiisg; i < PK4; i += NW) {
+            if (i < DK4) {
+                sq4[i] = (q4_t) q4[i];
+            } else {
+                sq4[i] = (q4_t) 0.0f;
+            }
         }
     }
 
@@ -9389,7 +9221,7 @@ void kernel_flash_attn_ext_vec_impl(
             }
 
             // skip -INF blocks
-            if (simd_max(sm[tiisg]) == -INFINITY) {
+            if (simd_max(sm[tiisg]) <= -MAXHALF) {
                 continue;
             }
 
@@ -9663,57 +9495,11 @@ void kernel_flash_attn_ext_vec_impl(
     }
 
 #undef NWG
+#undef NSG
 #undef NS10
 #undef NS20
 }
 
-template<
-    typename q4_t,  // query types in shared memory
-    typename k4_t,  // key types in shared memory
-    typename v4_t,  // value types in shared memory
-    typename qk_t,  // Q*K types
-    typename s_t,   // soft-max types
-    typename s4_t,
-    typename o4_t,  // attention accumulation types
-    typename kd4_t, // key type in device memory
-    short nl_k,
-    void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
-    typename vd4_t, // value type in device memory
-    short nl_v,
-    void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
-    short DK,       // K head size
-    short DV,       // V head size
-    short NE = 4,   // head elements per thread
-    short Q  = OP_FLASH_ATTN_EXT_VEC_NQPTG,  // queries per threadgroup
-    short C  = OP_FLASH_ATTN_EXT_VEC_NCPSG>  // cache items per threadgroup
-kernel void kernel_flash_attn_ext_vec(
-        constant ggml_metal_kargs_flash_attn_ext_vec & args,
-        device const char * q,
-        device const char * k,
-        device const char * v,
-        device const char * mask,
-        device const char * sinks,
-        device const char * pad,
-        device       char * dst,
-        threadgroup  half * shmem_f16 [[threadgroup(0)]],
-        uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort  tiisg[[thread_index_in_simdgroup]],
-        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
-#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C
-#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
-    switch (FC_flash_attn_ext_vec_nsg) {
-      // note: disabled cases to reduce library load time
-        case 1:  kernel_flash_attn_ext_vec_impl(FWD_ARGS); break;
-        case 2:  kernel_flash_attn_ext_vec_impl(FWD_ARGS); break;
-        case 4:  kernel_flash_attn_ext_vec_impl(FWD_ARGS); break;
-      //case 8:  kernel_flash_attn_ext_vec_impl(FWD_ARGS); break;
-      //case 16: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break;
-      //case 32: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break;
-    }
-#undef FWD_TMPL
-#undef FWD_ARGS
-}
-
 // note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
 //       in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
 //
@@ -11876,6 +11662,26 @@ kernel void kernel_set_rows_f(
     }
 }
 
+kernel void kernel_diag_f32(
+        constant ggml_metal_kargs_diag & args,
+        device   const char * src0,
+        device         char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiitg[[thread_index_in_threadgroup]]) {
+    constexpr short NW = N_SIMDWIDTH;
+
+    const int32_t i3 = tgpig.z;
+    const int32_t i2 = tgpig.y;
+    const int32_t i1 = tgpig.x;
+
+    device const float * src0_ptr = (device const float *)(src0 +                i2*args.nb02 + i3*args.nb03);
+    device       float * dst_ptr  = (device       float *)(dst  + i1*args.nb01 + i2*args.nb2  + i3*args.nb3);
+
+    for (int i0 = tiitg; i0 < args.ne0; i0 += NW) {
+        dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f;
+    }
+}
+
 constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
 constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
 
@@ -11894,7 +11700,9 @@ kernel void kernel_mul_mm(
     threadgroup S0 * sa = (threadgroup S0 *)(shmem);
     threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
 
+#ifdef GGML_METAL_HAS_TENSOR
     threadgroup float * sc = (threadgroup float *)(shmem);
+#endif
 
     constexpr int NR0 = 64;
     constexpr int NR1 = 32;
@@ -12017,8 +11825,8 @@ kernel void kernel_mul_mm(
             const short sx = (tiitg%NL1);
             const short sy = (tiitg/NL1)/8;
 
-            const short dx = sx;
-            const short dy = sy;
+          //const short dx = sx;
+          //const short dy = sy;
 
             const short ly = (tiitg/NL1)%8;
 
@@ -12245,6 +12053,7 @@ typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
 template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
 template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
 template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
+template [[host_name("kernel_mul_mm_id_map0_ne20_5" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<5>;
 template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
 template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
 template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
@@ -12267,7 +12076,9 @@ kernel void kernel_mul_mm_id(
     threadgroup S0 * sa = (threadgroup S0 *)(shmem);
     threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
 
+#ifdef GGML_METAL_HAS_TENSOR
     threadgroup float * sc = (threadgroup float *)(shmem);
+#endif
 
     constexpr int NR0 = 64;
     constexpr int NR1 = 32;
@@ -12402,8 +12213,8 @@ kernel void kernel_mul_mm_id(
             const short sx = (tiitg%NL1);
             const short sy = (tiitg/NL1)/8;
 
-            const short dx = sx;
-            const short dy = sy;
+          //const short dx = sx;
+          //const short dy = sy;
 
             const short ly = (tiitg/NL1)%8;
 
@@ -12656,9 +12467,6 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]]  kernel mul_mm_t kernel_mul_m
 
 template [[host_name("kernel_mul_mm_f32_f16")]]     kernel mul_mm_t kernel_mul_mm;
 template [[host_name("kernel_mul_mm_f16_f16")]]     kernel mul_mm_t kernel_mul_mm;
-#if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_mul_mm_bf16_f16")]]    kernel mul_mm_t kernel_mul_mm;
-#endif
 template [[host_name("kernel_mul_mm_q4_0_f16")]]    kernel mul_mm_t kernel_mul_mm;
 template [[host_name("kernel_mul_mm_q4_1_f16")]]    kernel mul_mm_t kernel_mul_mm;
 template [[host_name("kernel_mul_mm_q5_0_f16")]]    kernel mul_mm_t kernel_mul_mm;
@@ -12714,9 +12522,6 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]]  kernel mul_mm_id kernel_m
 
 template [[host_name("kernel_mul_mm_id_f32_f16")]]     kernel mul_mm_id kernel_mul_mm_id;
 template [[host_name("kernel_mul_mm_id_f16_f16")]]     kernel mul_mm_id kernel_mul_mm_id;
-#if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_mul_mm_id_bf16_f16")]]    kernel mul_mm_id kernel_mul_mm_id;
-#endif
 template [[host_name("kernel_mul_mm_id_q4_0_f16")]]    kernel mul_mm_id kernel_mul_mm_id;
 template [[host_name("kernel_mul_mm_id_q4_1_f16")]]    kernel mul_mm_id kernel_mul_mm_id;
 template [[host_name("kernel_mul_mm_id_q5_0_f16")]]    kernel mul_mm_id kernel_mul_mm_id;
@@ -12972,6 +12777,74 @@ kernel void kernel_pool_2d_avg_f32(
     o_ptr[cur_oh * args.OW + cur_ow] = res;
 }
 
+
+kernel void kernel_pool_1d_max_f32(
+        constant        ggml_metal_kargs_pool_1d & args,
+        device  const   float * src,
+        device          float * dst,
+        uint            gid [[thread_position_in_grid]]
+) {
+
+    if (gid >= args.np) {
+        return;
+    }
+
+    const int ow  = (int)gid % args.OW;
+    const int row = (int)gid / args.OW;
+
+    const int base = ow * args.s0 - args.p0;
+
+    float acc = -INFINITY;
+
+    const int src_off = row * args.IW;
+    const int dst_off = row * args.OW;
+
+    for (int ki = 0; ki < args.k0; ++ki) {
+        int j = base + ki;
+        if (j < 0 || j >= args.IW){
+            continue;
+        }
+        float v = src[src_off + j];
+        acc = max(acc, v);
+    }
+
+    dst[dst_off + ow] = acc;
+}
+
+kernel void kernel_pool_1d_avg_f32(
+        constant        ggml_metal_kargs_pool_1d & args,
+        device  const   float * src,
+        device          float * dst,
+        uint            gid [[thread_position_in_grid]]
+) {
+
+    if (gid >= args.np) {
+        return;
+    }
+
+    const int ow  = (int)gid % args.OW;
+    const int row = (int)gid / args.OW;
+
+    const int base = ow * args.s0 - args.p0;
+
+    float acc = 0.0f;
+    int   cnt = 0;
+
+    const int src_off = row * args.IW;
+    const int dst_off = row * args.OW;
+
+    for (int ki = 0; ki < args.k0; ++ki) {
+        const int j = base + ki;
+        if (j < 0 || j >= args.IW) {
+            continue;
+        }
+        acc += src[src_off + j];
+        cnt += 1;
+    }
+
+    dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f;
+}
+
 kernel void kernel_opt_step_adamw_f32(
         constant    ggml_metal_kargs_opt_step_adamw & args,
         device       float * x,
@@ -13019,3 +12892,75 @@ kernel void kernel_opt_step_sgd_f32(
 
     x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
 }
+
+template
+kernel void kernel_memset(
+        constant ggml_metal_kargs_memset & args,
+        device T * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = args.val;
+}
+
+typedef decltype(kernel_memset) kernel_memset_t;
+
+template [[host_name("kernel_memset_i64")]] kernel kernel_memset_t kernel_memset;
+
+constant short FC_count_equal_nsg [[function_constant(FC_COUNT_EQUAL + 0)]];
+
+template
+kernel void kernel_count_equal(
+        constant ggml_metal_kargs_count_equal & args,
+        device   const char * src0,
+        device   const char * src1,
+        device   atomic_int * dst,
+        threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const short NSG = FC_count_equal_nsg;
+
+    const int i3 = tgpig.z;
+    const int i2 = tgpig.y;
+    const int i1 = tgpig.x;
+
+    if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
+        return;
+    }
+
+    int sum = 0;
+
+    device const char * base0 = src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03;
+    device const char * base1 = src1 + i1*args.nb11 + i2*args.nb12 + i3*args.nb13;
+
+    for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
+        const T v0 = *(device const T *)(base0 + i0*args.nb00);
+        const T v1 = *(device const T *)(base1 + i0*args.nb10);
+        sum += (v0 == v1);
+    }
+
+    sum = simd_sum(sum);
+
+    if (tiisg == 0) {
+        shmem_i32[sgitg] = sum;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    if (sgitg == 0) {
+        float v = 0.0f;
+        if (tpitg.x < NSG) {
+            v = shmem_i32[tpitg.x];
+        }
+
+        float total = simd_sum(v);
+        if (tpitg.x == 0) {
+            atomic_fetch_add_explicit(dst, (int32_t) total, memory_order_relaxed);
+        }
+    }
+}
+
+typedef decltype(kernel_count_equal) kernel_count_equal_t;
+
+template [[host_name("kernel_count_equal_i32")]] kernel kernel_count_equal_t kernel_count_equal;
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h
index cfdea9c0721..383e0d6e93b 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h
@@ -78,14 +78,50 @@
 #define FC_MUL_MM                      700
 #define FC_ROPE                        800
 #define FC_SSM_CONV                    900
+#define FC_SOLVE_TRI                   1000
+#define FC_COUNT_EQUAL                 1100
+#define FC_UNARY                       1200
+#define FC_BIN                         1300
+#define FC_SUM_ROWS                    1400
 
 // op-specific constants
-#define OP_FLASH_ATTN_EXT_NQPTG 8
+#define OP_FLASH_ATTN_EXT_NQPSG 8
 #define OP_FLASH_ATTN_EXT_NCPSG 64
 
-#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1
+#define OP_FLASH_ATTN_EXT_VEC_NQPSG 1
 #define OP_FLASH_ATTN_EXT_VEC_NCPSG 32
 
+#define OP_UNARY_NUM_SCALE      10
+#define OP_UNARY_NUM_FILL       11
+#define OP_UNARY_NUM_CLAMP      12
+#define OP_UNARY_NUM_SQR        13
+#define OP_UNARY_NUM_SQRT       14
+#define OP_UNARY_NUM_SIN        15
+#define OP_UNARY_NUM_COS        16
+#define OP_UNARY_NUM_LOG        17
+#define OP_UNARY_NUM_LEAKY_RELU 18
+
+#define OP_UNARY_NUM_TANH        100
+#define OP_UNARY_NUM_RELU        101
+#define OP_UNARY_NUM_SIGMOID     102
+#define OP_UNARY_NUM_GELU        103
+#define OP_UNARY_NUM_GELU_ERF    104
+#define OP_UNARY_NUM_GELU_QUICK  105
+#define OP_UNARY_NUM_SILU        106
+#define OP_UNARY_NUM_ELU         107
+#define OP_UNARY_NUM_NEG         108
+#define OP_UNARY_NUM_ABS         109
+#define OP_UNARY_NUM_SGN         110
+#define OP_UNARY_NUM_STEP        111
+#define OP_UNARY_NUM_HARDSWISH   112
+#define OP_UNARY_NUM_HARDSIGMOID 113
+#define OP_UNARY_NUM_EXP         114
+#define OP_UNARY_NUM_SOFTPLUS    115
+#define OP_UNARY_NUM_EXPM1       116
+
+#define OP_SUM_ROWS_NUM_SUM_ROWS 10
+#define OP_SUM_ROWS_NUM_MEAN     11
+
 // kernel argument structs
 //
 // - element counters (e.g. ne00) typically use int32_t to reduce register usage
@@ -121,6 +157,31 @@ typedef struct {
     int32_t  dim;
 } ggml_metal_kargs_concat;
 
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    float    slope;
+    float    scale;
+    float    bias;
+    float    val;
+    float    min;
+    float    max;
+} ggml_metal_kargs_unary;
+
 typedef struct {
     int32_t  ne00;
     int32_t  ne01;
@@ -178,20 +239,6 @@ typedef struct {
     uint64_t nb3;
 } ggml_metal_kargs_repeat;
 
-typedef struct {
-    float scale;
-    float bias;
-} ggml_metal_kargs_scale;
-
-typedef struct {
-    float val;
-} ggml_metal_kargs_fill;
-
-typedef struct {
-    float min;
-    float max;
-} ggml_metal_kargs_clamp;
-
 typedef struct {
     int64_t  nk0;
     int64_t  ne00;
@@ -493,13 +540,6 @@ typedef struct {
     uint64_t nbf3[3];
 } ggml_metal_kargs_norm;
 
-typedef struct {
-    int32_t  ne00;
-    int32_t  ne00_4;
-    uint64_t nb01;
-    float    eps;
-} ggml_metal_kargs_l2_norm;
-
 typedef struct {
     int32_t  ne00;
     int32_t  ne01;
@@ -509,17 +549,16 @@ typedef struct {
     uint64_t nb01;
     uint64_t nb02;
     uint64_t nb03;
-    int32_t  ne10;
-    int32_t  ne11;
-    uint64_t nb10;
-    uint64_t nb11;
-    uint64_t nb12;
-    uint64_t nb13;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
     uint64_t nb0;
     uint64_t nb1;
     uint64_t nb2;
     uint64_t nb3;
-} ggml_metal_kargs_solve_tri;
+    float    eps;
+} ggml_metal_kargs_l2_norm;
 
 typedef struct {
     int64_t  ne00;
@@ -753,6 +792,33 @@ typedef struct {
     uint64_t nb0;
 } ggml_metal_kargs_ssm_scan;
 
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne10;
+    int32_t  ne11;
+    int32_t  ne12;
+    int32_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_solve_tri;
+
 typedef struct {
     int32_t  ne00t;
     int32_t  ne00;
@@ -784,6 +850,25 @@ typedef struct {
     uint64_t nb3;
 } ggml_metal_kargs_set_rows;
 
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_diag;
+
 typedef struct {
     int64_t  ne00;
     int64_t  ne01;
@@ -853,10 +938,6 @@ typedef struct {
     int      max_period;
 } ggml_metal_kargs_timestep_embedding;
 
-typedef struct {
-    float    slope;
-} ggml_metal_kargs_leaky_relu;
-
 typedef struct {
     int32_t  ne00;
     int32_t  ne01;
@@ -915,6 +996,25 @@ typedef struct {
     float    step;
 } ggml_metal_kargs_arange;
 
+typedef struct {
+    int64_t val;
+} ggml_metal_kargs_memset;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+} ggml_metal_kargs_count_equal;
+
 typedef struct {
     int32_t  k0;
     int32_t  k1;
@@ -929,6 +1029,15 @@ typedef struct {
     int64_t  np;
 } ggml_metal_kargs_pool_2d;
 
+typedef struct {
+    int32_t  k0;
+    int32_t  s0;
+    int32_t  p0;
+    int64_t  IW;
+    int64_t  OW;
+    int64_t  np;
+} ggml_metal_kargs_pool_1d;
+
 typedef struct {
      int64_t ne00;
     uint64_t nb01;
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp
index ac5ad53db05..771cb387622 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp
@@ -203,6 +203,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
         GGML_ABORT("unsupported op");
     }
 
+    if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
+        return 1;
+    }
+
     int n_fuse = 1;
 
     // check if the current node can run concurrently with other nodes before it
@@ -283,17 +287,9 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
                 n_fuse = ggml_metal_op_acc(ctx, idx);
             } break;
         case GGML_OP_SCALE:
-            {
-                n_fuse = ggml_metal_op_scale(ctx, idx);
-            } break;
         case GGML_OP_FILL:
-            {
-                n_fuse = ggml_metal_op_fill(ctx, idx);
-            } break;
         case GGML_OP_CLAMP:
-            {
-                n_fuse = ggml_metal_op_clamp(ctx, idx);
-            } break;
+        case GGML_OP_LEAKY_RELU:
         case GGML_OP_SQR:
         case GGML_OP_SQRT:
         case GGML_OP_SIN:
@@ -337,6 +333,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_rwkv(ctx, idx);
             } break;
+        case GGML_OP_SOLVE_TRI:
+            {
+                n_fuse = ggml_metal_op_solve_tri(ctx, idx);
+            } break;
         case GGML_OP_MUL_MAT:
             {
                 n_fuse = ggml_metal_op_mul_mat(ctx, idx);
@@ -353,13 +353,13 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_set_rows(ctx, idx);
             } break;
-        case GGML_OP_L2_NORM:
+        case GGML_OP_DIAG:
             {
-                n_fuse = ggml_metal_op_l2_norm(ctx, idx);
+                n_fuse = ggml_metal_op_diag(ctx, idx);
             } break;
-        case GGML_OP_SOLVE_TRI:
+        case GGML_OP_L2_NORM:
             {
-                n_fuse = ggml_metal_op_solve_tri(ctx, idx);
+                n_fuse = ggml_metal_op_l2_norm(ctx, idx);
             } break;
         case GGML_OP_GROUP_NORM:
             {
@@ -418,10 +418,6 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_top_k(ctx, idx);
             } break;
-        case GGML_OP_LEAKY_RELU:
-            {
-                n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
-            } break;
         case GGML_OP_TRI:
             {
                 n_fuse = ggml_metal_op_tri(ctx, idx);
@@ -430,12 +426,20 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
             } break;
+        case GGML_OP_SET:
+            {
+                n_fuse = ggml_metal_op_set(ctx, idx);
+            } break;
         case GGML_OP_DUP:
         case GGML_OP_CPY:
         case GGML_OP_CONT:
             {
                 n_fuse = ggml_metal_op_cpy(ctx, idx);
             } break;
+        case GGML_OP_POOL_1D:
+            {
+                n_fuse = ggml_metal_op_pool_1d(ctx, idx);
+            } break;
         case GGML_OP_POOL_2D:
             {
                 n_fuse = ggml_metal_op_pool_2d(ctx, idx);
@@ -452,7 +456,11 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
             } break;
-       default:
+        case GGML_OP_COUNT_EQUAL:
+            {
+                n_fuse = ggml_metal_op_count_equal(ctx, idx);
+            } break;
+        default:
             {
                 GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
                 GGML_ABORT("fatal error");
@@ -612,8 +620,8 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
     GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
     GGML_ASSERT(op->type         == GGML_TYPE_F32);
 
-    GGML_ASSERT(ggml_is_contiguous(op->src[0]));
-    GGML_ASSERT(ggml_is_contiguous(op->src[1]));
+    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
+    GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
 
     const size_t pnb1 = ((const int32_t *) op->op_params)[0];
     const size_t pnb2 = ((const int32_t *) op->op_params)[1];
@@ -663,10 +671,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
     }
 
     ggml_metal_kargs_bin args = {
-        /*.ne00 =*/ ne00,
-        /*.ne01 =*/ ne01,
-        /*.ne02 =*/ ne02,
-        /*.ne03 =*/ ne03,
+        /*.ne00 =*/ ne10,
+        /*.ne01 =*/ ne11,
+        /*.ne02 =*/ ne12,
+        /*.ne03 =*/ ne13,
         /*.nb00 =*/ nb00,
         /*.nb01 =*/ pnb1,
         /*.nb02 =*/ pnb2,
@@ -679,10 +687,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
         /*.nb11 =*/ nb11,
         /*.nb12 =*/ nb12,
         /*.nb13 =*/ nb13,
-        /*.ne0  =*/ ne0,
-        /*.ne1  =*/ ne1,
-        /*.ne2  =*/ ne2,
-        /*.ne3  =*/ ne3,
+        /*.ne0  =*/ ne10,
+        /*.ne1  =*/ ne11,
+        /*.ne2  =*/ ne12,
+        /*.ne3  =*/ ne13,
         /*.nb0  =*/ nb0,
         /*.nb1  =*/ pnb1,
         /*.nb2  =*/ pnb2,
@@ -691,7 +699,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
         /*.o1   =*/ { 0 },
     };
 
-    auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false);
+    auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD);
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
@@ -699,53 +707,20 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
     ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
     ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);
 
-    const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
-
-    ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
-
-    return 1;
-}
-
-int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
-    ggml_tensor * op = ctx->node(idx);
-
-    ggml_metal_library_t lib = ctx->lib;
-    ggml_metal_encoder_t enc = ctx->enc;
-
-    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
-    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
+    const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
 
-    float scale;
-    float bias;
-    memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(float));
-    memcpy(&bias,  ((const int32_t *) op->op_params) + 1, sizeof(float));
-
-    ggml_metal_kargs_scale args = {
-        /*.scale =*/ scale,
-        /*.bias  =*/ bias,
-    };
-
-    int64_t n = ggml_nelements(op);
+    int nth = 1;
 
-    if (n % 4 == 0) {
-        n /= 4;
+    while (2*nth < args.ne0 && nth < nth_max) {
+        nth *= 2;
     }
 
-    auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
-
-    ggml_metal_encoder_set_pipeline(enc, pipeline);
-    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
-
-    ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
+    ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
 
     return 1;
 }
 
-int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
+int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
     ggml_tensor * op = ctx->node(idx);
 
     ggml_metal_library_t lib = ctx->lib;
@@ -756,94 +731,80 @@ int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
     GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
-    const float val = ggml_get_op_params_f32(op, 0);
+    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
 
-    ggml_metal_kargs_fill args = {
-        /*.val =*/ val
-    };
+    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
 
-    int64_t n = ggml_nelements(op);
+    ggml_metal_kargs_unary args = {
+        /*.ne00  =*/ ne00,
+        /*.ne01  =*/ ne01,
+        /*.ne02  =*/ ne02,
+        /*.ne03  =*/ ne03,
+        /*.nb00  =*/ nb00,
+        /*.nb01  =*/ nb01,
+        /*.nb02  =*/ nb02,
+        /*.nb03  =*/ nb03,
+        /*.ne0   =*/ ne0,
+        /*.ne1   =*/ ne1,
+        /*.ne2   =*/ ne2,
+        /*.ne3   =*/ ne3,
+        /*.nb0   =*/ nb0,
+        /*.nb1   =*/ nb1,
+        /*.nb2   =*/ nb2,
+        /*.nb3   =*/ nb3,
+        /*.slope =*/ 0.0,
+        /*.scale =*/ 0.0,
+        /*.bias  =*/ 0.0,
+        /*.val   =*/ 0.0,
+        /*.min   =*/ 0.0,
+        /*.max   =*/ 0.0,
+    };
 
-    if (n % 4 == 0) {
-        n /= 4;
+    if (op->op == GGML_OP_LEAKY_RELU) {
+        args.slope = ggml_get_op_params_f32(op, 0);
     }
 
-    auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
-
-    ggml_metal_encoder_set_pipeline(enc, pipeline);
-    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
-
-    ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
-
-    return 1;
-}
-
-int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
-    ggml_tensor * op = ctx->node(idx);
-
-    ggml_metal_library_t lib = ctx->lib;
-    ggml_metal_encoder_t enc = ctx->enc;
-
-    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
-    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
-
-    float min;
-    float max;
-    memcpy(&min, ((const int32_t *) op->op_params) + 0, sizeof(float));
-    memcpy(&max, ((const int32_t *) op->op_params) + 1, sizeof(float));
-
-    ggml_metal_kargs_clamp args = {
-        /*.min =*/ min,
-        /*.max =*/ max,
-    };
+    if (op->op == GGML_OP_SCALE) {
+        args.scale = ggml_get_op_params_f32(op, 0);
+        args.bias  = ggml_get_op_params_f32(op, 1);
+    }
 
-    int64_t n = ggml_nelements(op);
+    if (op->op == GGML_OP_FILL) {
+        args.val = ggml_get_op_params_f32(op, 0);
+    }
 
-    if (n % 4 == 0) {
-        n /= 4;
+    if (op->op == GGML_OP_CLAMP) {
+        args.min = ggml_get_op_params_f32(op, 0);
+        args.max = ggml_get_op_params_f32(op, 1);
     }
 
     auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
 
+    if (pipeline.c4) {
+        args.ne00 = ne00/4;
+        args.ne0  = ne0/4;
+    }
+
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
-
-    ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
-
-    return 1;
-}
+    ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
+    ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
 
-int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
-    ggml_tensor * op = ctx->node(idx);
+    if (pipeline.cnt) {
+        const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
 
-    ggml_metal_library_t lib = ctx->lib;
-    ggml_metal_encoder_t enc = ctx->enc;
+        ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
+    } else {
+        const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
 
-    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
-    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
+        const int nth = MIN(args.ne00, nth_max);
 
-    int64_t n = ggml_nelements(op);
+        const int nk0 = (args.ne00 + nth - 1)/nth;
 
-    if (n % 4 == 0) {
-        n /= 4;
+        ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);
     }
 
-    auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
-
-    ggml_metal_encoder_set_pipeline(enc, pipeline);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 0);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         1);
-
-    ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
-
     return 1;
 }
 
@@ -953,6 +914,11 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
     GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
+    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
+
+    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
+
     ggml_metal_kargs_sum_rows args = {
         /*.ne00 =*/ ne00,
         /*.ne01 =*/ ne01,
@@ -974,21 +940,26 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
 
     auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
 
+    if (pipeline.c4) {
+        args.ne00 = ne00/4;
+        args.ne0  = ne0/4;
+    }
+
     int nth = 32; // SIMD width
 
-    while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+    while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
         nth *= 2;
     }
 
     nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
-    nth = std::min(nth, ne00);
+    nth = std::min(nth, (int) args.ne00);
 
     const size_t smem = pipeline.smem;
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
+    ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
+    ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
 
     ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
 
@@ -1247,6 +1218,48 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
     return 1;
 }
 
+int ggml_metal_op_diag(ggml_metal_op_t ctx, int idx) {
+    ggml_tensor * op = ctx->node(idx);
+
+    ggml_metal_library_t lib = ctx->lib;
+    ggml_metal_encoder_t enc = ctx->enc;
+
+    GGML_TENSOR_LOCALS(int32_t,  ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS(int32_t,  ne, op, ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+    ggml_metal_kargs_diag args = {
+        /*.ne00 =*/ne00,
+        /*.ne01 =*/ne01,
+        /*.ne02 =*/ne02,
+        /*.ne03 =*/ne03,
+        /*.nb00 =*/nb00,
+        /*.nb01 =*/nb01,
+        /*.nb02 =*/nb02,
+        /*.nb03 =*/nb03,
+        /*.ne0  =*/ne0,
+        /*.ne1  =*/ne1,
+        /*.ne2  =*/ne2,
+        /*.ne3  =*/ne3,
+        /*.nb0  =*/nb0,
+        /*.nb1  =*/nb1,
+        /*.nb2  =*/nb2,
+        /*.nb3  =*/nb3,
+    };
+
+    auto pipeline = ggml_metal_library_get_pipeline_diag(lib, op);
+
+    ggml_metal_encoder_set_pipeline(enc, pipeline);
+    ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
+    ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+    ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op),         2);
+
+    ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, 32, 1, 1);
+
+    return 1;
+}
+
 int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
     ggml_tensor * op = ctx->node(idx);
 
@@ -1514,37 +1527,222 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
     ggml_metal_library_t lib = ctx->lib;
     ggml_metal_encoder_t enc = ctx->enc;
 
-    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
-    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
+    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
+
+    const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
+    const int64_t T = op->src[0]->ne[2];
+    const int64_t C = op->ne[0];
+    const int64_t H = op->src[0]->ne[1];
+
+    auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
+
+    int ida = 0;
+
+    ggml_metal_encoder_set_pipeline(enc, pipeline);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[5]), ida++);
+    if (op->op == GGML_OP_RWKV_WKV7) {
+        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[6]), ida++);
+    }
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         ida++);
+    ggml_metal_encoder_set_bytes   (enc, (void *) &B, sizeof(B), ida++);
+    ggml_metal_encoder_set_bytes   (enc, (void *) &T, sizeof(T), ida++);
+    ggml_metal_encoder_set_bytes   (enc, (void *) &C, sizeof(C), ida++);
+    ggml_metal_encoder_set_bytes   (enc, (void *) &H, sizeof(H), ida++);
+
+    ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1);
+
+    return 1;
+}
+
+int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
+    ggml_tensor * op = ctx->node(idx);
+
+    ggml_metal_library_t lib = ctx->lib;
+    ggml_metal_encoder_t enc = ctx->enc;
+
+    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
+
+    ggml_metal_kargs_solve_tri args = {
+        /*.ne00 =*/ ne00,
+        /*.ne01 =*/ ne01,
+        /*.ne02 =*/ ne02,
+        /*.ne03 =*/ ne03,
+        /*.nb00 =*/ nb00,
+        /*.nb01 =*/ nb01,
+        /*.nb02 =*/ nb02,
+        /*.nb03 =*/ nb03,
+        /*.ne10 =*/ ne10,
+        /*.ne11 =*/ ne11,
+        /*.ne12 =*/ ne12,
+        /*.ne13 =*/ ne13,
+        /*.nb10 =*/ nb10,
+        /*.nb11 =*/ nb11,
+        /*.nb12 =*/ nb12,
+        /*.nb13 =*/ nb13,
+        /*.ne0  =*/ ne0,
+        /*.ne1  =*/ ne1,
+        /*.ne2  =*/ ne2,
+        /*.ne3  =*/ ne3,
+        /*.nb0  =*/ nb0,
+        /*.nb1  =*/ nb1,
+        /*.nb2  =*/ nb2,
+        /*.nb3  =*/ nb3,
+    };
+
+    auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
+
+    ggml_metal_encoder_set_pipeline(enc, pipeline);
+    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);
+
+    const int nsg = pipeline.nsg;
+
+    ggml_metal_encoder_set_threadgroup_memory_size(enc, pipeline.smem, 0);
+
+    ggml_metal_encoder_dispatch_threadgroups(enc, (ne10 + nsg - 1)/nsg, ne02, ne03, 32, nsg, 1);
+
+    return 1;
+}
+
+int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) {
+    ggml_tensor * op = ctx->node(idx);
+
+    ggml_metal_library_t lib = ctx->lib;
+    ggml_metal_encoder_t enc = ctx->enc;
+
+    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
+
+    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+    ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
+    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
+
+    const size_t pnb1 = ((const int32_t *) op->op_params)[0];
+    const size_t pnb2 = ((const int32_t *) op->op_params)[1];
+    const size_t pnb3 = ((const int32_t *) op->op_params)[2];
+    const size_t offs = ((const int32_t *) op->op_params)[3];
+
+    const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
+
+    if (!inplace) {
+        // run a separete kernel to cpy src->dst
+        // not sure how to avoid this
+        // TODO: make a simpler cpy_bytes kernel
+
+        //const id pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
+        auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
+
+        ggml_metal_kargs_cpy args = {
+            /*.nk0  =*/ ne00,
+            /*.ne00 =*/ ne00,
+            /*.ne01 =*/ ne01,
+            /*.ne02 =*/ ne02,
+            /*.ne03 =*/ ne03,
+            /*.nb00 =*/ nb00,
+            /*.nb01 =*/ nb01,
+            /*.nb02 =*/ nb02,
+            /*.nb03 =*/ nb03,
+            /*.ne0  =*/ ne0,
+            /*.ne1  =*/ ne1,
+            /*.ne2  =*/ ne2,
+            /*.ne3  =*/ ne3,
+            /*.nb0  =*/ nb0,
+            /*.nb1  =*/ nb1,
+            /*.nb2  =*/ nb2,
+            /*.nb3  =*/ nb3,
+        };
+
+        ggml_metal_encoder_set_pipeline(enc, pipeline);
+        ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
+        ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
+        ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
+
+        const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
+
+        ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
+
+        ggml_metal_op_concurrency_reset(ctx);
+    }
+
+    auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type);
+
+    GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0);
+
+    int64_t nk0 = ne10;
+    if (ggml_is_quantized(op->src[1]->type)) {
+        nk0 = ne10/16;
+    } else if (ggml_is_quantized(op->type)) {
+        nk0 = ne10/ggml_blck_size(op->type);
+    }
+
+    int nth = std::min(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+
+    // when rows are small, we can batch them together in a single threadgroup
+    int nrptg = 1;
+
+    // TODO: relax this constraint in the future
+    if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) {
+        if (nth > nk0) {
+            nrptg = (nth + nk0 - 1)/nk0;
+            nth   = nk0;
+
+            if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+                nrptg--;
+            }
+        }
+    }
+
+    nth = std::min(nth, nk0);
 
-    const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
-    const int64_t T = op->src[0]->ne[2];
-    const int64_t C = op->ne[0];
-    const int64_t H = op->src[0]->ne[1];
+    ggml_metal_kargs_cpy args = {
+        /*.nk0  =*/ nk0,
+        /*.ne00 =*/ ne10,
+        /*.ne01 =*/ ne11,
+        /*.ne02 =*/ ne12,
+        /*.ne03 =*/ ne13,
+        /*.nb00 =*/ nb10,
+        /*.nb01 =*/ nb11,
+        /*.nb02 =*/ nb12,
+        /*.nb03 =*/ nb13,
+        /*.ne0  =*/ ne10,
+        /*.ne1  =*/ ne11,
+        /*.ne2  =*/ ne12,
+        /*.ne3  =*/ ne13,
+        /*.nb0  =*/ ggml_element_size(op),
+        /*.nb1  =*/ pnb1,
+        /*.nb2  =*/ pnb2,
+        /*.nb3  =*/ pnb3,
+    };
 
-    auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
+    const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
 
-    int ida = 0;
+    bid_dst.offs += offs;
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[5]), ida++);
-    if (op->op == GGML_OP_RWKV_WKV7) {
-        ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[6]), ida++);
-    }
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         ida++);
-    ggml_metal_encoder_set_bytes   (enc, (void *) &B, sizeof(B), ida++);
-    ggml_metal_encoder_set_bytes   (enc, (void *) &T, sizeof(T), ida++);
-    ggml_metal_encoder_set_bytes   (enc, (void *) &C, sizeof(C), ida++);
-    ggml_metal_encoder_set_bytes   (enc, (void *) &H, sizeof(H), ida++);
+    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
+    ggml_metal_encoder_set_buffer  (enc, bid_src1, 1);
+    ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
 
-    ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1);
+    ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1);
 
     return 1;
 }
@@ -1622,6 +1820,54 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
     return 1;
 }
 
+int ggml_metal_op_pool_1d(ggml_metal_op_t ctx, int idx) {
+    ggml_tensor * op = ctx->node(idx);
+
+    ggml_metal_library_t lib = ctx->lib;
+    ggml_metal_encoder_t enc = ctx->enc;
+
+    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
+
+    const int32_t * opts = op->op_params;
+    ggml_op_pool op_pool = (ggml_op_pool) opts[0];
+
+    const int32_t k0 = opts[1];
+    const int32_t s0 = opts[2];
+    const int32_t p0 = opts[3];
+
+    const int64_t IW = op->src[0]->ne[0];
+    const int64_t OW = op->ne[0];
+
+    const int64_t np = ggml_nelements(op);
+
+    ggml_metal_kargs_pool_1d args_pool_1d = {
+        /* .k0 = */  k0,
+        /* .s0 = */  s0,
+        /* .p0 = */  p0,
+        /* .IW = */  IW,
+        /* .OW = */  OW,
+        /* .np = */  np
+    };
+
+    auto pipeline = ggml_metal_library_get_pipeline_pool_1d(lib, op, op_pool);
+
+    const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
+    const int ntg = (np + nth - 1) / nth;
+
+    ggml_metal_encoder_set_pipeline(enc, pipeline);
+    ggml_metal_encoder_set_bytes   (enc, &args_pool_1d, sizeof(args_pool_1d),  0);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
+
+    ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1);
+
+    return 1;
+}
+
+
 int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
     ggml_tensor * op = ctx->node(idx);
 
@@ -2182,7 +2428,11 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
 
     const bool has_mask = op->src[3] != nullptr;
 
-    if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
+    // note: the non-vec kernel requires more extra memory, so always reserve for it
+    GGML_ASSERT(OP_FLASH_ATTN_EXT_NCPSG >= OP_FLASH_ATTN_EXT_VEC_NCPSG);
+
+    //if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
+    if (false) {
         // note: always reserve the padding space to avoid graph reallocations
         //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
         const bool has_kvpad = true;
@@ -2236,7 +2486,7 @@ size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
     //    return res;
     //}
 
-    const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
+    const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPSG : OP_FLASH_ATTN_EXT_NQPSG;
     const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
 
     const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
@@ -2352,7 +2602,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
 
     if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
         // half8x8 kernel
-        const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup
+        const int nqptg = OP_FLASH_ATTN_EXT_NQPSG; // queries per threadgroup
         const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
 
         GGML_ASSERT(nqptg <= 32);
@@ -2519,9 +2769,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
 #undef FATTN_SMEM
     } else {
         // half4x4 kernel
-        const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup
+        const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup
         const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
-        const int nkpsg = 1*ncpsg;
+        const int nhptg = 1;                           // heads per threadgroup
 
         GGML_ASSERT(nqptg <= 32);
         GGML_ASSERT(nqptg  % 1  == 0);
@@ -2573,6 +2823,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
             ggml_metal_op_concurrency_reset(ctx);
         }
 
+        // note: for simplicity assume the K is larger or equal than V
+        GGML_ASSERT(ne10 >= ne20);
+
         // ne00 + 2*ncpsg*(nsg)
         // for each query, we load it as f16 in shared memory (ne00)
         // and store the soft_max values and the mask
@@ -2580,28 +2833,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
         // ne20*(nsg)
         // each simdgroup has a full f32 head vector in shared mem to accumulate results
         //
-#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*GGML_PAD(ne20, 128)*(nsg))*(sizeof(float)/2), 16))
-
-        int64_t nsgmax = 2;
-        while (true) {
-            const size_t smem = FATTN_SMEM(nsgmax);
-            // avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
-            if (smem > props_dev->max_theadgroup_memory_size/2) {
-                break;
-            }
-            nsgmax *= 2;
-        }
-        nsgmax /= 2;
-
-        // simdgroups per threadgroup (a.k.a. warps)
-        //const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
-        const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) 1024/32)));
+#define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg))*(sizeof(float)/2), 16))
 
         int64_t nsg = 1;
-        while (nsg <= nsgt) {
-            nsg *= 2;
-        }
-        nsg /= 2;
 
         // workgroups
         // each workgroup handles nsg*nkpsg cache values
@@ -2614,7 +2848,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
         } else {
             nwg = 32;
             nsg = 1;
-            while (2*nwg*nsg*nkpsg < ne11 && nsg < 4) {
+            while (2*nwg*nsg*ncpsg < ne11 && nsg < 4) {
                 nsg *= 2;
             }
         }
@@ -2680,7 +2914,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
 
             ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
 
-            ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
+            ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
         } else {
             // sanity checks
             assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
@@ -2693,7 +2927,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
             ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
 
             ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
-            ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
+            ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
 
             // sync the 2 kernels
             ggml_metal_op_concurrency_reset(ctx);
@@ -2745,8 +2979,6 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
     GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
     GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
 
-    bool bcast_row = false;
-
     ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
     ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
     ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
@@ -2840,18 +3072,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
 
     struct ggml_metal_pipeline_with_params pipeline;
 
-    if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) {
-        GGML_ASSERT(ggml_is_contiguous(op->src[0]));
-
-        // src1 is a row
-        GGML_ASSERT(ne11 == 1);
-
-        pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, true);
-
-        bcast_row = true;
-    } else {
-        pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, false);
-    }
+    pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse);
 
     if (n_fuse > 1) {
         bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
@@ -2865,20 +3086,28 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
         }
     }
 
+    if (pipeline.c4) {
+        args.ne00 = ne00/4;
+        args.ne10 = ne10/4;
+        args.ne0  = ne0/4;
+    }
+
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
     ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
     ggml_metal_encoder_set_buffer  (enc, bid_src1, 2);
     ggml_metal_encoder_set_buffer  (enc, bid_dst,  3);
 
-    if (bcast_row) {
-        const int64_t n = ggml_nelements(op)/4;
+    if (pipeline.cnt) {
+        const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
 
         ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
     } else {
-        int nth = 32;
+        const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
 
-        while (16*nth < ne0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+        int nth = 1;
+
+        while (2*nth < args.ne0 && nth < nth_max) {
             nth *= 2;
         }
 
@@ -2899,98 +3128,59 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
     GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
+    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
+
+    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
+
     float eps;
     memcpy(&eps, op->op_params, sizeof(float));
 
-    int nth = 32; // SIMD width
-
     ggml_metal_kargs_l2_norm args = {
-        /*.ne00   =*/ ne00,
-        /*.ne00_4 =*/ ne00/4,
-        /*.nb01   =*/ nb01,
-        /*.eps    =*/ eps,
+        /*.ne00  =*/ ne00,
+        /*.ne01  =*/ ne01,
+        /*.ne02  =*/ ne02,
+        /*.ne03  =*/ ne03,
+        /*.nb00  =*/ nb00,
+        /*.nb01  =*/ nb01,
+        /*.nb02  =*/ nb02,
+        /*.nb03  =*/ nb03,
+        /*.ne0   =*/ ne0,
+        /*.ne1   =*/ ne1,
+        /*.ne2   =*/ ne2,
+        /*.ne3   =*/ ne3,
+        /*.nb0   =*/ nb0,
+        /*.nb1   =*/ nb1,
+        /*.nb2   =*/ nb2,
+        /*.nb3   =*/ nb3,
+        /*.eps   =*/ eps,
     };
 
     auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
 
-    while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+    if (pipeline.c4) {
+        args.ne00 = ne00/4;
+        args.ne0  = ne0/4;
+    }
+
+    int nth = 32; // SIMD width
+
+    while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
         nth *= 2;
     }
 
     nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
-    nth = std::min(nth, ne00/4);
 
     const size_t smem = pipeline.smem;
 
-    const int64_t nrows = ggml_nrows(op->src[0]);
-
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
+    ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
+    ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
 
     ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
 
-    ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
-
-    return 1;
-}
-
-int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
-    ggml_tensor * op = ctx->node(idx);
-
-    ggml_metal_library_t lib = ctx->lib;
-    ggml_metal_encoder_t enc = ctx->enc;
-
-    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
-    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
-    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
-
-    ggml_metal_kargs_solve_tri args = {
-        /*.ne00 =*/ ne00,
-        /*.ne01 =*/ ne01,
-        /*.ne02 =*/ ne02,
-        /*.ne03 =*/ ne03,
-        /*.nb00 =*/ nb00,
-        /*.nb01 =*/ nb01,
-        /*.nb02 =*/ nb02,
-        /*.nb03 =*/ nb03,
-        /*.ne10 =*/ ne10,
-        /*.ne11 =*/ ne11,
-        /*.nb10 =*/ nb10,
-        /*.nb11 =*/ nb11,
-        /*.nb12 =*/ nb12,
-        /*.nb13 =*/ nb13,
-        /*.nb0  =*/ nb0,
-        /*.nb1  =*/ nb1,
-        /*.nb2  =*/ nb2,
-        /*.nb3  =*/ nb3,
-    };
-
-    auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
-
-    const int64_t ncols = ne10;
-    const int64_t n_batches = (int64_t)ne02 * ne03;
-    const int64_t nr = n_batches * ncols;
-
-    int nth = 64;
-    nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
-    if (nth < 1) {
-        nth = 1;
-    }
-
-    const int64_t n_tg = (nr + nth - 1) / nth;
-
-    ggml_metal_encoder_set_pipeline(enc, pipeline);
-    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);
-
-    ggml_metal_encoder_dispatch_threadgroups(enc, n_tg, 1, 1, nth, 1, 1);
+    ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
 
     return 1;
 }
@@ -3998,42 +4188,6 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
     return 1;
 }
 
-int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
-    ggml_tensor * op = ctx->node(idx);
-
-    ggml_metal_library_t lib = ctx->lib;
-    ggml_metal_encoder_t enc = ctx->enc;
-
-    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
-    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
-
-    float slope;
-    memcpy(&slope, op->op_params, sizeof(float));
-
-    ggml_metal_kargs_leaky_relu args = {
-        /*.slope =*/ slope
-    };
-
-    auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
-
-    int64_t n = ggml_nelements(op);
-
-    if (n % 4 == 0) {
-        n /= 4;
-    }
-
-    ggml_metal_encoder_set_pipeline(enc, pipeline);
-    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
-
-    ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
-
-    return 1;
-}
-
 int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
     ggml_tensor * op = ctx->node(idx);
 
@@ -4154,3 +4308,64 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
 
     return 1;
 }
+
+int ggml_metal_op_count_equal(ggml_metal_op_t ctx, int idx) {
+    ggml_tensor * op = ctx->node(idx);
+
+    ggml_metal_library_t lib = ctx->lib;
+    ggml_metal_encoder_t enc = ctx->enc;
+
+    GGML_TENSOR_LOCALS(int32_t,  ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+
+    {
+        ggml_metal_kargs_memset args = { /*.val =*/ 0 };
+
+        auto pipeline = ggml_metal_library_get_pipeline_memset(lib, op);
+
+        ggml_metal_encoder_set_pipeline(enc, pipeline);
+        ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
+        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 1);
+
+        ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
+    }
+
+    ggml_metal_op_concurrency_reset(ctx);
+
+    {
+        ggml_metal_kargs_count_equal args = {
+            /*.ne00 =*/ ne00,
+            /*.ne01 =*/ ne01,
+            /*.ne02 =*/ ne02,
+            /*.ne03 =*/ ne03,
+            /*.nb00 =*/ nb00,
+            /*.nb01 =*/ nb01,
+            /*.nb02 =*/ nb02,
+            /*.nb03 =*/ nb03,
+            /*.nb10 =*/ nb10,
+            /*.nb11 =*/ nb11,
+            /*.nb12 =*/ nb12,
+            /*.nb13 =*/ nb13,
+        };
+
+        auto pipeline = ggml_metal_library_get_pipeline_count_equal(lib, op);
+
+        const size_t smem = pipeline.smem;
+
+        const int nth = 32*pipeline.nsg;
+
+        GGML_ASSERT(nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+
+        ggml_metal_encoder_set_pipeline(enc, pipeline);
+        ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
+        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+        ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
+
+        ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
+        ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
+    }
+
+    return 1;
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h
index a475183d367..f3e38c7aa9d 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h
@@ -46,9 +46,6 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
 int ggml_metal_op_concat            (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_repeat            (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_acc               (ggml_metal_op_t ctx, int idx);
-int ggml_metal_op_scale             (ggml_metal_op_t ctx, int idx);
-int ggml_metal_op_fill              (ggml_metal_op_t ctx, int idx);
-int ggml_metal_op_clamp             (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_unary             (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_glu               (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_sum               (ggml_metal_op_t ctx, int idx);
@@ -56,11 +53,15 @@ int ggml_metal_op_sum_rows          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_cumsum            (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_get_rows          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_set_rows          (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_diag              (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_soft_max          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_ssm_conv          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_ssm_scan          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_rwkv              (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_solve_tri         (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_set               (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_cpy               (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_pool_1d           (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_pool_2d           (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_mul_mat           (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_mul_mat_id        (ggml_metal_op_t ctx, int idx);
@@ -68,7 +69,6 @@ int ggml_metal_op_add_id            (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_flash_attn_ext    (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_bin               (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_l2_norm           (ggml_metal_op_t ctx, int idx);
-int ggml_metal_op_solve_tri         (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_group_norm        (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_norm              (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_rope              (ggml_metal_op_t ctx, int idx);
@@ -84,10 +84,10 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_argmax            (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_argsort           (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_top_k             (ggml_metal_op_t ctx, int idx);
-int ggml_metal_op_leaky_relu        (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_tri               (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_opt_step_adamw    (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_opt_step_sgd      (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_count_equal       (ggml_metal_op_t ctx, int idx);
 
 #ifdef __cplusplus
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.cpp b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.cpp
index f6f8f7a106a..128c10fa028 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.cpp
@@ -7,11 +7,15 @@
 #include "ggml-metal-context.h"
 #include "ggml-metal-ops.h"
 
-// globals
+#include 
+#include 
 
-// initialized in ggml_backend_metal_reg
-static ggml_backend_reg    g_ggml_metal_reg;
-static ggml_backend_device g_ggml_metal_device;
+#define GGML_METAL_NAME "MTL"
+#define GGML_METAL_MAX_DEVICES 16
+
+// number of Metal devices
+// note: can be overriden with GGML_METAL_DEVICES env to simulate virtual devices
+static int g_devices = 1;
 
 ////////////////////////////////////////////////////////////////////////////////
 // backend interface
@@ -167,10 +171,28 @@ static ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = {
     /* .reset           = */ NULL,
 };
 
+static bool ggml_backend_buffer_is_metal(ggml_backend_buffer_t buffer) {
+    return buffer->iface.free_buffer == ggml_backend_metal_buffer_shared_free_buffer ||
+           buffer->iface.free_buffer == ggml_backend_metal_buffer_private_free_buffer;
+}
+
 //
 // buffer types
 //
 
+struct ggml_backend_metal_buffer_type {
+    int device;
+    std::string name;
+};
+
+struct ggml_backend_metal_buffer_type_deleter {
+    void operator()(ggml_backend_metal_buffer_type * ctx) const {
+        delete ctx;
+    }
+};
+
+typedef std::unique_ptr ggml_backend_metal_buffer_type_ptr;
+
 // common method for allocating shread or private Metal buffers
 static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size, bool shared) {
     ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context;
@@ -220,9 +242,9 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
 // default (shared) buffer type
 
 static const char * ggml_backend_metal_buffer_type_shared_get_name(ggml_backend_buffer_type_t buft) {
-    return "Metal";
+    ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context;
 
-    GGML_UNUSED(buft);
+    return ctx->name.c_str();
 }
 
 static ggml_backend_buffer_t ggml_backend_metal_buffer_type_shared_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
@@ -251,29 +273,54 @@ static bool ggml_backend_metal_buffer_type_shared_is_host(ggml_backend_buffer_ty
     GGML_UNUSED(buft);
 }
 
-static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(void) {
-    static ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
-        /* .iface = */ {
-            /* .get_name         = */ ggml_backend_metal_buffer_type_shared_get_name,
-            /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_shared_alloc_buffer,
-            /* .get_alignment    = */ ggml_backend_metal_buffer_type_shared_get_alignment,
-            /* .get_max_size     = */ ggml_backend_metal_buffer_type_shared_get_max_size,
-            /* .get_alloc_size   = */ ggml_backend_metal_buffer_type_shared_get_alloc_size,
-            /* .is_host          = */ ggml_backend_metal_buffer_type_shared_is_host,
-        },
-        /* .device  = */ &g_ggml_metal_device,
-        /* .context = */ NULL,
-    };
+static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(int device) {
+    static std::mutex mutex;
+    std::lock_guard lock(mutex);
+
+    static std::vector bufts;
+    static std::vector ctxs;
+
+    static bool initialized = false;
+    if (!initialized) {
+        bufts.reserve(g_devices);
+        ctxs.reserve(g_devices);
+
+        for (int i = 0; i < g_devices; ++i) {
+            ggml_backend_metal_buffer_type * raw_ctx =
+                new ggml_backend_metal_buffer_type {
+                    /* .device = */ i,
+                    /* .name   = */ GGML_METAL_NAME + std::to_string(i),
+                };
+            ctxs.emplace_back(raw_ctx);
+
+            ggml_backend_buffer_type buft = {
+                /* .iface = */ {
+                    /* .get_name         = */ ggml_backend_metal_buffer_type_shared_get_name,
+                    /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_shared_alloc_buffer,
+                    /* .get_alignment    = */ ggml_backend_metal_buffer_type_shared_get_alignment,
+                    /* .get_max_size     = */ ggml_backend_metal_buffer_type_shared_get_max_size,
+                    /* .get_alloc_size   = */ ggml_backend_metal_buffer_type_shared_get_alloc_size,
+                    /* .is_host          = */ ggml_backend_metal_buffer_type_shared_is_host,
+                },
+                /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i),
+                /* .context = */ raw_ctx,
+            };
+
+            bufts.emplace_back(buft);
+        }
+
+        initialized = true;
+    }
 
-    return &ggml_backend_buffer_type_metal;
+    return &bufts[device];
 }
 
 // default (private) buffer type
 
 static const char * ggml_backend_metal_buffer_type_private_get_name(ggml_backend_buffer_type_t buft) {
-    return "Metal_Private";
+    ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context;
 
-    GGML_UNUSED(buft);
+    return ctx->name.c_str();
 }
 
 static ggml_backend_buffer_t ggml_backend_metal_buffer_type_private_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
@@ -302,29 +349,53 @@ static bool ggml_backend_metal_buffer_type_private_is_host(ggml_backend_buffer_t
     GGML_UNUSED(buft);
 }
 
-static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(void) {
-    static ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
-        /* .iface = */ {
-            /* .get_name         = */ ggml_backend_metal_buffer_type_private_get_name,
-            /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_private_alloc_buffer,
-            /* .get_alignment    = */ ggml_backend_metal_buffer_type_private_get_alignment,
-            /* .get_max_size     = */ ggml_backend_metal_buffer_type_private_get_max_size,
-            /* .get_alloc_size   = */ ggml_backend_metal_buffer_type_private_get_alloc_size,
-            /* .is_host          = */ ggml_backend_metal_buffer_type_private_is_host,
-        },
-        /* .device  = */ &g_ggml_metal_device,
-        /* .context = */ NULL,
-    };
+static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(int device) {
+    static std::mutex mutex;
+    std::lock_guard lock(mutex);
+
+    static std::vector bufts;
+    static std::vector ctxs;
+
+    static bool initialized = false;
+    if (!initialized) {
+        bufts.reserve(g_devices);
+        ctxs.reserve(g_devices);
+
+        for (int i = 0; i < g_devices; ++i) {
+            ggml_backend_metal_buffer_type * raw_ctx = new ggml_backend_metal_buffer_type{
+                /* .device = */ i,
+                /* .name   = */ GGML_METAL_NAME + std::to_string(i) + "_Private"
+            };
+            ctxs.emplace_back(raw_ctx);
+
+            ggml_backend_buffer_type buft = {
+                /* .iface = */ {
+                    /* .get_name         = */ ggml_backend_metal_buffer_type_private_get_name,
+                    /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_private_alloc_buffer,
+                    /* .get_alignment    = */ ggml_backend_metal_buffer_type_private_get_alignment,
+                    /* .get_max_size     = */ ggml_backend_metal_buffer_type_private_get_max_size,
+                    /* .get_alloc_size   = */ ggml_backend_metal_buffer_type_private_get_alloc_size,
+                    /* .is_host          = */ ggml_backend_metal_buffer_type_private_is_host,
+                },
+                /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i),
+                /* .context = */ raw_ctx,
+            };
+
+            bufts.emplace_back(buft);
+        }
+
+        initialized = true;
+    }
 
-    return &ggml_backend_buffer_type_metal;
+    return &bufts[device];
 }
 
 // mapped buffer type
 
 static const char * ggml_backend_metal_buffer_type_mapped_get_name(ggml_backend_buffer_type_t buft) {
-    return "Metal_Mapped";
+    ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context;
 
-    GGML_UNUSED(buft);
+    return ctx->name.c_str();
 }
 
 static ggml_backend_buffer_t ggml_backend_metal_buffer_type_mapped_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
@@ -354,31 +425,55 @@ static bool ggml_backend_metal_buffer_type_mapped_is_host(ggml_backend_buffer_ty
     GGML_UNUSED(buft);
 }
 
-static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(void) {
-    // note: not obvious, but this buffer type still needs to implement .alloc_buffer:
-    //       https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2333177099
-    static ggml_backend_buffer_type ggml_backend_buffer_type_mapped_metal = {
-        /* .iface = */ {
-            /* .get_name         = */ ggml_backend_metal_buffer_type_mapped_get_name,
-            /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer,
-            /* .get_alignment    = */ ggml_backend_metal_buffer_type_mapped_get_alignment,
-            /* .get_max_size     = */ ggml_backend_metal_buffer_type_mapped_get_max_size,
-            /* .get_alloc_size   = */ ggml_backend_metal_buffer_type_mapped_get_alloc_size,
-            /* .is_host          = */ ggml_backend_metal_buffer_type_mapped_is_host,
-        },
-        /* .device  = */ &g_ggml_metal_device,
-        /* .context = */ NULL,
-    };
+static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(int device) {
+    static std::mutex mutex;
+    std::lock_guard lock(mutex);
+
+    static std::vector bufts;
+    static std::vector ctxs;
+
+    static bool initialized = false;
+    if (!initialized) {
+        bufts.reserve(g_devices);
+        ctxs.reserve(g_devices);
+
+        for (int i = 0; i < g_devices; ++i) {
+            ggml_backend_metal_buffer_type * raw_ctx = new ggml_backend_metal_buffer_type{
+                /* .device = */ i,
+                /* .name   = */ GGML_METAL_NAME + std::to_string(i) + "_Mapped"
+            };
+            ctxs.emplace_back(raw_ctx);
+
+            // note: not obvious, but this buffer type still needs to implement .alloc_buffer:
+            //       https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2333177099
+            ggml_backend_buffer_type buft = {
+                /* .iface = */ {
+                    /* .get_name         = */ ggml_backend_metal_buffer_type_mapped_get_name,
+                    /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer,
+                    /* .get_alignment    = */ ggml_backend_metal_buffer_type_mapped_get_alignment,
+                    /* .get_max_size     = */ ggml_backend_metal_buffer_type_mapped_get_max_size,
+                    /* .get_alloc_size   = */ ggml_backend_metal_buffer_type_mapped_get_alloc_size,
+                    /* .is_host          = */ ggml_backend_metal_buffer_type_mapped_is_host,
+                },
+                /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i),
+                /* .context = */ raw_ctx,
+            };
+
+            bufts.emplace_back(buft);
+        }
+
+        initialized = true;
+    }
 
-    return &ggml_backend_buffer_type_mapped_metal;
+    return &bufts[device];
 }
 
 // backend
 
 static const char * ggml_backend_metal_name(ggml_backend_t backend) {
-    return "Metal";
+    ggml_metal_t ctx = (ggml_metal_t)backend->context;
 
-    GGML_UNUSED(backend);
+    return ggml_metal_get_name(ctx);
 }
 
 static void ggml_backend_metal_free(ggml_backend_t backend) {
@@ -411,20 +506,46 @@ static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const gg
 }
 
 static bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
-    return false;
+    if (!ggml_backend_is_metal(backend_src) || !ggml_backend_is_metal(backend_dst)) {
+        return false;
+    }
 
-    GGML_UNUSED(backend_src);
-    GGML_UNUSED(backend_dst);
-    GGML_UNUSED(src);
-    GGML_UNUSED(dst);
+    if (!ggml_backend_buffer_is_metal(src->buffer) || !ggml_backend_buffer_is_metal(dst->buffer)) {
+        return false;
+    }
+
+    ggml_metal_t ctx_src = (ggml_metal_t)backend_src->context;
+    ggml_metal_t ctx_dst = (ggml_metal_t)backend_dst->context;
+
+    //ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;
+    //ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;
+
+    //ggml_metal_buffer_t buf_ctx_src = (ggml_metal_buffer_t)buf_src->context;
+    //ggml_metal_buffer_t buf_ctx_dst = (ggml_metal_buffer_t)buf_dst->context;
+
+    return ggml_metal_cpy_tensor_async(ctx_src, ctx_dst, src, dst);
 }
 
 static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) {
     ggml_metal_t ctx = (ggml_metal_t)backend->context;
 
+    GGML_UNUSED(batch_size);
+
     return ggml_metal_graph_compute(ctx, cgraph);
+}
 
-    GGML_UNUSED(batch_size);
+static void ggml_backend_metal_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
+    ggml_metal_t ctx = (ggml_metal_t)backend->context;
+    ggml_metal_event_t ev = (ggml_metal_event_t)event->context;
+
+    ggml_metal_event_record(ctx, ev);
+}
+
+static void ggml_backend_metal_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
+    ggml_metal_t ctx = (ggml_metal_t)backend->context;
+    ggml_metal_event_t ev = (ggml_metal_event_t)event->context;
+
+    ggml_metal_event_wait(ctx, ev);
 }
 
 static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
@@ -439,7 +560,6 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
     ggml_metal_t ctx = (ggml_metal_t)backend->context;
 
     ggml_metal_set_n_cb(ctx, n_cb);
-
 }
 
 static ggml_backend_i ggml_backend_metal_i = {
@@ -454,12 +574,8 @@ static ggml_backend_i ggml_backend_metal_i = {
     /* .graph_plan_update       = */ NULL,
     /* .graph_plan_compute      = */ NULL,
     /* .graph_compute           = */ ggml_backend_metal_graph_compute,
-
-    // the events API is needed only for multi-GPU setups, so likely no need to implement it for Metal
-    // in any case, these docs seem relevant if we ever decide to implement it:
-    // https://developer.apple.com/documentation/metal/mtlcommandbuffer#Synchronizing-Passes-with-Events
-    /* .event_record            = */ NULL,
-    /* .event_wait              = */ NULL,
+    /* .event_record            = */ ggml_backend_metal_event_record,
+    /* .event_wait              = */ ggml_backend_metal_event_wait,
     /* .graph_optimize          = */ ggml_backend_metal_graph_optimize,
 };
 
@@ -523,15 +639,17 @@ void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
 // backend device
 
 static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
-    return "Metal";
+    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
 
-    GGML_UNUSED(dev);
+    const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev);
+
+    return props_dev->name;
 }
 
 static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
     ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
 
-    return ggml_metal_device_get_props(ctx_dev)->name;
+    return ggml_metal_device_get_props(ctx_dev)->desc;
 }
 
 static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
@@ -557,14 +675,14 @@ static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_bac
 
     props->library = GGML_METAL_NAME;
     props->caps = {
-        /* .async                 = */ true,
-        /* .host_buffer           = */ false,
-        /* .buffer_from_host_ptr  = */ true,
-        /* .events                = */ false,
+        /* .async                = */ true,
+        /* .host_buffer          = */ false,
+        /* .buffer_from_host_ptr = */ true,
+        /* .events               = */ true,
     };
 }
 
-static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) {
+static ggml_backend_t ggml_backend_metal_device_init_backend(ggml_backend_dev_t dev, const char * params) {
     ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
 
     ggml_metal_t ctx = ggml_metal_init(ctx_dev);
@@ -594,7 +712,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml
 
     const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev);
 
-    return props_dev->use_shared_buffers ? ggml_backend_metal_buffer_type_shared() : ggml_backend_metal_buffer_type_private();
+    return props_dev->use_shared_buffers ? ggml_backend_metal_buffer_type_shared(props_dev->device) : ggml_backend_metal_buffer_type_private(props_dev->device);
 }
 
 static ggml_backend_buffer_t ggml_backend_metal_device_buffer_mapped(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
@@ -602,7 +720,9 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_mapped(ggml_backen
 
     ggml_metal_buffer_t res = ggml_metal_buffer_map(ctx_dev, ptr, size, max_tensor_size);
 
-    return ggml_backend_buffer_init(ggml_backend_metal_buffer_type_mapped(), ggml_backend_metal_buffer_shared_i, res, size);
+    const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev);
+
+    return ggml_backend_buffer_init(ggml_backend_metal_buffer_type_mapped(props_dev->device), ggml_backend_metal_buffer_shared_i, res, size);
 }
 
 static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
@@ -613,9 +733,10 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const
 
 static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
     return
+        buft->device == dev && (
         buft->iface.get_name == ggml_backend_metal_buffer_type_shared_get_name ||
         buft->iface.get_name == ggml_backend_metal_buffer_type_private_get_name ||
-        buft->iface.get_name == ggml_backend_metal_buffer_type_mapped_get_name;
+        buft->iface.get_name == ggml_backend_metal_buffer_type_mapped_get_name);
 
     GGML_UNUSED(dev);
 }
@@ -632,14 +753,43 @@ static int64_t get_op_batch_size(const ggml_tensor * op) {
 }
 
 static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
-    const int min_batch_size = 32;
+    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
 
     return (op->op == GGML_OP_MUL_MAT ||
             op->op == GGML_OP_MUL_MAT_ID) &&
-            get_op_batch_size(op) >= min_batch_size;
+            get_op_batch_size(op) >= ggml_metal_device_get_props(ctx_dev)->op_offload_min_batch_size;
+}
 
-    GGML_UNUSED(dev);
-    GGML_UNUSED(op);
+static ggml_backend_event_t ggml_backend_metal_device_event_new(ggml_backend_dev_t dev) {
+    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
+
+    ggml_metal_event_t event = ggml_metal_device_event_init(ctx_dev);
+    GGML_ASSERT(event);
+
+    ggml_backend_event_t ev = new ggml_backend_event {
+        /* .device  = */ dev,
+        /* .context = */ event,
+    };
+
+    return ev;
+}
+
+static void ggml_backend_metal_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
+    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
+
+    ggml_metal_event_t ev = (ggml_metal_event_t)event->context;
+
+    ggml_metal_device_event_free(ctx_dev, ev);
+
+    delete event;
+}
+
+static void ggml_backend_metal_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
+    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
+
+    ggml_metal_event_t evt = (ggml_metal_event_t)event->context;
+
+    ggml_metal_device_event_synchronize(ctx_dev, evt);
 }
 
 static ggml_backend_device_i ggml_backend_metal_device_i = {
@@ -648,39 +798,59 @@ static ggml_backend_device_i ggml_backend_metal_device_i = {
     /* .get_memory           = */ ggml_backend_metal_device_get_memory,
     /* .get_type             = */ ggml_backend_metal_device_get_type,
     /* .get_props            = */ ggml_backend_metal_device_get_props,
-    /* .init_backend         = */ ggml_backend_metal_device_init,
+    /* .init_backend         = */ ggml_backend_metal_device_init_backend,
     /* .get_buffer_type      = */ ggml_backend_metal_device_get_buffer_type,
     /* .get_host_buffer_type = */ NULL,
     /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_mapped,
     /* .supports_op          = */ ggml_backend_metal_device_supports_op,
     /* .supports_buft        = */ ggml_backend_metal_device_supports_buft,
     /* .offload_op           = */ ggml_backend_metal_device_offload_op,
-    /* .event_new            = */ NULL,
-    /* .event_free           = */ NULL,
-    /* .event_synchronize    = */ NULL,
+    /* .event_new            = */ ggml_backend_metal_device_event_new,
+    /* .event_free           = */ ggml_backend_metal_device_event_free,
+    /* .event_synchronize    = */ ggml_backend_metal_device_event_synchronize,
 };
 
 // backend registry
 
+struct ggml_backend_metal_reg {
+    std::vector devices;
+};
+
+typedef struct ggml_backend_metal_reg * ggml_backend_metal_reg_t;
+
+static ggml_backend_metal_reg_t ggml_backend_metal_reg_init(void) {
+    ggml_backend_metal_reg_t ctx = new struct ggml_backend_metal_reg;
+
+    return ctx;
+}
+
+static void ggml_backend_metal_reg_free(ggml_backend_metal_reg_t ctx) {
+    delete ctx;
+}
+
+struct ggml_backend_metal_reg_deleter {
+    void operator()(ggml_backend_metal_reg_t ctx) {
+        ggml_backend_metal_reg_free(ctx);
+    }
+};
+
+typedef std::unique_ptr ggml_backend_metal_reg_ptr;
+
 static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) {
-    return "Metal";
+    return GGML_METAL_NAME;
 
     GGML_UNUSED(reg);
 }
 
 static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) {
-    return 1;
-
-    GGML_UNUSED(reg);
+    ggml_backend_metal_reg_t ctx = (ggml_backend_metal_reg_t)reg->context;
+    return ctx->devices.size();
 }
 
 static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) {
-    GGML_ASSERT(index == 0);
-
-    return &g_ggml_metal_device;
-
-    GGML_UNUSED(reg);
-    GGML_UNUSED(index);
+    ggml_backend_metal_reg_t ctx = (ggml_backend_metal_reg_t)reg->context;
+    GGML_ASSERT(index < ctx->devices.size());
+    return ctx->devices[index];
 }
 
 static ggml_backend_feature g_ggml_backend_metal_features[] = {
@@ -708,27 +878,67 @@ static void * ggml_backend_metal_get_proc_address(ggml_backend_reg_t reg, const
 
 static ggml_backend_reg_i ggml_backend_metal_reg_i = {
     /* .get_name         = */ ggml_backend_metal_reg_get_name,
-    /* .device_count     = */ ggml_backend_metal_reg_device_count,
-    /* .device_get       = */ ggml_backend_metal_reg_device_get,
+    /* .get_device_count = */ ggml_backend_metal_reg_device_count,
+    /* .get_device       = */ ggml_backend_metal_reg_device_get,
     /* .get_proc_address = */ ggml_backend_metal_get_proc_address,
 };
 
+static ggml_backend_dev_t ggml_backend_metal_device_init(ggml_backend_reg_t reg, int device) {
+    return new ggml_backend_device {
+        /* .iface   = */ ggml_backend_metal_device_i,
+        /* .reg     = */ reg,
+        /* .context = */ ggml_metal_device_get(device),
+    };
+}
+
+static void ggml_backend_metal_device_free(ggml_backend_dev_t dev) {
+    delete dev;
+}
+
+struct ggml_backend_device_deleter {
+    void operator()(ggml_backend_dev_t ctx) {
+        ggml_backend_metal_device_free(ctx);
+    }
+};
+
+typedef std::unique_ptr ggml_backend_device_ptr;
+
 ggml_backend_reg_t ggml_backend_metal_reg(void) {
+    static ggml_backend_reg reg;
+    static bool initialized = false;
+
     {
-        g_ggml_metal_reg = {
-            /* .api_version = */ GGML_BACKEND_API_VERSION,
-            /* .iface       = */ ggml_backend_metal_reg_i,
-            /* .context     = */ NULL,
-        };
-
-        g_ggml_metal_device = {
-            /* .iface   = */ ggml_backend_metal_device_i,
-            /* .reg     = */ &g_ggml_metal_reg,
-            /* .context = */ ggml_metal_device_get(),
-        };
+        static std::mutex mutex;
+        std::lock_guard lock(mutex);
+
+        const char * env = getenv("GGML_METAL_DEVICES");
+        if (env) {
+            g_devices = atoi(env);
+        }
+
+        static std::vector devs;
+
+        if (!initialized) {
+            static ggml_backend_metal_reg_ptr reg_ctx(ggml_backend_metal_reg_init());
+
+            for (int i = 0; i < g_devices; ++i) {
+                auto * dev = ggml_backend_metal_device_init(®, i);
+                devs.emplace_back(dev);
+
+                reg_ctx->devices.push_back(dev);
+            }
+
+            reg = {
+                /* .api_version = */ GGML_BACKEND_API_VERSION,
+                /* .iface       = */ ggml_backend_metal_reg_i,
+                /* .context     = */ reg_ctx.get(),
+            };
+        }
+
+        initialized = true;
     }
 
-    return &g_ggml_metal_reg;
+    return ®
 }
 
 GGML_BACKEND_DL_IMPL(ggml_backend_metal_reg)
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
index 4f338aa1356..36b5a6812cd 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
@@ -77,6 +77,14 @@ static inline float dot(float x, float y) {
     return x*y;
 }
 
+static inline float sum(float x) {
+    return x;
+}
+
+static inline float sum(float4 x) {
+    return x[0] + x[1] + x[2] + x[3];
+}
+
 // NOTE: this is not dequantizing - we are simply fitting the template
 template 
 void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@@ -895,751 +903,428 @@ enum ggml_sort_order {
     GGML_SORT_ORDER_DESC,
 };
 
-// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
-// pros: works for non-contiguous tensors, supports broadcast across all dims
-// cons: not very efficient
-template 
-kernel void kernel_add_fuse_impl(
-        constant ggml_metal_kargs_bin & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort3 tpitg[[thread_position_in_threadgroup]],
-        ushort3   ntg[[threads_per_threadgroup]]) {
-    const int i03 = tgpig.z;
-    const int i02 = tgpig.y;
-    const int i01 = tgpig.x;
+constant float GELU_COEF_A     = 0.044715f;
+constant float GELU_QUICK_COEF = -1.702f;
+constant float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
+constant float SQRT_2_INV      = 0.70710678118654752440084436210484f;
 
-    const int i13 = i03%args.ne13;
-    const int i12 = i02%args.ne12;
-    const int i11 = i01%args.ne11;
+// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
+// ref: https://www.johndcook.com/blog/python_erf/
+constant float p_erf  = 0.3275911f;
+constant float a1_erf = 0.254829592f;
+constant float a2_erf = -0.284496736f;
+constant float a3_erf = 1.421413741f;
+constant float a4_erf = -1.453152027f;
+constant float a5_erf = 1.061405429f;
 
-    device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
-    device       float * dst_ptr  = (device       float *) (dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs);
+template
+inline T erf_approx(T x) {
+    T sign_x = sign(x);
+    x = fabs(x);
+    T t = 1.0f / (1.0f + p_erf * x);
+    T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
+    return sign_x * y;
+}
 
-    device const float * src1_ptr[F];
-    for (short j = 0; j < F; ++j) {
-        src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
-    }
+template T elu_approx(T x);
 
-    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-        const int i10 = i0%args.ne10;
+template<> inline float elu_approx(float x) {
+    return (x > 0.f) ? x : (exp(x) - 1);
+}
 
-        float res = src0_ptr[i0];
+template<> inline float4 elu_approx(float4 x) {
+    float4 res;
 
-#pragma unroll
-        for (short j = 0; j < F; ++j) {
-            res += src1_ptr[j][i10];
-        }
+    res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
+    res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
+    res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
+    res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
 
-        dst_ptr[i0] = res;
-    }
+    return res;
 }
 
-typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t;
+constant short FC_unary_op [[function_constant(FC_UNARY + 0)]];
+constant bool  FC_unary_cnt[[function_constant(FC_UNARY + 1)]];
 
-template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
-template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
-template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
-template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
-template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>;
-template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>;
-template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
-template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
-
-kernel void kernel_sub_fuse_1(
-        constant ggml_metal_kargs_bin & args,
+template 
+kernel void kernel_unary_impl(
+        constant ggml_metal_kargs_unary & args,
         device const char * src0,
-        device const char * src1,
         device       char * dst,
         uint3   tgpig[[threadgroup_position_in_grid]],
         ushort3 tpitg[[thread_position_in_threadgroup]],
         ushort3   ntg[[threads_per_threadgroup]]) {
-    const int i03 = tgpig.z;
-    const int i02 = tgpig.y;
-    const int i01 = tgpig.x;
+#define FC_OP  FC_unary_op
+#define FC_CNT FC_unary_cnt
 
-    const int i13 = i03%args.ne13;
-    const int i12 = i02%args.ne12;
-    const int i11 = i01%args.ne11;
+    device const T0 * src0_ptr;
+    device       T  * dst_ptr;
 
-    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
-    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
-    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
+    int i0;
 
-    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-        const int i10 = i0%args.ne10;
-        *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10));
-    }
-}
-
-kernel void kernel_mul_fuse_1(
-        constant ggml_metal_kargs_bin & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort3 tpitg[[thread_position_in_threadgroup]],
-        ushort3   ntg[[threads_per_threadgroup]]) {
-    const int i03 = tgpig.z;
-    const int i02 = tgpig.y;
-    const int i01 = tgpig.x;
+    if (FC_CNT) {
+        i0 = tgpig.x;
 
-    const int i13 = i03%args.ne13;
-    const int i12 = i02%args.ne12;
-    const int i11 = i01%args.ne11;
+        src0_ptr = (device const T0 *) (src0);
+        dst_ptr  = (device       T  *) (dst);
+    } else {
+        const int i03 = tgpig.z;
+        const int i02 = tgpig.y;
+        const int k0  = tgpig.x/args.ne01;
+        const int i01 = tgpig.x - k0*args.ne01;
 
-    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
-    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
-    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
+        i0 = k0*ntg.x + tpitg.x;
 
-    if (args.ne10 == 1) {
-        const float x = *((device float *)(src1_ptr));
-        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-            *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
-        }
-    } else {
-        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-            const int i10 = i0%args.ne10;
-            *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
-        }
+        src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
+        dst_ptr  = (device       T  *) (dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1 );
     }
-}
 
-kernel void kernel_div_fuse_1(
-        constant ggml_metal_kargs_bin & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort3 tpitg[[thread_position_in_threadgroup]],
-        ushort3   ntg[[threads_per_threadgroup]]) {
-    const int i03 = tgpig.z;
-    const int i02 = tgpig.y;
-    const int i01 = tgpig.x;
+    {
+        //threadgroup_barrier(mem_flags::mem_none);
 
-    const int i13 = i03%args.ne13;
-    const int i12 = i02%args.ne12;
-    const int i11 = i01%args.ne11;
+        if (!FC_CNT) {
+            if (i0 >= args.ne0) {
+                return;
+            }
+        }
 
-    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
-    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
-    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
+        const TC x = (TC) src0_ptr[i0];
 
-    if (args.ne10 == 1) {
-        const float x = 1.0f / *((device float *)(src1_ptr));
-        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-            *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
+        if (FC_OP == OP_UNARY_NUM_SCALE) {
+            dst_ptr[i0] = (T) (args.scale * x + args.bias);
         }
-    } else {
-        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-            const int i10 = i0%args.ne10;
-            *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
+
+        if (FC_OP == OP_UNARY_NUM_FILL) {
+            dst_ptr[i0] = (T) args.val;
         }
-    }
-}
 
-kernel void kernel_add_id(
-        constant ggml_metal_kargs_add_id & args,
-        device const char * src0,
-        device const char * src1,
-        device const char * src2,
-        device       char * dst,
-        uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort3 tpitg[[thread_position_in_threadgroup]],
-        ushort3   ntg[[threads_per_threadgroup]]) {
-    const int i1 = tgpig.x;
-    const int i2 = tgpig.y;
+        if (FC_OP == OP_UNARY_NUM_CLAMP) {
+            dst_ptr[i0] = (T) clamp(x, args.min, args.max);
+        }
 
-    const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
+        if (FC_OP == OP_UNARY_NUM_SQR) {
+            dst_ptr[i0] = (T) (x * x);
+        }
 
-    const size_t nb1 = args.ne0 * sizeof(float);
-    const size_t nb2 = args.ne1 * nb1;
+        if (FC_OP == OP_UNARY_NUM_SQRT) {
+            dst_ptr[i0] = (T) sqrt(x);
+        }
 
-    device       float * dst_row  = (device       float *)((device char *)dst + i1*nb1 + i2*nb2);
-    device const float * src0_row = (device const float *)((device char *)src0 +  i1*args.nb01 + i2*args.nb02);
-    device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
+        if (FC_OP == OP_UNARY_NUM_SIN) {
+            dst_ptr[i0] = (T) sin(x);
+        }
 
-    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-        dst_row[i0] = src0_row[i0] + src1_row[i0];
-    }
-}
+        if (FC_OP == OP_UNARY_NUM_COS) {
+            dst_ptr[i0] = (T) cos(x);
+        }
 
-template
-kernel void kernel_repeat(
-        constant ggml_metal_kargs_repeat & args,
-        device const char * src0,
-        device       char * dst,
-        uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort3 tpitg[[thread_position_in_threadgroup]],
-        ushort3   ntg[[threads_per_threadgroup]]) {
-    const int i3 = tgpig.z;
-    const int i2 = tgpig.y;
-    const int i1 = tgpig.x;
+        if (FC_OP == OP_UNARY_NUM_LOG) {
+            dst_ptr[i0] = (T) log(x);
+        }
 
-    const int i03 = i3%args.ne03;
-    const int i02 = i2%args.ne02;
-    const int i01 = i1%args.ne01;
+        if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {
+            dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope));
+        }
 
-    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
-    device       char * dst_ptr  = dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1;
+        if (FC_OP == OP_UNARY_NUM_TANH) {
+            dst_ptr[i0] = (T) precise::tanh(x);
+        }
 
-    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-        const int i00 = i0%args.ne00;
-        *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
-    }
-}
+        if (FC_OP == OP_UNARY_NUM_RELU) {
+            dst_ptr[i0] = (T) fmax(0, x);
+        }
 
-typedef decltype(kernel_repeat) kernel_repeat_t;
+        if (FC_OP == OP_UNARY_NUM_SIGMOID) {
+            dst_ptr[i0] = (T) (1 / (1 + exp(-x)));
+        }
 
-template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat;
-template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat;
-template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat;
-template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat;
+        if (FC_OP == OP_UNARY_NUM_GELU) {
+            dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x))));
+        }
 
-// assumption: src1 is a row
-// broadcast src1 into src0
-template 
-kernel void kernel_add_row_c4_fuse_impl(
-        constant ggml_metal_kargs_bin & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const uint nb = args.ne00/4;
-    const uint i  = tpig % nb;
+        if (FC_OP == OP_UNARY_NUM_GELU_ERF) {
+            dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x)));
+        }
 
-    device const float4 * src0_row = (device const float4 *) (src0);
-    device       float4 *  dst_row = (device       float4 *) (dst);
+        if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {
+            dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x))));
+        }
 
-    float4 res = src0_row[tpig];
+        if (FC_OP == OP_UNARY_NUM_SILU) {
+            dst_ptr[i0] = (T) (x / (1 + exp(-x)));
+        }
 
-#pragma unroll(F)
-    for (short j = 0; j < F; ++j) {
-        res += ((device const float4 *) (src1 + args.o1[j]))[i];
-    }
+        if (FC_OP == OP_UNARY_NUM_ELU) {
+            dst_ptr[i0] = (T) elu_approx(x);
+        }
 
-    dst_row[tpig] = res;
-}
+        if (FC_OP == OP_UNARY_NUM_NEG) {
+            dst_ptr[i0] = (T) -x;
+        }
 
-typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t;
+        if (FC_OP == OP_UNARY_NUM_ABS) {
+            dst_ptr[i0] = (T) fabs(x);
+        }
 
-template [[host_name("kernel_add_row_c4_fuse_1")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
-template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>;
-template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>;
-template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
-template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>;
-template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>;
-template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>;
-template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>;
+        if (FC_OP == OP_UNARY_NUM_SGN) {
+            dst_ptr[i0] = T(x > 0) - T(x < 0);
+        }
 
-template 
-kernel void kernel_sub_row_c4_fuse_impl(
-        constant ggml_metal_kargs_bin & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint tpig[[thread_position_in_grid]]) {
+        if (FC_OP == OP_UNARY_NUM_STEP) {
+            dst_ptr[i0] = T(x > 0);
+        }
 
-    const uint nb = args.ne00/4;
-    const uint i  = tpig % nb;
+        if (FC_OP == OP_UNARY_NUM_HARDSWISH) {
+            dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5)));
+        }
 
-    device const float4 * src0_row = (device const float4 *) (src0);
-    device       float4 *  dst_row = (device       float4 *) (dst);
+        if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {
+            dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5));
+        }
 
-    device const float4 * src1_row[F];
-    for (short j = 0; j < F; ++j) {
-        src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
-    }
+        if (FC_OP == OP_UNARY_NUM_EXP) {
+            dst_ptr[i0] = (T) exp(x);
+        }
 
-    float4 res = src0_row[tpig];
+        if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {
+            dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20);
+        }
 
-#pragma unroll(F)
-    for (short j = 0; j < F; ++j) {
-        res -= src1_row[j][i];
+        if (FC_OP == OP_UNARY_NUM_EXPM1) {
+            // TODO: precise implementation
+            dst_ptr[i0] = (T) (exp(x) - 1);
+        }
     }
 
-    dst_row[tpig] = res;
+#undef FC_OP
+#undef FC_CNT
 }
 
-typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t;
-
-template [[host_name("kernel_sub_row_c4_fuse_1")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
-
-template 
-kernel void kernel_mul_row_c4_fuse_impl(
-        constant ggml_metal_kargs_bin & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint tpig[[thread_position_in_grid]]) {
-
-    const uint nb = args.ne00/4;
-    const uint i  = tpig % nb;
-
-    device const float4 * src0_row = (device const float4 *) (src0);
-    device       float4 *  dst_row = (device       float4 *) (dst);
+typedef decltype(kernel_unary_impl) kernel_unary_t;
 
-    device const float4 * src1_row[F];
-    for (short j = 0; j < F; ++j) {
-        src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
-    }
+template [[host_name("kernel_unary_f32_f32")]]   kernel kernel_unary_t kernel_unary_impl;
+template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl;
+template [[host_name("kernel_unary_f16_f16")]]   kernel kernel_unary_t kernel_unary_impl;
+template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl;
 
-    float4 res = src0_row[tpig];
+// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
+constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
+constant short FC_bin_f  [[function_constant(FC_BIN + 1)]];
+constant bool  FC_bin_rb [[function_constant(FC_BIN + 2)]];
 
-#pragma unroll(F)
-    for (short j = 0; j < F; ++j) {
-        res *= src1_row[j][i];
-    }
-
-    dst_row[tpig] = res;
-}
-
-typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t;
-
-template [[host_name("kernel_mul_row_c4_fuse_1")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
-
-template 
-kernel void kernel_div_row_c4_fuse_impl(
+template 
+kernel void kernel_bin_fuse_impl(
         constant ggml_metal_kargs_bin & args,
         device const char * src0,
         device const char * src1,
         device       char * dst,
-        uint tpig[[thread_position_in_grid]]) {
-
-    const uint nb = args.ne00/4;
-    const uint i  = tpig % nb;
-
-    device const float4 * src0_row = (device const float4 *) (src0);
-    device       float4 *  dst_row = (device       float4 *) (dst);
-
-    device const float4 * src1_row[F];
-    for (short j = 0; j < F; ++j) {
-        src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
-    }
-
-    float4 res = src0_row[tpig];
-
-#pragma unroll(F)
-    for (short j = 0; j < F; ++j) {
-        res /= src1_row[j][i];
-    }
-
-    dst_row[tpig] = res;
-}
-
-typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
-
-template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
-
-kernel void kernel_scale_f32(
-        constant ggml_metal_kargs_scale & args,
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * args.scale + args.bias;
-}
-
-kernel void kernel_scale_f32_4(
-        constant ggml_metal_kargs_scale & args,
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * args.scale + args.bias;
-}
-
-kernel void kernel_fill_f32(
-        constant ggml_metal_kargs_fill & args,
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = args.val;
-}
-
-kernel void kernel_fill_f32_4(
-        constant ggml_metal_kargs_fill & args,
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = args.val;
-}
-
-kernel void kernel_clamp_f32(
-        constant ggml_metal_kargs_clamp & args,
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = clamp(src0[tpig], args.min, args.max);
-}
-
-kernel void kernel_clamp_f32_4(
-        constant ggml_metal_kargs_clamp & args,
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = clamp(src0[tpig], args.min, args.max);
-}
-
-kernel void kernel_relu_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = max(0.0f, src0[tpig]);
-}
-
-kernel void kernel_relu_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = max(0.0f, src0[tpig]);
-}
-
-kernel void kernel_sigmoid_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
-}
-
-kernel void kernel_sigmoid_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
-}
-
-kernel void kernel_tanh_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = precise::tanh(src0[tpig]);
-}
-
-kernel void kernel_tanh_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = precise::tanh(src0[tpig]);
-}
-
-constant float GELU_COEF_A     = 0.044715f;
-constant float GELU_QUICK_COEF = -1.702f;
-constant float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
-constant float SQRT_2_INV      = 0.70710678118654752440084436210484f;
-
-kernel void kernel_gelu_f32(
-    device const float * src0,
-    device       float * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
-
-    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
-}
-
-kernel void kernel_gelu_f32_4(
-    device const float4 * src0,
-    device       float4 * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-
-    // BEWARE !!!
-    // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
-    // This was observed with Falcon 7B and 40B models
-    //
-    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
-}
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+#define FC_OP FC_bin_op
+#define FC_F  FC_bin_f
+#define FC_RB FC_bin_rb
 
-kernel void kernel_gelu_quick_f32(
-    device const float * src0,
-    device       float * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
+    if (FC_RB) {
+        // row broadcast
+        const uint i0 = tgpig.x;
+        const uint i1 = i0%args.ne10;
 
-    dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
-}
+        device const T0 * src0_row = (device const T0 *) (src0);
+        device       T  * dst_row  = (device       T  *) (dst);
 
-kernel void kernel_gelu_quick_f32_4(
-    device const float4 * src0,
-    device       float4 * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
+        if (FC_F == 1) {
+            device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]);
 
-    dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
-}
+            if (FC_OP == 0) {
+                dst_row[i0] = src0_row[i0] + src1_row[i1];
+            }
 
-// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
-// ref: https://www.johndcook.com/blog/python_erf/
-constant float p_erf  = 0.3275911f;
-constant float a1_erf = 0.254829592f;
-constant float a2_erf = -0.284496736f;
-constant float a3_erf = 1.421413741f;
-constant float a4_erf = -1.453152027f;
-constant float a5_erf = 1.061405429f;
+            if (FC_OP == 1) {
+                dst_row[i0] = src0_row[i0] - src1_row[i1];
+            }
 
-template
-T erf_approx(T x) {
-    T sign_x = sign(x);
-    x = fabs(x);
-    T t = 1.0f / (1.0f + p_erf * x);
-    T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
-    return sign_x * y;
-}
+            if (FC_OP == 2) {
+                dst_row[i0] = src0_row[i0] * src1_row[i1];
+            }
 
-kernel void kernel_gelu_erf_f32(
-    device const float * src0,
-    device       float * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
+            if (FC_OP == 3) {
+                dst_row[i0] = src0_row[i0] / src1_row[i1];
+            }
+        } else {
+            T0 res = src0_row[i0];
 
-    dst[tpig] = 0.5f*x*(1.0f+erf_approx(x*SQRT_2_INV));
-}
+            if (FC_OP == 0) {
+                FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                    res += ((device const T1 *) (src1 + args.o1[j]))[i1];
+                }
+            }
 
-kernel void kernel_gelu_erf_f32_4(
-    device const float4 * src0,
-    device       float4 * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
+            if (FC_OP == 1) {
+                FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                    res -= ((device const T1 *) (src1 + args.o1[j]))[i1];
+                }
+            }
 
-    dst[tpig] = 0.5f*x*(1.0f+erf_approx(x*SQRT_2_INV));
-}
+            if (FC_OP == 2) {
+                FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                    res *= ((device const T1 *) (src1 + args.o1[j]))[i1];
+                }
+            }
 
-kernel void kernel_silu_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
-    dst[tpig] = x / (1.0f + exp(-x));
-}
+            if (FC_OP == 3) {
+                FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                    res /= ((device const T1 *) (src1 + args.o1[j]))[i1];
+                }
+            }
 
-kernel void kernel_silu_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-    dst[tpig] = x / (1.0f + exp(-x));
-}
+            dst_row[i0] = res;
+        }
+    } else {
+        const int i03 = tgpig.z;
+        const int i02 = tgpig.y;
+        const int i01 = tgpig.x;
 
-kernel void kernel_elu_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float x = src0[tpig];
-    dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f);
-}
+        if (i01 >= args.ne01) {
+            return;
+        }
 
-kernel void kernel_elu_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float4 x = src0[tpig];
-    dst[tpig][0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
-    dst[tpig][1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
-    dst[tpig][2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
-    dst[tpig][3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
-}
+        const int i13 = i03%args.ne13;
+        const int i12 = i02%args.ne12;
+        const int i11 = i01%args.ne11;
 
-kernel void kernel_sqr_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * src0[tpig];
-}
+        device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
+        device       T  * dst_ptr  = (device       T  *) (dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs);
 
-kernel void kernel_sqr_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * src0[tpig];
-}
+        if (FC_F == 1) {
+            device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
 
-kernel void kernel_sqrt_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sqrt(src0[tpig]);
-}
+            for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+                const int i10 = i0%args.ne10;
 
-kernel void kernel_sqrt_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sqrt(src0[tpig]);
-}
+                if (FC_OP == 0) {
+                    dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
+                }
 
-kernel void kernel_sin_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sin(src0[tpig]);
-}
+                if (FC_OP == 1) {
+                    dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10];
+                }
 
-kernel void kernel_sin_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sin(src0[tpig]);
-}
+                if (FC_OP == 2) {
+                    dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10];
+                }
 
-kernel void kernel_cos_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = cos(src0[tpig]);
-}
+                if (FC_OP == 3) {
+                    dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10];
+                }
+            }
+        } else {
+            device const T1 * src1_ptr[8];
+            FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
+            }
 
-kernel void kernel_cos_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = cos(src0[tpig]);
-}
+            for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+                const int i10 = i0%args.ne10;
 
-kernel void kernel_log_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = log(src0[tpig]);
-}
+                T res = src0_ptr[i0];
 
-kernel void kernel_log_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = log(src0[tpig]);
-}
+                if (FC_OP == 0) {
+                    FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                        res += src1_ptr[j][i10];
+                    }
+                }
 
-kernel void kernel_neg_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = -src0[tpig];
-}
+                if (FC_OP == 1) {
+                    FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                        res -= src1_ptr[j][i10];
+                    }
+                }
 
-kernel void kernel_neg_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = -src0[tpig];
-}
+                if (FC_OP == 2) {
+                    FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                        res *= src1_ptr[j][i10];
+                    }
+                }
 
-kernel void kernel_abs_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = fabs(src0[tpig]);
-}
+                if (FC_OP == 3) {
+                    FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                        res /= src1_ptr[j][i10];
+                    }
+                }
 
-kernel void kernel_abs_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = fabs(src0[tpig]);
-}
+                dst_ptr[i0] = res;
+            }
+        }
+    }
 
-kernel void kernel_sgn_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sign(src0[tpig]);
+#undef FC_OP
+#undef FC_F
+#undef FC_RB
 }
 
-kernel void kernel_sgn_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sign(src0[tpig]);
-}
+typedef decltype(kernel_bin_fuse_impl) kernel_bin_fuse_t;
 
-kernel void kernel_step_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = step(0.0f, src0[tpig]);
-}
+template [[host_name("kernel_bin_fuse_f32_f32_f32")]]   kernel kernel_bin_fuse_t kernel_bin_fuse_impl;
+template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl;
 
-kernel void kernel_step_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = step(0.0f, src0[tpig]);
-}
+kernel void kernel_add_id(
+        constant ggml_metal_kargs_add_id & args,
+        device const char * src0,
+        device const char * src1,
+        device const char * src2,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i1 = tgpig.x;
+    const int i2 = tgpig.y;
 
-kernel void kernel_hardswish_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float x = src0[tpig];
-    dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
-}
+    const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
 
-kernel void kernel_hardswish_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float4 x = src0[tpig];
-    dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
-}
+    const size_t nb1 = args.ne0 * sizeof(float);
+    const size_t nb2 = args.ne1 * nb1;
 
-kernel void kernel_hardsigmoid_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float x = src0[tpig];
-    dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
-}
+    device       float * dst_row  = (device       float *)((device char *)dst  +  i1*nb1       + i2*nb2);
+    device const float * src0_row = (device const float *)((device char *)src0 +  i1*args.nb01 + i2*args.nb02);
+    device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
 
-kernel void kernel_hardsigmoid_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float4 x = src0[tpig];
-    dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        dst_row[i0] = src0_row[i0] + src1_row[i0];
+    }
 }
 
-kernel void kernel_exp_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = exp(src0[tpig]);
-}
+template
+kernel void kernel_repeat(
+        constant ggml_metal_kargs_repeat & args,
+        device const char * src0,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i3 = tgpig.z;
+    const int i2 = tgpig.y;
+    const int i1 = tgpig.x;
 
-kernel void kernel_exp_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = exp(src0[tpig]);
-}
+    const int i03 = i3%args.ne03;
+    const int i02 = i2%args.ne02;
+    const int i01 = i1%args.ne01;
 
-kernel void kernel_softplus_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
-    dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
-}
+    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
+    device       char * dst_ptr  = dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1;
 
-kernel void kernel_softplus_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-    dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        const int i00 = i0%args.ne00;
+        *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
+    }
 }
 
-kernel void kernel_expm1_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = exp(src0[tpig]) - 1.0f;
-}
+typedef decltype(kernel_repeat) kernel_repeat_t;
 
-kernel void kernel_expm1_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = exp(src0[tpig]) - 1.0f;
-}
+template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat;
+template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat;
+template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat;
+template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat;
 
 kernel void kernel_reglu_f32(
         constant ggml_metal_kargs_glu & args,
@@ -1790,6 +1475,7 @@ kernel void kernel_op_sum_f32(
         return;
     }
 
+    // TODO: become function constant
     const uint nsg = (ntg.x + 31) / 32;
 
     float sumf = 0;
@@ -1823,33 +1509,35 @@ kernel void kernel_op_sum_f32(
     }
 }
 
-template 
-kernel void kernel_sum_rows(
+constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]];
+
+template 
+kernel void kernel_sum_rows_impl(
         constant ggml_metal_kargs_sum_rows & args,
-        device const float * src0,
-        device       float * dst,
-        threadgroup  float * shmem_f32 [[threadgroup(0)]],
+        device const char * src0,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
         uint3   tgpig[[threadgroup_position_in_grid]],
         ushort3 tpitg[[thread_position_in_threadgroup]],
         ushort  sgitg[[simdgroup_index_in_threadgroup]],
         ushort  tiisg[[thread_index_in_simdgroup]],
         ushort3   ntg[[threads_per_threadgroup]]) {
-    int64_t i3 = tgpig.z;
-    int64_t i2 = tgpig.y;
-    int64_t i1 = tgpig.x;
+#define FC_OP  FC_sum_rows_op
 
-    if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
-        return;
-    }
+    const int i3 = tgpig.z;
+    const int i2 = tgpig.y;
+    const int i1 = tgpig.x;
+
+    threadgroup T0 * shmem_t = (threadgroup T0 *) shmem;
 
     if (sgitg == 0) {
-        shmem_f32[tiisg] = 0.0f;
+        shmem_t[tiisg] = 0.0f;
     }
 
-    device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
-    device       float * dst_row = (device       float *) ((device       char *) dst  + i1*args.nb1  + i2*args.nb2  + i3*args.nb3);
+    device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
+    device       T  * dst_row = (device       T  *) (dst  + i1*args.nb1  + i2*args.nb2  + i3*args.nb3);
 
-    float sumf = 0;
+    T0 sumf = T0(0.0f);
 
     for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
         sumf += src_row[i0];
@@ -1860,23 +1548,33 @@ kernel void kernel_sum_rows(
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     if (tiisg == 0) {
-        shmem_f32[sgitg] = sumf;
+        shmem_t[sgitg] = sumf;
     }
 
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    sumf = shmem_f32[tiisg];
+    sumf = shmem_t[tiisg];
     sumf = simd_sum(sumf);
 
     if (tpitg.x == 0) {
-        dst_row[0] = norm ? sumf / args.ne00 : sumf;
+        if (FC_OP == OP_SUM_ROWS_NUM_MEAN) {
+            if (is_same::value) {
+                dst_row[0] = sum(sumf) / (4*args.ne00);
+            } else {
+                dst_row[0] = sum(sumf) / args.ne00;
+            }
+        } else {
+            dst_row[0] = sum(sumf);
+        }
     }
+
+#undef FC_OP
 }
 
-typedef decltype(kernel_sum_rows) kernel_sum_rows_t;
+typedef decltype(kernel_sum_rows_impl) kernel_sum_rows_t;
 
-template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows;
-template [[host_name("kernel_mean_f32")]]     kernel kernel_sum_rows_t kernel_sum_rows;
+template [[host_name("kernel_sum_rows_f32_f32")]]   kernel kernel_sum_rows_t kernel_sum_rows_impl;
+template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl;
 
 template
 kernel void kernel_cumsum_blk(
@@ -2736,6 +2434,80 @@ kernel void kernel_rwkv_wkv7_f32(
     }
 }
 
+constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]];
+constant short FC_solve_tri_n   [[function_constant(FC_SOLVE_TRI + 1)]];
+constant short FC_solve_tri_k   [[function_constant(FC_SOLVE_TRI + 2)]];
+
+kernel void kernel_solve_tri_f32(
+        constant ggml_metal_kargs_solve_tri & args,
+        device   const char * src0,
+        device   const char * src1,
+        device         char * dst,
+        threadgroup    char * shmem [[threadgroup(0)]],
+        ushort3 tgpig[[threadgroup_position_in_grid]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    constexpr short NW = N_SIMDWIDTH;
+
+    const short NSG = FC_solve_tri_nsg;
+    const short N   = FC_solve_tri_n;
+    const short K   = FC_solve_tri_k;
+    const short NP  = PAD2(N, NW);
+
+    const int32_t i03 = tgpig.z;
+    const int32_t i02 = tgpig.y;
+    const int32_t i01 = tgpig.x*NSG + sgitg;
+
+    threadgroup float * sh0 = (threadgroup float *) shmem;
+
+    device const float * src0_ptr = (device const float *)(src0 + i02 * args.nb02 + i03 * args.nb03) + sgitg*N;
+    device const float * src1_ptr = (device const float *)(src1 + i02 * args.nb12 + i03 * args.nb13) + i01;
+    device       float * dst_ptr  = (device       float *)(dst  + i02 * args.nb2  + i03 * args.nb3)  + i01;
+
+    for (short rr = 0; rr < N; rr += NSG) {
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        {
+            threadgroup float * sh0_cur = sh0 + sgitg*NP;
+
+            for (short t = 0; t*NW < N; ++t) {
+                const short idx = t*NW + tiisg;
+                sh0_cur[idx] = src0_ptr[idx];
+            }
+
+            src0_ptr += NSG*N;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (i01 >= args.ne10) {
+            continue;
+        }
+
+        for (short ir = 0; ir < NSG && rr + ir < N; ++ir) {
+            const short r = rr + ir;
+
+            threadgroup float * sh0_cur = sh0 + ir*NP;
+
+            float sum = 0.0f;
+
+            for (short t = 0; t*NW < r; ++t) {
+                const short idx = t*NW + tiisg;
+                sum += sh0_cur[idx] * dst_ptr[idx*K] * (idx < r);
+            }
+
+            sum = simd_sum(sum);
+
+            if (tiisg == 0) {
+                const float diag = sh0_cur[r];
+
+                dst_ptr[r*K] = (src1_ptr[r*K] - sum) / diag;
+            }
+        }
+    }
+}
+
 kernel void kernel_argmax_f32(
         constant ggml_metal_kargs_argmax & args,
         device   const char * src0,
@@ -2969,26 +2741,32 @@ template [[host_name("kernel_rms_norm_f32_4")]]         kernel kernel_rms_norm_f
 template [[host_name("kernel_rms_norm_mul_f32_4")]]     kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl;
 template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl;
 
-kernel void kernel_l2_norm_f32(
+template 
+kernel void kernel_l2_norm_impl(
         constant ggml_metal_kargs_l2_norm & args,
         device const char * src0,
         device       char * dst,
         threadgroup float * shmem_f32 [[threadgroup(0)]],
-        uint   tgpig[[threadgroup_position_in_grid]],
-        ushort tpitg[[thread_position_in_threadgroup]],
-        ushort sgitg[[simdgroup_index_in_threadgroup]],
-        ushort tiisg[[thread_index_in_simdgroup]],
-        ushort   ntg[[threads_per_threadgroup]]) {
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig.z;
+    const int i02 = tgpig.y;
+    const int i01 = tgpig.x;
+
     if (sgitg == 0) {
         shmem_f32[tiisg] = 0.0f;
     }
 
-    device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
+    device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
+    device       T  * y = (device       T  *) (dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1);
 
     float sumf = 0.0f;
 
     // parallel sum
-    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+    for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
         sumf += dot(x[i00], x[i00]);
     }
     sumf = simd_sum(sumf);
@@ -3006,71 +2784,15 @@ kernel void kernel_l2_norm_f32(
 
     const float scale = 1.0f/sqrt(max(sumf, args.eps));
 
-    device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
-    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+    for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
         y[i00] = x[i00] * scale;
     }
 }
 
-kernel void kernel_solve_tri_f32(
-        constant ggml_metal_kargs_solve_tri & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint   tgpig[[threadgroup_position_in_grid]],
-        ushort tpitg[[thread_position_in_threadgroup]],
-        ushort   ntg[[threads_per_threadgroup]]) {
-    const uint64_t ncols = (uint64_t) args.ne10;
-    const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
-    const uint64_t nr = n_batches * ncols;
-
-    const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
-    if (gid >= nr) {
-        return;
-    }
-
-    const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
-    const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
-    const uint64_t i02 = rem / ncols;
-    const uint64_t i01 = rem - i02 * ncols;
-
-    const uint64_t sa0 = args.nb00 / sizeof(float);
-    const uint64_t sa1 = args.nb01 / sizeof(float);
-    const uint64_t sa2 = args.nb02 / sizeof(float);
-    const uint64_t sa3 = args.nb03 / sizeof(float);
-
-    const uint64_t sb0 = args.nb10 / sizeof(float);
-    const uint64_t sb1 = args.nb11 / sizeof(float);
-    const uint64_t sb2 = args.nb12 / sizeof(float);
-    const uint64_t sb3 = args.nb13 / sizeof(float);
-
-    const uint64_t sx0 = args.nb0 / sizeof(float);
-    const uint64_t sx1 = args.nb1 / sizeof(float);
-    const uint64_t sx2 = args.nb2 / sizeof(float);
-    const uint64_t sx3 = args.nb3 / sizeof(float);
-
-    device const float * A = (device const float *) src0;
-    device const float * B = (device const float *) src1;
-    device       float * X = (device       float *) dst;
+typedef decltype(kernel_l2_norm_impl) kernel_l2_norm_t;
 
-    const uint64_t A_base = i02 * sa2 + i03 * sa3;
-    const uint64_t B_base = i02 * sb2 + i03 * sb3;
-    const uint64_t X_base = i02 * sx2 + i03 * sx3;
-
-    const uint64_t n = (uint64_t) args.ne11;
-
-    for (uint64_t i00 = 0; i00 < n; ++i00) {
-        float sum = 0.0f;
-        for (uint64_t t = 0; t < i00; ++t) {
-            sum += A[A_base + i00 * sa1 + t * sa0] *
-                X[X_base + t * sx1 + i01 * sx0];
-        }
-
-        const float diag = A[A_base + i00 * sa1 + i00 * sa0];
-        X[X_base + i00 * sx1 + i01 * sx0] =
-            (B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
-    }
-}
+template [[host_name("kernel_l2_norm_f32_f32")]]   kernel kernel_l2_norm_t kernel_l2_norm_impl;
+template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl;
 
 kernel void kernel_group_norm_f32(
         constant ggml_metal_kargs_group_norm & args,
@@ -5388,24 +5110,6 @@ template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge
 template [[host_name("kernel_argsort_merge_i32_i32_asc")]]  kernel argsort_merge_t kernel_argsort_merge_i32_i32;
 template [[host_name("kernel_argsort_merge_i32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_i32_i32;
 
-kernel void kernel_leaky_relu_f32(
-        constant     ggml_metal_kargs_leaky_relu & args,
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float x = src0[tpig];
-    dst[tpig] = x > 0.0f ? x : x * args.slope;
-}
-
-kernel void kernel_leaky_relu_f32_4(
-        constant     ggml_metal_kargs_leaky_relu & args,
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float4 x = src0[tpig];
-    dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope);
-}
-
 constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
 
 constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]];
@@ -5482,6 +5186,7 @@ constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_E
 // scan the blocks of the mask that are not masked
 // 0 -     masked (i.e. full of -INF, skip)
 // 1 - not masked (i.e. at least one element of the mask is not -INF)
+// 2 - all zero
 kernel void kernel_flash_attn_ext_blk(
         constant ggml_metal_kargs_flash_attn_ext_blk & args,
         device const char * mask,
@@ -5503,27 +5208,29 @@ kernel void kernel_flash_attn_ext_blk(
 
     device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg;
 
-    // fast route
-    if (res == 0) {
-        if (simd_max(*mask_src) > -MAXHALF/2) {
-            res = 1;
-        }
-    }
-
     // detailed check of the elements of the block
     if ((C > NW || Q > 1) && res == 0) {
-        half m = -MAXHALF;
+        half mmin =  MAXHALF;
+        half mmax = -MAXHALF;
 
         FOR_UNROLL (short j = 0; j < Q; ++j) {
             FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) {
-                m = max(m, mask_src[ii*NW]);
+                mmin = min(mmin, mask_src[ii*NW]);
+                mmax = max(mmax, mask_src[ii*NW]);
             }
 
             mask_src += args.nb31/2;
         }
 
-        if (simd_max(m) > -MAXHALF/2) {
-            res = 1;
+        mmin = simd_min(mmin);
+        mmax = simd_max(mmax);
+
+        if (mmax > -MAXHALF) {
+            if (mmin == 0.0 && mmax == 0.0) {
+                res = 2;
+            } else {
+                res = 1;
+            }
         }
     }
 
@@ -5765,9 +5472,13 @@ void kernel_flash_attn_ext_impl(
                 ic = 0;
             }
 
+            char blk_cur = 1;
+
             // read the mask into shared mem
             if (FC_flash_attn_ext_has_mask) {
-                if (blk[ic0] == 0) {
+                blk_cur = blk[ic0];
+
+                if (blk_cur == 0) {
                     FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
                         pm2[jj] += NW;
                     }
@@ -5775,16 +5486,22 @@ void kernel_flash_attn_ext_impl(
                     continue;
                 }
 
-                FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
-                    const short j = jj*NSG + sgitg;
+                if (blk_cur == 1) {
+                    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+                        const short j = jj*NSG + sgitg;
+
+                        if (FC_flash_attn_ext_bc_mask) {
+                            sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
+                        } else {
+                            sm2[j*SH + tiisg] = pm2[jj][tiisg];
+                        }
 
-                    if (FC_flash_attn_ext_bc_mask) {
-                        sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
-                    } else {
-                        sm2[j*SH + tiisg] = pm2[jj][tiisg];
+                        pm2[jj] += NW;
+                    }
+                } else if (blk_cur == 2) {
+                    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+                        pm2[jj] += NW;
                     }
-
-                    pm2[jj] += NW;
                 }
 
 #if 0
@@ -5826,9 +5543,7 @@ void kernel_flash_attn_ext_impl(
 
                 constexpr short NC = (C/8)/NSG;
 
-                // note: do not unroll for large heads
-                #pragma unroll (DK <= 64 ? NC : 1)
-                for (short cc = 0; cc < NC; ++cc) {
+                FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
                     qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f);
 
                     if (DK % 16 != 0) {
@@ -5849,7 +5564,9 @@ void kernel_flash_attn_ext_impl(
                         k8x8_t mk[2];
                         q8x8_t mq[2];
 
-                        FOR_UNROLL (short i = 0; i < DK8/2; ++i) {
+                        // note: too much unroll can tank the performance for large heads
+                        #pragma unroll (MIN(DK8/2, 4*NSG))
+                        for (short i = 0; i < DK8/2; ++i) {
                             simdgroup_barrier(mem_flags::mem_none);
 
                             simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);
@@ -5949,10 +5666,12 @@ void kernel_flash_attn_ext_impl(
                 }
 
                 // mqk = mqk + slope*mask
-                if (FC_flash_attn_ext_has_bias) {
-                    s2 += s2_t(sm2[j*SH + tiisg])*slope;
-                } else {
-                    s2 += s2_t(sm2[j*SH + tiisg]);
+                if (blk_cur != 2) {
+                    if (FC_flash_attn_ext_has_bias) {
+                        s2 += s2_t(sm2[j*SH + tiisg])*slope;
+                    } else {
+                        s2 += s2_t(sm2[j*SH + tiisg]);
+                    }
                 }
 
                 M[jj] = simd_max(max(M[jj], max(s2[0], s2[1])));
@@ -6023,7 +5742,9 @@ void kernel_flash_attn_ext_impl(
                                 pv  += 8*NS20;
                             }
                         } else {
-                            FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) {
+                            constexpr short NC = (C/8)/2;
+
+                            FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
                                 s8x8_t vs[2];
 
                                 simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);
@@ -6203,7 +5924,7 @@ template<
     void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
     short DK,         // K head size
     short DV,         // V head size
-    short Q  = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup
+    short Q  = OP_FLASH_ATTN_EXT_NQPSG, // queries per threadgroup
     short C  = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup
 kernel void kernel_flash_attn_ext(
         constant ggml_metal_kargs_flash_attn_ext & args,
@@ -6413,11 +6134,10 @@ template<
     void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
     short DK,       // K head size
     short DV,       // V head size
-    short NE,       // head elements per thread
-    short Q,        // queries per threadgroup
-    short C,        // cache items per threadgroup
-    short NSG>      // number of simd groups
-void kernel_flash_attn_ext_vec_impl(
+    short NE = 4,   // head elements per thread
+    short Q  = OP_FLASH_ATTN_EXT_VEC_NQPSG,  // queries per threadgroup
+    short C  = OP_FLASH_ATTN_EXT_VEC_NCPSG>  // cache items per threadgroup
+kernel void kernel_flash_attn_ext_vec(
         constant ggml_metal_kargs_flash_attn_ext_vec & args,
         device const char * q,
         device const char * k,
@@ -6434,6 +6154,7 @@ void kernel_flash_attn_ext_vec_impl(
     static_assert(DV % 32 == 0, "DV must be divisible by 32");
 
 #define NWG  (FC_flash_attn_ext_vec_nwg)
+#define NSG  (FC_flash_attn_ext_vec_nsg)
 
 #define NS10 (FC_flash_attn_ext_vec_ns10)
 #define NS20 (FC_flash_attn_ext_vec_ns20)
@@ -6460,14 +6181,14 @@ void kernel_flash_attn_ext_vec_impl(
     static_assert(DK4 % NL == 0, "DK4 must be divisible by NL");
     static_assert(DV4 % NL == 0, "DV4 must be divisible by NL");
 
-    const short T = PK + NSG*SH; // shared memory size per query in (half)
+  //const short T = PK + NSG*SH; // shared memory size per query in (half)
 
-  //threadgroup q_t   * sq  = (threadgroup q_t   *) (shmem_f16 +                    0*PK); // holds the query data
-    threadgroup q4_t  * sq4 = (threadgroup q4_t  *) (shmem_f16 +                    0*PK); // same as above but in q4_t
-    threadgroup s_t   * ss  = (threadgroup s_t   *) (shmem_f16 +   sgitg*SH       + Q*PK); // scratch buffer for attention
-    threadgroup s4_t  * ss4 = (threadgroup s4_t  *) (shmem_f16 +   sgitg*SH       + Q*PK); // same as above but in s4_t
-    threadgroup half  * sm  = (threadgroup half  *) (shmem_f16 +   sgitg*SH + 2*C + Q*PK); // scratch buffer for mask
-    threadgroup o4_t  * so4 = (threadgroup o4_t  *) (shmem_f16 + 2*sgitg*PV       + Q*T);  // scratch buffer for the results
+  //threadgroup q_t   * sq  = (threadgroup q_t   *) (shmem_f16 +                      0*PK); // holds the query data
+    threadgroup q4_t  * sq4 = (threadgroup q4_t  *) (shmem_f16 +                      0*PK); // same as above but in q4_t
+    threadgroup s_t   * ss  = (threadgroup s_t   *) (shmem_f16 +   sgitg*SH       + NSG*PK); // scratch buffer for attention
+    threadgroup s4_t  * ss4 = (threadgroup s4_t  *) (shmem_f16 +   sgitg*SH       + NSG*PK); // same as above but in s4_t
+    threadgroup half  * sm  = (threadgroup half  *) (shmem_f16 +   sgitg*SH + 2*C + NSG*PK); // scratch buffer for mask
+    threadgroup o4_t  * so4 = (threadgroup o4_t  *) (shmem_f16 + 2*sgitg*PV       + NSG*PK + NSG*SH); // scratch buffer for the results
 
     // store the result for all queries in shared memory (the O matrix from the paper)
     so4 += tiisg;
@@ -6485,11 +6206,13 @@ void kernel_flash_attn_ext_vec_impl(
     // load heads from Q to shared memory
     device const float4 * q4 = (device const float4 *) ((device const char *) q);
 
-    for (short i = tiisg; i < PK4; i += NW) {
-        if (iq1 < args.ne01 && i < DK4) {
-            sq4[i] = (q4_t) q4[i];
-        } else {
-            sq4[i] = (q4_t) 0.0f;
+    if (iq1 < args.ne01) {
+        for (short i = tiisg; i < PK4; i += NW) {
+            if (i < DK4) {
+                sq4[i] = (q4_t) q4[i];
+            } else {
+                sq4[i] = (q4_t) 0.0f;
+            }
         }
     }
 
@@ -6567,7 +6290,7 @@ void kernel_flash_attn_ext_vec_impl(
             }
 
             // skip -INF blocks
-            if (simd_max(sm[tiisg]) == -INFINITY) {
+            if (simd_max(sm[tiisg]) <= -MAXHALF) {
                 continue;
             }
 
@@ -6841,57 +6564,11 @@ void kernel_flash_attn_ext_vec_impl(
     }
 
 #undef NWG
+#undef NSG
 #undef NS10
 #undef NS20
 }
 
-template<
-    typename q4_t,  // query types in shared memory
-    typename k4_t,  // key types in shared memory
-    typename v4_t,  // value types in shared memory
-    typename qk_t,  // Q*K types
-    typename s_t,   // soft-max types
-    typename s4_t,
-    typename o4_t,  // attention accumulation types
-    typename kd4_t, // key type in device memory
-    short nl_k,
-    void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
-    typename vd4_t, // value type in device memory
-    short nl_v,
-    void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
-    short DK,       // K head size
-    short DV,       // V head size
-    short NE = 4,   // head elements per thread
-    short Q  = OP_FLASH_ATTN_EXT_VEC_NQPTG,  // queries per threadgroup
-    short C  = OP_FLASH_ATTN_EXT_VEC_NCPSG>  // cache items per threadgroup
-kernel void kernel_flash_attn_ext_vec(
-        constant ggml_metal_kargs_flash_attn_ext_vec & args,
-        device const char * q,
-        device const char * k,
-        device const char * v,
-        device const char * mask,
-        device const char * sinks,
-        device const char * pad,
-        device       char * dst,
-        threadgroup  half * shmem_f16 [[threadgroup(0)]],
-        uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort  tiisg[[thread_index_in_simdgroup]],
-        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
-#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C
-#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
-    switch (FC_flash_attn_ext_vec_nsg) {
-      // note: disabled cases to reduce library load time
-        case 1:  kernel_flash_attn_ext_vec_impl(FWD_ARGS); break;
-        case 2:  kernel_flash_attn_ext_vec_impl(FWD_ARGS); break;
-        case 4:  kernel_flash_attn_ext_vec_impl(FWD_ARGS); break;
-      //case 8:  kernel_flash_attn_ext_vec_impl(FWD_ARGS); break;
-      //case 16: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break;
-      //case 32: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break;
-    }
-#undef FWD_TMPL
-#undef FWD_ARGS
-}
-
 // note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
 //       in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
 //
@@ -9054,6 +8731,26 @@ kernel void kernel_set_rows_f(
     }
 }
 
+kernel void kernel_diag_f32(
+        constant ggml_metal_kargs_diag & args,
+        device   const char * src0,
+        device         char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiitg[[thread_index_in_threadgroup]]) {
+    constexpr short NW = N_SIMDWIDTH;
+
+    const int32_t i3 = tgpig.z;
+    const int32_t i2 = tgpig.y;
+    const int32_t i1 = tgpig.x;
+
+    device const float * src0_ptr = (device const float *)(src0 +                i2*args.nb02 + i3*args.nb03);
+    device       float * dst_ptr  = (device       float *)(dst  + i1*args.nb01 + i2*args.nb2  + i3*args.nb3);
+
+    for (int i0 = tiitg; i0 < args.ne0; i0 += NW) {
+        dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f;
+    }
+}
+
 constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
 constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
 
@@ -9072,7 +8769,9 @@ kernel void kernel_mul_mm(
     threadgroup S0 * sa = (threadgroup S0 *)(shmem);
     threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
 
+#ifdef GGML_METAL_HAS_TENSOR
     threadgroup float * sc = (threadgroup float *)(shmem);
+#endif
 
     constexpr int NR0 = 64;
     constexpr int NR1 = 32;
@@ -9195,8 +8894,8 @@ kernel void kernel_mul_mm(
             const short sx = (tiitg%NL1);
             const short sy = (tiitg/NL1)/8;
 
-            const short dx = sx;
-            const short dy = sy;
+          //const short dx = sx;
+          //const short dy = sy;
 
             const short ly = (tiitg/NL1)%8;
 
@@ -9423,6 +9122,7 @@ typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
 template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
 template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
 template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
+template [[host_name("kernel_mul_mm_id_map0_ne20_5" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<5>;
 template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
 template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
 template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
@@ -9445,7 +9145,9 @@ kernel void kernel_mul_mm_id(
     threadgroup S0 * sa = (threadgroup S0 *)(shmem);
     threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
 
+#ifdef GGML_METAL_HAS_TENSOR
     threadgroup float * sc = (threadgroup float *)(shmem);
+#endif
 
     constexpr int NR0 = 64;
     constexpr int NR1 = 32;
@@ -9580,8 +9282,8 @@ kernel void kernel_mul_mm_id(
             const short sx = (tiitg%NL1);
             const short sy = (tiitg/NL1)/8;
 
-            const short dx = sx;
-            const short dy = sy;
+          //const short dx = sx;
+          //const short dy = sy;
 
             const short ly = (tiitg/NL1)%8;
 
@@ -9834,9 +9536,6 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]]  kernel mul_mm_t kernel_mul_m
 
 template [[host_name("kernel_mul_mm_f32_f16")]]     kernel mul_mm_t kernel_mul_mm;
 template [[host_name("kernel_mul_mm_f16_f16")]]     kernel mul_mm_t kernel_mul_mm;
-#if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_mul_mm_bf16_f16")]]    kernel mul_mm_t kernel_mul_mm;
-#endif
 template [[host_name("kernel_mul_mm_q4_0_f16")]]    kernel mul_mm_t kernel_mul_mm;
 template [[host_name("kernel_mul_mm_q4_1_f16")]]    kernel mul_mm_t kernel_mul_mm;
 template [[host_name("kernel_mul_mm_q5_0_f16")]]    kernel mul_mm_t kernel_mul_mm;
@@ -9892,9 +9591,6 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]]  kernel mul_mm_id kernel_m
 
 template [[host_name("kernel_mul_mm_id_f32_f16")]]     kernel mul_mm_id kernel_mul_mm_id;
 template [[host_name("kernel_mul_mm_id_f16_f16")]]     kernel mul_mm_id kernel_mul_mm_id;
-#if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_mul_mm_id_bf16_f16")]]    kernel mul_mm_id kernel_mul_mm_id;
-#endif
 template [[host_name("kernel_mul_mm_id_q4_0_f16")]]    kernel mul_mm_id kernel_mul_mm_id;
 template [[host_name("kernel_mul_mm_id_q4_1_f16")]]    kernel mul_mm_id kernel_mul_mm_id;
 template [[host_name("kernel_mul_mm_id_q5_0_f16")]]    kernel mul_mm_id kernel_mul_mm_id;
@@ -10150,6 +9846,74 @@ kernel void kernel_pool_2d_avg_f32(
     o_ptr[cur_oh * args.OW + cur_ow] = res;
 }
 
+
+kernel void kernel_pool_1d_max_f32(
+        constant        ggml_metal_kargs_pool_1d & args,
+        device  const   float * src,
+        device          float * dst,
+        uint            gid [[thread_position_in_grid]]
+) {
+
+    if (gid >= args.np) {
+        return;
+    }
+
+    const int ow  = (int)gid % args.OW;
+    const int row = (int)gid / args.OW;
+
+    const int base = ow * args.s0 - args.p0;
+
+    float acc = -INFINITY;
+
+    const int src_off = row * args.IW;
+    const int dst_off = row * args.OW;
+
+    for (int ki = 0; ki < args.k0; ++ki) {
+        int j = base + ki;
+        if (j < 0 || j >= args.IW){
+            continue;
+        }
+        float v = src[src_off + j];
+        acc = max(acc, v);
+    }
+
+    dst[dst_off + ow] = acc;
+}
+
+kernel void kernel_pool_1d_avg_f32(
+        constant        ggml_metal_kargs_pool_1d & args,
+        device  const   float * src,
+        device          float * dst,
+        uint            gid [[thread_position_in_grid]]
+) {
+
+    if (gid >= args.np) {
+        return;
+    }
+
+    const int ow  = (int)gid % args.OW;
+    const int row = (int)gid / args.OW;
+
+    const int base = ow * args.s0 - args.p0;
+
+    float acc = 0.0f;
+    int   cnt = 0;
+
+    const int src_off = row * args.IW;
+    const int dst_off = row * args.OW;
+
+    for (int ki = 0; ki < args.k0; ++ki) {
+        const int j = base + ki;
+        if (j < 0 || j >= args.IW) {
+            continue;
+        }
+        acc += src[src_off + j];
+        cnt += 1;
+    }
+
+    dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f;
+}
+
 kernel void kernel_opt_step_adamw_f32(
         constant    ggml_metal_kargs_opt_step_adamw & args,
         device       float * x,
@@ -10197,3 +9961,75 @@ kernel void kernel_opt_step_sgd_f32(
 
     x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
 }
+
+template
+kernel void kernel_memset(
+        constant ggml_metal_kargs_memset & args,
+        device T * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = args.val;
+}
+
+typedef decltype(kernel_memset) kernel_memset_t;
+
+template [[host_name("kernel_memset_i64")]] kernel kernel_memset_t kernel_memset;
+
+constant short FC_count_equal_nsg [[function_constant(FC_COUNT_EQUAL + 0)]];
+
+template
+kernel void kernel_count_equal(
+        constant ggml_metal_kargs_count_equal & args,
+        device   const char * src0,
+        device   const char * src1,
+        device   atomic_int * dst,
+        threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const short NSG = FC_count_equal_nsg;
+
+    const int i3 = tgpig.z;
+    const int i2 = tgpig.y;
+    const int i1 = tgpig.x;
+
+    if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
+        return;
+    }
+
+    int sum = 0;
+
+    device const char * base0 = src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03;
+    device const char * base1 = src1 + i1*args.nb11 + i2*args.nb12 + i3*args.nb13;
+
+    for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
+        const T v0 = *(device const T *)(base0 + i0*args.nb00);
+        const T v1 = *(device const T *)(base1 + i0*args.nb10);
+        sum += (v0 == v1);
+    }
+
+    sum = simd_sum(sum);
+
+    if (tiisg == 0) {
+        shmem_i32[sgitg] = sum;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    if (sgitg == 0) {
+        float v = 0.0f;
+        if (tpitg.x < NSG) {
+            v = shmem_i32[tpitg.x];
+        }
+
+        float total = simd_sum(v);
+        if (tpitg.x == 0) {
+            atomic_fetch_add_explicit(dst, (int32_t) total, memory_order_relaxed);
+        }
+    }
+}
+
+typedef decltype(kernel_count_equal) kernel_count_equal_t;
+
+template [[host_name("kernel_count_equal_i32")]] kernel kernel_count_equal_t kernel_count_equal;
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 9cc4ebdeffb..3f7abaf45fe 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -93,6 +93,7 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
 #define VK_VENDOR_ID_APPLE 0x106b
 #define VK_VENDOR_ID_INTEL 0x8086
 #define VK_VENDOR_ID_NVIDIA 0x10de
+#define VK_VENDOR_ID_QUALCOMM 0x5143
 
 #define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256
 
@@ -120,6 +121,8 @@ struct ggml_backend_vk_context;
 // Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT.
 #define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 3)
 
+typedef std::shared_ptr vk_pipeline;
+
 struct vk_pipeline_struct {
     std::string name;
     vk::ShaderModule shader_module;
@@ -137,9 +140,15 @@ struct vk_pipeline_struct {
     std::atomic compiled {};
     // number of registers used, extracted from pipeline executable properties
     uint32_t register_count {};
+
+#if defined(VK_EXT_shader_64bit_indexing)
+    bool is_64b_indexing {};
+#endif
+    // linked list of pipelines for multiple compilation variants.
+    // currently only used to compile a 64-bit indexing variant.
+    vk_pipeline next;
 };
 
-typedef std::shared_ptr vk_pipeline;
 typedef std::weak_ptr vk_pipeline_ref;
 
 static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
@@ -231,9 +240,7 @@ static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = {
     /* .is_host          = */ NULL,
 };
 
-#ifdef GGML_VULKAN_MEMORY_DEBUG
 class vk_memory_logger;
-#endif
 class vk_perf_logger;
 static void ggml_vk_destroy_buffer(vk_buffer& buf);
 static void ggml_vk_synchronize(ggml_backend_vk_context * ctx);
@@ -250,6 +257,7 @@ enum vk_device_architecture {
     AMD_RDNA3,
     INTEL_XE2,
     NVIDIA_PRE_TURING,
+    NVIDIA_TURING,
 };
 
 static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
@@ -332,18 +340,34 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
         const std::vector ext_props = device.enumerateDeviceExtensionProperties();
 
         bool cooperative_matrix = false;
+        bool sm_builtins = false;
 
         // Detect "pre-turing" based on lack of coopmat support.
         for (const auto& properties : ext_props) {
             if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) {
                 cooperative_matrix = true;
-                break;
+            } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
+                sm_builtins = true;
             }
         }
 
         if (!cooperative_matrix) {
             return vk_device_architecture::NVIDIA_PRE_TURING;
         }
+
+        if (sm_builtins) {
+            vk::PhysicalDeviceProperties2 props2;
+            vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
+
+            props2.pNext = &sm_props;
+
+            device.getProperties2(&props2);
+
+            // Turing has 32, following architectures have 48
+            if (sm_props.shaderWarpsPerSM == 32) {
+                return vk_device_architecture::NVIDIA_TURING;
+            }
+        }
     }
     return vk_device_architecture::OTHER;
 }
@@ -381,18 +405,20 @@ enum FaCodePath {
 };
 
 struct vk_fa_pipeline_state {
-    vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, FaCodePath path, bool aligned, bool f32acc)
-        : HSK(HSK), HSV(HSV), small_rows(small_rows), path(path), aligned(aligned), f32acc(f32acc) {}
-
     uint32_t HSK, HSV;
-    bool small_rows;
+    uint32_t Br, Bc;
+    uint32_t D_split, row_split;
+    bool shmem_staging;
     FaCodePath path;
+    uint32_t workgroup_size, subgroup_size;
     bool aligned;
     bool f32acc;
+    uint32_t flags;
+    uint32_t limit_occupancy_shmem;
 
     bool operator<(const vk_fa_pipeline_state &b) const {
-        return std::tie(HSK, HSV, small_rows, path, aligned, f32acc) <
-               std::tie(b.HSK, b.HSV, b.small_rows, b.path, b.aligned, b.f32acc);
+        return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem) <
+               std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem);
     }
 };
 
@@ -436,8 +462,15 @@ static constexpr std::initializer_list topk_moe_early_softmax_norm{ GGM
                                                                              GGML_OP_VIEW,     GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
                                                                              GGML_OP_SUM_ROWS, GGML_OP_CLAMP,    GGML_OP_DIV,
                                                                              GGML_OP_RESHAPE };
+
+static constexpr std::initializer_list topk_moe_sigmoid_norm_bias{ GGML_OP_UNARY,    GGML_OP_RESHAPE,  GGML_OP_ADD,
+                                                                            GGML_OP_ARGSORT,  GGML_OP_VIEW,     GGML_OP_GET_ROWS,
+                                                                            GGML_OP_RESHAPE,  GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
+                                                                            GGML_OP_DIV,      GGML_OP_RESHAPE };
+
 static constexpr std::initializer_list topk_moe_early_softmax     { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE,  GGML_OP_ARGSORT,
                                                                              GGML_OP_VIEW,     GGML_OP_GET_ROWS };
+
 static constexpr std::initializer_list topk_moe_late_softmax      { GGML_OP_ARGSORT,  GGML_OP_VIEW,
                                                                              GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
                                                                              GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
@@ -466,6 +499,32 @@ static constexpr std::initializer_list> topk_moe_early_softma
     { 9, 0, 8 }, // reshape->src[0]  == div
 };
 
+//node #436 (     UNARY):     ffn_moe_probs-10 ( 256K) [Vulka         ] use=2:    ffn_moe_logits-10 ( 256K) [Vulka         ]
+//node #437 (   RESHAPE): ffn_moe_probs-10 (re ( 256K) [Vulka         ] use=1:     ffn_moe_probs-10 ( 256K) [Vulka         ]
+//node #438 (       ADD): ffn_moe_probs_biased ( 256K) [Vulka         ] use=1:     ffn_moe_probs-10 ( 256K) [Vulka         ] blk.10.exp_probs_b.b (   0K) [Vulka         ]
+//node #439 (   ARGSORT):   ffn_moe_argsort-10 ( 256K) [Vulka         ] use=1: ffn_moe_probs_biased ( 256K) [Vulka         ]
+//node #440 (      VIEW):      ffn_moe_topk-10 ( 255K) [Vulka         ] use=3:   ffn_moe_argsort-10 ( 256K) [Vulka         ]
+//node #441 (  GET_ROWS):   ffn_moe_weights-10 (  12K) [Vulka         ] use=1: ffn_moe_probs-10 (re ( 256K) [Vulka         ]      ffn_moe_topk-10 ( 255K) [Vulka         ]
+//node #442 (   RESHAPE): ffn_moe_weights-10 ( (  12K) [Vulka         ] use=2:   ffn_moe_weights-10 (  12K) [Vulka         ]
+//node #443 (  SUM_ROWS): ffn_moe_weights_sum- (   2K) [Vulka         ] use=1: ffn_moe_weights-10 ( (  12K) [Vulka         ]
+//node #444 (     CLAMP): ffn_moe_weights_sum_ (   2K) [Vulka         ] use=1: ffn_moe_weights_sum- (   2K) [Vulka         ]
+//node #445 (       DIV): ffn_moe_weights_norm (  12K) [Vulka         ] use=1: ffn_moe_weights-10 ( (  12K) [Vulka         ] ffn_moe_weights_sum_ (   2K) [Vulka         ]
+//node #446 (   RESHAPE): ffn_moe_weights_norm (  12K) [Vulka         ] use=1: ffn_moe_weights_norm (  12K) [Vulka         ]
+static constexpr std::initializer_list> topk_moe_sigmoid_norm_bias_edges {
+    { 1, 0, 0 }, // reshape->src[0]  == sigmoid
+    { 2, 0, 0 }, // add->src[0]      == sigmoid
+    { 3, 0, 2 }, // argsort->src[0]  == add
+    { 4, 0, 3 }, // view->src[0]     == argsort
+    { 5, 0, 1 }, // get_rows->src[0] == reshape
+    { 5, 1, 4 }, // get_rows->src[1] == view
+    { 6, 0, 5 }, // reshape->src[0]  == get_rows
+    { 7, 0, 6 }, // sum_rows->src[0] == reshape
+    { 8, 0, 7 }, // clamp->src[0]    == sum_rows
+    { 9, 0, 6 }, // div->src[0]      == reshape
+    { 9, 1, 8 }, // div->src[1]      == clamp
+    {10, 0, 9 }, // reshape->src[0]  == div
+};
+
 // same as early_softmax_norm but ending after the get_rows
 static constexpr std::initializer_list> topk_moe_early_softmax_edges {
     { 1, 0, 0 }, // reshape->src[0]  == softmax
@@ -493,16 +552,10 @@ enum topk_moe_mode {
     TOPK_MOE_EARLY_SOFTMAX,
     TOPK_MOE_EARLY_SOFTMAX_NORM,
     TOPK_MOE_LATE_SOFTMAX,
+    TOPK_MOE_SIGMOID_NORM_BIAS,
     TOPK_MOE_COUNT,
 };
 
-static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {
-    topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM :
-                         num == topk_moe_early_softmax.size() - 1      ? TOPK_MOE_EARLY_SOFTMAX :
-                                                                         TOPK_MOE_LATE_SOFTMAX;
-    return mode;
-}
-
 static constexpr std::initializer_list> rope_view_set_rows_edges {
     { 1, 0, 0 }, // view->src[0]     == rope
     { 2, 0, 1 }, // set_rows->src[0] == view
@@ -525,6 +578,8 @@ struct vk_device_struct {
     uint64_t max_memory_allocation_size;
     uint64_t max_buffer_size;
     uint64_t suballocation_block_size;
+    uint64_t min_imported_host_pointer_alignment;
+    bool external_memory_host {};
     bool fp16;
     bool bf16;
     bool pipeline_robustness;
@@ -537,12 +592,14 @@ struct vk_device_struct {
     vk_queue transfer_queue;
     bool single_queue;
     bool support_async;
+    bool async_use_transfer_queue;
     uint32_t subgroup_size;
     uint32_t subgroup_size_log2;
     uint32_t shader_core_count;
     bool uma;
     bool prefer_host_memory;
     bool float_controls_rte_fp16;
+    bool subgroup_basic;
     bool subgroup_arithmetic;
     bool subgroup_shuffle;
     bool subgroup_ballot;
@@ -556,6 +613,8 @@ struct vk_device_struct {
     bool add_rms_fusion;
     uint32_t partials_binding_alignment;
 
+    bool shader_64b_indexing;
+
     bool integer_dot_product;
     // 0: default, 1: force mmvq, -1: disable mmvq
     int32_t mmvq_mode;
@@ -633,6 +692,7 @@ struct vk_device_struct {
     vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
     vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
     vk_pipeline pipeline_acc_f32;
+    vk_pipeline pipeline_set_f32;
 
     // [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16]
     vk_pipeline pipeline_add[2][2][2];
@@ -653,7 +713,7 @@ struct vk_device_struct {
     vk_pipeline pipeline_add_id_f32;
 
     vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
-    vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32;
+    vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32, pipeline_upscale_bilinear_antialias_f32;
     vk_pipeline pipeline_scale_f32;
     vk_pipeline pipeline_sqr_f32;
     vk_pipeline pipeline_sqrt_f32;
@@ -691,6 +751,7 @@ struct vk_device_struct {
     vk_pipeline pipeline_gelu_quick[2];
     vk_pipeline pipeline_silu[2];
     vk_pipeline pipeline_relu[2];
+    vk_pipeline pipeline_xielu[2];
     vk_pipeline pipeline_neg[2];
     vk_pipeline pipeline_tanh[2];
     vk_pipeline pipeline_sigmoid[2];
@@ -732,13 +793,16 @@ struct vk_device_struct {
 
     vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16;
     vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16;
-    vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
+    vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16, pipeline_rope_multi_f32_f16;
     vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
     vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
     vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
     vk_pipeline pipeline_topk_f32[num_topk_pipelines];
     vk_pipeline pipeline_sum_rows_f32;
     vk_pipeline pipeline_cumsum_f32;
+    vk_pipeline pipeline_cumsum_small_f32;
+    vk_pipeline pipeline_cumsum_multipass1_f32;
+    vk_pipeline pipeline_cumsum_multipass2_f32;
     vk_pipeline pipeline_argmax_f32;
     vk_pipeline pipeline_count_equal_i32;
     std::map pipeline_solve_tri_f32;
@@ -763,10 +827,13 @@ struct vk_device_struct {
 
     std::map pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT];
 
+    std::map, vk_pipeline> pipeline_fa_mask_opt;
+
     vk_pipeline pipeline_flash_attn_split_k_reduce;
+    vk_pipeline pipeline_count_experts;
 
     // [2] is for whether to take n_experts from spec constant (0) or push constant (1)
-    vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT][2];
+    vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
 
     std::vector all_pipelines;
 
@@ -782,9 +849,7 @@ struct vk_device_struct {
     bool allow_sysmem_fallback;
     bool disable_graph_optimize;
 
-#ifdef GGML_VULKAN_MEMORY_DEBUG
     std::unique_ptr memory_logger;
-#endif
 
     ~vk_device_struct() {
         VK_LOG_DEBUG("destroy device " << name);
@@ -857,6 +922,15 @@ struct vk_subbuffer {
     }
 };
 
+// vk_event is used for the event-related backend interfaces. It uses 'event' for
+// event_wait and 'fence' for event_synchronize. Polling on an event for
+// event_synchronize wouldn't be sufficient to wait for command buffers to complete,
+// and would lead to validation errors.
+struct vk_event {
+    vk::Event event;
+    vk::Fence fence;
+};
+
 struct vk_semaphore {
     vk::Semaphore s;
     uint64_t value;
@@ -874,6 +948,7 @@ struct vk_mat_mat_push_constants {
     uint32_t M; uint32_t N; uint32_t K;
     uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
     uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
+    uint32_t base_work_group_z; uint32_t num_batches;
     uint32_t k_split;
     uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
     uint32_t padded_N;
@@ -893,6 +968,7 @@ struct vk_mat_vec_push_constants {
     uint32_t batch_stride_b;
     uint32_t batch_stride_d;
     uint32_t fusion_flags;
+    uint32_t base_work_group_y;
     uint32_t ne02;
     uint32_t ne12;
     uint32_t broadcast2;
@@ -943,6 +1019,8 @@ struct vk_mat_vec_id_push_constants {
     uint32_t fusion_flags;
     uint32_t nei0;
     uint32_t ne11;
+    uint32_t expert_i1;
+    uint32_t nbi1;
 };
 
 struct vk_flash_attn_push_constants {
@@ -992,6 +1070,16 @@ struct vk_op_push_constants {
     uint32_t KY;
     float param1;
     float param2;
+    float param3;
+    float param4;
+};
+
+struct vk_op_count_experts_push_constants {
+    uint32_t ne00;
+    uint32_t ne01;
+    uint32_t nb00;
+    uint32_t nb01;
+    uint32_t a_offset;
 };
 
 struct vk_op_glu_push_constants {
@@ -1162,6 +1250,11 @@ struct vk_op_topk_moe_push_constants {
     uint32_t n_expert_used;
     float clamp_min;
     float clamp_max;
+    uint32_t gating_func;
+    uint32_t has_bias;
+    uint32_t with_norm;
+    float output_scale;
+    float output_bias;
 };
 
 struct vk_op_add_id_push_constants {
@@ -1181,24 +1274,30 @@ struct vk_op_diag_mask_push_constants {
 
 struct vk_op_rope_push_constants {
     uint32_t rope_mode;
-    uint32_t ncols;
+    uint32_t nrows;
     uint32_t n_dims;
     float freq_scale;
-    uint32_t p_delta_rows;
     float freq_base;
     float ext_factor;
     float attn_factor;
     float corr_dims[2];
     float theta_scale;
     uint32_t has_ff;
-    uint32_t ne02;
-    uint32_t s1;
-    uint32_t s2;
     int32_t sections[4];
     uint32_t is_imrope;
     uint32_t is_back;
     uint32_t set_rows_stride;
+    uint32_t ne00;
+    uint32_t ne01;
+    uint32_t ne02;
+    uint32_t nb01;
+    uint32_t nb02;
+    uint32_t nb03;
+    uint32_t nb11;
+    uint32_t nb12;
+    uint32_t nb13;
 };
+static_assert(sizeof(vk_op_rope_push_constants) <= 128, "sizeof(vk_op_rope_push_constants) must be <= 128");
 
 // For fused rms_norm+mul+rope(+view+set_rows)
 struct vk_op_rms_norm_mul_rope_push_constants {
@@ -1260,6 +1359,7 @@ struct vk_op_im2col_push_constants {
     int32_t s0; int32_t s1;
     int32_t p0; int32_t p1;
     int32_t d0; int32_t d1;
+    uint32_t batch_IC;
 };
 
 struct vk_op_im2col_3d_push_constants {
@@ -1446,6 +1546,32 @@ template <> void init_pushconst_fastdiv(vk_op_sum_rows_push_constants &p) {
     init_fastdiv_values(p.ne01,        p.ne0_1mp,  p.ne0_1L);
 }
 
+struct vk_quantize_q8_1_push_constants {
+    uint32_t ne;
+    uint32_t num_blocks;
+};
+
+struct vk_op_flash_attn_split_k_reduce_push_constants {
+    uint32_t D;
+    uint32_t ne1;
+    uint32_t ne2;
+    uint32_t ne3;
+    uint32_t k_num;
+    uint32_t sinks;
+};
+
+struct vk_op_flash_attn_mask_opt_push_constants {
+    uint32_t nem0;
+    uint32_t nem1;
+    uint32_t nem2;
+    uint32_t nbm1;
+    uint32_t nbm2;
+    uint32_t nbm3;
+    uint32_t nbd1;
+    uint32_t nbd2;
+    uint32_t nbd3;
+};
+
 // Allow pre-recording command buffers
 struct vk_staging_memcpy {
     vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -1489,8 +1615,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex
 static void ggml_vk_load_shaders(vk_device& device);
 static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx);
 
-#if defined(GGML_VULKAN_MEMORY_DEBUG) || defined(GGML_VULKAN_DEBUG)
-#define VK_LOG_MEMORY(msg) std::cerr << "ggml_vulkan memory: " << msg << std::endl
+static bool vk_memory_logger_enabled = false;
+
+#define VK_LOG_MEMORY(msg) if (vk_memory_logger_enabled) { std::cerr << "ggml_vulkan memory: " << msg << std::endl; }
 
 static std::string format_size(size_t size) {
     const size_t kib = 1024;
@@ -1523,14 +1650,17 @@ class vk_memory_logger {
     std::map allocations; // Track allocations
     size_t total_device;
     size_t total_host;
+    static std::mutex log_mutex;
 };
-#else
-#define VK_LOG_MEMORY(msg) ((void) 0)
-#endif // GGML_VULKAN_MEMORY_DEBUG
+
+std::mutex vk_memory_logger::log_mutex;
 
 static bool vk_perf_logger_enabled = false;
+static bool vk_perf_logger_concurrent = false;
+static bool vk_enable_sync_logger = false;
 // number of calls between perf logger prints
 static uint32_t vk_perf_logger_frequency = 1;
+static std::string vk_pipeline_stats_filter;
 
 class vk_perf_logger {
   public:
@@ -1551,7 +1681,7 @@ class vk_perf_logger {
                 total_op_times += time;
             }
             std::cerr << t.first << ": " << t.second.size() << " x " << (total_op_times / t.second.size() / 1000.0)
-                      << " us";
+                      << " us = " << (total_op_times / 1000.0) << " us";
 
             // If we have as many flops entries as timing entries for the op, then compute and log the flops/S.
             auto it = flops.find(t.first);
@@ -1579,14 +1709,14 @@ class vk_perf_logger {
         flops.clear();
     }
 
-    void log_timing(const ggml_tensor * node, const char *fusion_name, uint64_t time) {
+    std::string get_node_fusion_name(const ggml_tensor * node, const char *fusion_name, uint64_t *n_flops) {
+        *n_flops = 0;
         std::string fusion_str;
         if (fusion_name) {
             fusion_str = fusion_name + std::string(" ");
         }
         if (node->op == GGML_OP_UNARY) {
-            timings[fusion_str + ggml_unary_op_name(ggml_get_unary_op(node))].push_back(time);
-            return;
+            return fusion_str + ggml_unary_op_name(ggml_get_unary_op(node));
         }
         if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
             const uint64_t m     = node->ne[0];
@@ -1608,9 +1738,8 @@ class vk_perf_logger {
                 name += " batch=" + std::to_string(batch);
             }
             name = fusion_str + name;
-            timings[name].push_back(time);
-            flops[name].push_back(m * n * (k + (k - 1)) * batch);
-            return;
+            *n_flops = m * n * (k + (k - 1)) * batch;
+            return name;
         }
         if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D) {
             std::string   name    = ggml_op_name(node->op);
@@ -1626,20 +1755,17 @@ class vk_perf_logger {
             uint64_t      size_M  = Cout;
             uint64_t      size_K  = Cin * KW * KH;
             uint64_t      size_N  = N * OW * OH;
-            uint64_t      n_flops = size_M * size_N * (size_K + (size_K - 1));
+            *n_flops = size_M * size_N * (size_K + (size_K - 1));
             name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) +
                     ", N=N*OW*OH=" + std::to_string(size_N);
             name = fusion_str + name;
-            flops[name].push_back(n_flops);
-            timings[name].push_back(time);
-            return;
+            return name;
         }
         if (node->op == GGML_OP_RMS_NORM) {
             std::string   name    = ggml_op_name(node->op);
             name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")";
             name = fusion_str + name;
-            timings[name].push_back(time);
-            return;
+            return name;
         }
         if (node->op == GGML_OP_FLASH_ATTN_EXT) {
             const ggml_tensor * dst = node;
@@ -1655,8 +1781,8 @@ class vk_perf_logger {
                 " k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " <<
                 " v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " <<
                 " m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")";
-            timings[name.str()].push_back(time);
-            return;
+            *n_flops = 2ull * q->ne[1] * q->ne[2] * (k->ne[0] + v->ne[0]) * k->ne[1] * q->ne[3];
+            return name.str();
         }
         if (node->op == GGML_OP_TOP_K) {
             std::stringstream name;
@@ -1664,11 +1790,38 @@ class vk_perf_logger {
             name << ggml_op_name(node->op) <<
                 " K=" << node->ne[0] <<
                 " (" << node->src[0]->ne[0] << "," << node->src[0]->ne[1] << "," << node->src[0]->ne[2] << "," << node->src[0]->ne[3] << ")";
-            timings[name.str()].push_back(time);
-            return;
+            return name.str();
+        }
+        return fusion_str + ggml_op_name(node->op);
+    }
+
+    void log_timing(const ggml_tensor * node, const char *fusion_name, uint64_t time) {
+        uint64_t n_flops;
+        std::string name = get_node_fusion_name(node, fusion_name, &n_flops);
+        if (n_flops) {
+            flops[name].push_back(n_flops);
+        }
+        timings[name].push_back(time);
+    }
+
+    void log_timing(const std::vector &nodes, const std::vector &names, uint64_t time) {
+        uint64_t total_flops = 0;
+        std::string name;
+        for (size_t n = 0; n < nodes.size(); ++n) {
+            uint64_t n_flops = 0;
+            name += get_node_fusion_name(nodes[n], names[n], &n_flops);
+            total_flops += n_flops;
+
+            if (n != nodes.size() - 1) {
+                name += ", ";
+            }
         }
-        timings[fusion_str + ggml_op_name(node->op)].push_back(time);
+        if (total_flops) {
+            flops[name].push_back(total_flops);
+        }
+        timings[name].push_back(time);
     }
+
   private:
     std::map> timings;
     std::map> flops;
@@ -1707,7 +1860,10 @@ struct ggml_backend_vk_context {
     bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync;
 
     vk_context_ref compute_ctx;
+
     vk_context_ref transfer_ctx;
+    vk_semaphore transfer_semaphore;
+    uint64_t transfer_semaphore_last_submitted {};
 
     std::vector tensor_ctxs;
 
@@ -1726,12 +1882,16 @@ struct ggml_backend_vk_context {
     // Bit 'i' means nodes[start_of_fusion + i] writes to memory.
     // If there's no fusion, bit 0 is still set.
     int fused_ops_write_mask {};
+    topk_moe_mode fused_topk_moe_mode {};
+    bool fused_topk_moe_scale {};
 
     // for GGML_VK_PERF_LOGGER
     std::unique_ptr perf_logger;
     vk::QueryPool query_pool;
     std::vector query_fusion_names;
+    std::vector query_fusion_node_count;
     std::vector query_nodes;
+    std::vector query_node_idx;
     int32_t num_queries {};
     int32_t query_idx {};
 };
@@ -1805,10 +1965,10 @@ struct ggml_backend_vk_buffer_context {
     }
 };
 
-#ifdef GGML_VULKAN_MEMORY_DEBUG
-static std::mutex log_mutex;
-
 void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) {
+    if (!vk_memory_logger_enabled) {
+        return;
+    }
     std::lock_guard guard(log_mutex);
     vk_buffer buf = buf_ref.lock();
     const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal);
@@ -1820,7 +1980,7 @@ void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) {
 }
 
 void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) {
-    if (buf_ref.expired() || buf_ref.lock()->size == 0) {
+    if (buf_ref.expired() || buf_ref.lock()->size == 0 || !vk_memory_logger_enabled) {
         return;
     }
 
@@ -1838,7 +1998,6 @@ void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) {
         VK_LOG_MEMORY("ERROR " << buf->device->name << ": Attempted to deallocate unknown " << type << " memory at " << buf->buffer);
     }
 }
-#endif // GGML_VULKAN_MEMORY_DEBUG
 
 struct vk_instance_t {
     vk::Instance instance;
@@ -1988,6 +2147,19 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
         compute_pipeline_create_info.setPNext(&rci);
     }
 
+#if defined(VK_EXT_shader_64bit_indexing)
+    vk::PipelineCreateFlags2CreateInfo pipelineFlags2CreateInfo;
+    if (pipeline->is_64b_indexing)
+    {
+        pipelineFlags2CreateInfo.flags = vk::PipelineCreateFlagBits2::e64BitIndexingEXT;
+        if (device->pipeline_executable_properties_support) {
+            pipelineFlags2CreateInfo.flags |= vk::PipelineCreateFlagBits2::eCaptureStatisticsKHR;
+        }
+        pipelineFlags2CreateInfo.setPNext(compute_pipeline_create_info.pNext);
+        compute_pipeline_create_info.setPNext(&pipelineFlags2CreateInfo);
+    }
+#endif
+
     try {
         pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
     } catch (const vk::SystemError& e) {
@@ -2010,7 +2182,32 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
         executableInfo.pipeline = pipeline->pipeline;
 
         auto statistics = device->device.getPipelineExecutableStatisticsKHR(executableInfo);
+
+        bool print_stats = !vk_pipeline_stats_filter.empty() &&
+                           pipeline->name.find(vk_pipeline_stats_filter) != std::string::npos;
+        if (print_stats) {
+            std::cerr << "ggml_vulkan: pipeline stats for " << pipeline->name << ":" << std::endl;
+        }
+
         for (auto & s : statistics) {
+            if (print_stats) {
+                std::cerr << "ggml_vulkan:   " << s.name.data() << ": ";
+                switch (s.format) {
+                    case vk::PipelineExecutableStatisticFormatKHR::eBool32:
+                        std::cerr << (s.value.b32 ? "true" : "false");
+                        break;
+                    case vk::PipelineExecutableStatisticFormatKHR::eInt64:
+                        std::cerr << s.value.i64;
+                        break;
+                    case vk::PipelineExecutableStatisticFormatKHR::eUint64:
+                        std::cerr << s.value.u64;
+                        break;
+                    case vk::PipelineExecutableStatisticFormatKHR::eFloat64:
+                        std::cerr << s.value.f64;
+                        break;
+                }
+                std::cerr << std::endl;
+            }
             // "Register Count" is reported by NVIDIA drivers.
             if (strcmp(s.name, "Register Count") == 0) {
                 VK_LOG_DEBUG(pipeline->name << " " << s.name << ": " << s.value.u64 << " registers");
@@ -2326,7 +2523,8 @@ static std::vector ggml_vk_find_memory_properties(const vk::PhysicalDe
     return indices;
 }
 
-static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list & req_flags_list) {
+static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list & req_flags_list,
+                                       void *import_ptr = nullptr) {
     VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")");
     if (size > device->max_buffer_size) {
         throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device buffer size limit");
@@ -2355,6 +2553,12 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
         nullptr,
     };
 
+    vk::ExternalMemoryBufferCreateInfo external_memory_bci;
+    if (import_ptr) {
+        external_memory_bci.handleTypes = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT;
+        buffer_create_info.setPNext(&external_memory_bci);
+    }
+
     buf->buffer = device->device.createBuffer(buffer_create_info);
 
     vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer);
@@ -2369,35 +2573,80 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
         mem_flags_info.setPNext(&mem_priority_info);
     }
 
-    for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
-        const auto & req_flags = *it;
+    if (import_ptr) {
+        vk::MemoryHostPointerPropertiesEXT host_pointer_props;
+        try {
+            host_pointer_props = device->device.getMemoryHostPointerPropertiesEXT(vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT, import_ptr);
+        } catch (vk::SystemError& e) {
+            GGML_LOG_WARN("ggml_vulkan: Failed getMemoryHostPointerPropertiesEXT (%s)\n", e.what());
+            device->device.destroyBuffer(buf->buffer);
+            return {};
+        }
+        vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
 
-        const std::vector memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags);
+        uint32_t memory_type_idx;
+        vk::MemoryPropertyFlags property_flags = *req_flags_list.begin();
+        for (memory_type_idx = 0; memory_type_idx < 32; ++memory_type_idx) {
+            if (!(host_pointer_props.memoryTypeBits & (1u << memory_type_idx))) {
+                continue;
+            }
+            if (!(mem_req.memoryTypeBits & (1u << memory_type_idx))) {
+                continue;
+            }
 
-        if (memory_type_indices.empty()) {
-            continue;
+            vk::MemoryType memory_type = mem_props.memoryTypes[memory_type_idx];
+            // check for visible+coherent+cached. Other flags (e.g. devicelocal) are allowed
+            if ((memory_type.propertyFlags & property_flags) == property_flags) {
+                property_flags = memory_type.propertyFlags;
+                break;
+            }
+        }
+        if (memory_type_idx == 32) {
+            GGML_LOG_WARN("ggml_vulkan: Memory type for host allocation not found\n");
+            device->device.destroyBuffer(buf->buffer);
+            return {};
         }
-        buf->memory_property_flags = req_flags;
 
-        bool done = false;
+        buf->memory_property_flags = mem_props.memoryTypes[memory_type_idx].propertyFlags;
+        try {
+            vk::ImportMemoryHostPointerInfoEXT import_info;
+            import_info.handleType = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT;
+            import_info.pHostPointer = import_ptr;
+            import_info.setPNext(&mem_flags_info);
+            buf->device_memory = device->device.allocateMemory({ size, memory_type_idx, &import_info });
+        } catch (const vk::SystemError& e) {
+        }
+    } else {
+        for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
+            const auto & req_flags = *it;
 
-        for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) {
-            try {
-                buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info });
-                done = true;
-                break;
-            } catch (const vk::SystemError& e) {
-                // loop and retry
-                // during last attempt throw the exception
-                if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) {
-                    device->device.destroyBuffer(buf->buffer);
-                    throw e;
+            const std::vector memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags);
+
+            if (memory_type_indices.empty()) {
+                continue;
+            }
+            buf->memory_property_flags = req_flags;
+
+            bool done = false;
+
+            for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) {
+                try {
+                    buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info });
+                    done = true;
+                    break;
+                } catch (const vk::SystemError& e) {
+                    // loop and retry
+                    // during last attempt throw the exception
+                    if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) {
+                        device->device.destroyBuffer(buf->buffer);
+                        throw e;
+                    }
                 }
             }
-        }
 
-        if (done) {
-            break;
+            if (done) {
+                break;
+            }
         }
     }
 
@@ -2408,8 +2657,12 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
 
     buf->ptr = nullptr;
 
-    if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
-        buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
+    if (import_ptr) {
+        buf->ptr = import_ptr;
+    } else {
+        if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
+            buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
+        }
     }
 
     device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0);
@@ -2422,9 +2675,7 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
         buf->bda_addr = device->device.getBufferAddress(addressInfo);
     }
 
-#ifdef GGML_VULKAN_MEMORY_DEBUG
     device->memory_logger->log_allocation(buf, size);
-#endif
 
     return buf;
 }
@@ -2481,11 +2732,9 @@ static void ggml_vk_destroy_buffer(vk_buffer& buf) {
         return;
     }
 
-#ifdef GGML_VULKAN_MEMORY_DEBUG
     if (buf->device != nullptr) {
         buf->device->memory_logger->log_deallocation(buf);
     }
-#endif
 
     buf.reset();
 }
@@ -2516,6 +2765,15 @@ static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subct
     );
 }
 
+static void ggml_vk_set_event(vk_context& ctx, vk::Event& event) {
+    VK_LOG_DEBUG("ggml_vk_set_event()");
+
+    ctx->s->buffer.setEvent(
+        event,
+        ctx->p->q->stage_flags
+    );
+}
+
 static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events) {
     VK_LOG_DEBUG("ggml_vk_wait_events()");
     if (events.empty()) {
@@ -2532,79 +2790,218 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events
     );
 }
 
-// number of rows/cols for flash attention shader
-static constexpr uint32_t flash_attention_num_small_rows = 32;
-static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
+struct vk_fa_tuning_params {
+    FaCodePath path;
+    uint32_t workgroup_size;
+    uint32_t subgroup_size;
+    uint32_t block_rows;
+    uint32_t block_cols;
+    uint32_t d_split;
+    uint32_t row_split;
+    bool shmem_staging;
+    bool disable_subgroups;
+    uint32_t limit_occupancy_shmem;
+
+    void print() const {
+        std::cerr << "path=" << path << " workgroup_size=" << workgroup_size << " subgroup_size=" << subgroup_size <<
+                     " block_rows=" << block_rows << " block_cols=" << block_cols << " d_split=" << d_split <<
+                     " row_split=" << row_split << " shmem_staging=" << shmem_staging << " disable_subgroups=" << disable_subgroups <<
+                     " limit_occupancy_shmem=" << limit_occupancy_shmem << std::endl;
+    }
+};
+
+static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
+static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
 
-static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) {
-    if (hsv >= 192) {
-        return 2;
-    } else if ((hsv | hsk) & 8) {
-        return 4;
+static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
+    GGML_UNUSED(kv_type);
+
+    vk_fa_tuning_params result{};
+    result.path = FA_SCALAR;
+
+    if (device->vendor_id == VK_VENDOR_ID_INTEL) {
+        // Disable subgroup use due to performance issues when enforcing subgroup sizes
+        result.subgroup_size = 32;
+        result.disable_subgroups = true;
+    } else if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN) {
+        result.subgroup_size = n_rows < 4 ? 32 : device->subgroup_size;
     } else {
-        return 8;
+        result.subgroup_size = device->subgroup_size;
     }
-}
 
-// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
-// 128 threads split into four subgroups, each subgroup does 1/4
-// of the Bc dimension.
-static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16;
-static constexpr uint32_t scalar_flash_attention_Bc = 64;
-static constexpr uint32_t scalar_flash_attention_workgroup_size = 128;
+    // Row split splits the workgroup so that synchronization only has to happen within subgroups, which avoids barriers
+    uint32_t row_split_max_hsk = 64;
+    if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && !device->uma) {
+        row_split_max_hsk = n_rows <= 8 ? 64 : 128;
+    }
+    result.row_split = (n_rows < 4 || hsk <= row_split_max_hsk) ? 1 : 4;
 
-static uint32_t get_fa_num_small_rows(FaCodePath path) {
-    if (path == FA_COOPMAT2) {
-        return flash_attention_num_small_rows;
+    if (result.subgroup_size > 32 && (n_rows < 4 || hsk < (result.row_split == 1 ? 128 : 64))) {
+        result.workgroup_size = result.subgroup_size * 2;
     } else {
-        return scalar_flash_attention_num_small_rows;
+        result.workgroup_size = result.subgroup_size * 4;
     }
-}
 
-static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) {
-    GGML_UNUSED(clamp);
-    GGML_UNUSED(hsv);
+    const uint32_t D = hsk | hsv;
 
-    if (path == FA_SCALAR) {
-        if (small_rows) {
-            return {scalar_flash_attention_num_small_rows, 64};
+    const bool reduce_block_rows = D & 8 || n_kv < 1024 || device->vendor_id == VK_VENDOR_ID_INTEL;
+
+    if (n_rows == 1) {
+        result.block_rows = 1;
+        result.block_cols = 64;
+    } else {
+        // row_split 1 means higher register use per row, so block size has to be adjusted
+        if (result.row_split == 1) {
+            result.block_rows = n_rows == 2 ? 2 : ((n_rows <= 4 || reduce_block_rows) ? 4 : 8);
         } else {
-            if ((hsv | hsk) & 8) {
-                // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
-                // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
-                return {get_fa_scalar_num_large_rows(hsk, hsv), 64};
-            } else {
-                return {get_fa_scalar_num_large_rows(hsk, hsv), 32};
-            }
+            result.block_rows = n_rows <= 4 ? 4 : ((n_rows <= 8 || reduce_block_rows) ? 8 : 16);
         }
+
+        result.block_cols = (D & 8) ? 64 : 32;
     }
 
-    if (path == FA_COOPMAT1) {
-        if (small_rows) {
-            return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc};
-        } else {
-            return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc};
+    const uint32_t D_lsb = D ^ (D & (D-1));  // extract lowest set bit
+
+    result.d_split = std::min(std::min(result.subgroup_size, 8u), D_lsb / 4);
+
+    result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
+
+    if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc)) {
+        result.block_rows /= 2;
+    }
+
+    // On AMD RDNA, for small head sizes and big batch size the shader uses few registers, so too many subgroups get scheduled
+    // at once and end up thrashing the cache. Fix this by setting a large (unused) shmem buffer that reduces occupancy.
+    // This targets an occupancy of 4 subgroups per SIMD.
+    if (device->vendor_id == VK_VENDOR_ID_AMD && device->properties.limits.maxComputeSharedMemorySize == 65536) {
+        if (device->architecture != AMD_GCN && n_rows >= 64 && hsk <= 128) {
+            // 30kb target for hsk > 64, 26kb for <= 64 due to smaller workgroup size
+            // Values are guessed, tested on RDNA2
+            result.limit_occupancy_shmem = (hsk <= 64 ? 26 : 30) * 1024 / 4 / 4;
+        } else if (device->architecture == AMD_GCN && n_rows <= 8 && hsk >= 256) {
+            // Same thing for GCN, with an occupancy target of 2 subgroups per SIMD.
+            // Here low-batch FA with large head size is affected.
+            // n_rows < 4 switch because workgroup size switches from 128 to 256 there.
+            result.limit_occupancy_shmem = (n_rows < 4 ? 14 : 26) * 1024 / 4 / 4;
         }
     }
 
-    // small rows, large cols
+    return result;
+}
+
+static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
+    GGML_UNUSED(n_rows);
+    GGML_UNUSED(n_kv);
+    GGML_UNUSED(kv_type);
+    GGML_UNUSED(f32acc);
+
+    vk_fa_tuning_params result{};
+    result.path = FA_COOPMAT1;
+
+    const uint32_t D = hsk | hsv;
+
+    const uint32_t coopmat_block_rows = 16;
+    const uint32_t coopmat_block_cols = 16;
+
+    const uint32_t num_subgroups = 4;
+
+    result.block_rows = coopmat_block_rows;
+    result.block_cols = coopmat_block_cols * num_subgroups;
+    result.row_split = num_subgroups;
+    result.subgroup_size = device->subgroup_size;
+    result.workgroup_size = num_subgroups * result.subgroup_size;
+
+    const uint32_t D_lsb = D ^ (D & (D-1));  // extract lowest set bit
+    result.d_split = std::min(std::min(result.subgroup_size, 8u), D_lsb / 4);
+
+    result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
+
+    return result;
+}
+
+static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
+    GGML_UNUSED(n_kv);
+    GGML_UNUSED(f32acc);
+
+    vk_fa_tuning_params result{};
+    result.path = FA_COOPMAT2;
+
+    const uint32_t D = hsk | hsv;
+
+    const bool small_rows = n_rows < 32;
+
     if (small_rows) {
-        return {get_fa_num_small_rows(FA_COOPMAT2), 32};
+        result.block_rows = 32;
+        result.block_cols = 32;
+    } else if (ggml_is_quantized(kv_type) || hsk >= 256 || hsv >= 256) {
+        result.block_rows = (hsk >= 512 || hsv >= 512) ? 32 : 64;
+        result.block_cols = 32;
+    } else {
+        result.block_rows = 64;
+        result.block_cols = 64;
     }
 
-    // small cols to reduce register count
-    if (ggml_is_quantized(type) || hsk >= 256 || hsv >= 256) {
-        if (hsk >= 512 || hsv >= 512) {
-            return {32, 32};
-        } else {
-            return {64, 32};
+    result.subgroup_size = device->subgroup_size;
+    result.workgroup_size = (small_rows && (D % 32) == 0) ? 256 : 128;
+
+    return result;
+}
+
+static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
+    FaCodePath path = device->coopmat2 ? FA_COOPMAT2 :
+                      device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
+
+    if (path == FA_COOPMAT1 && device->architecture == vk_device_architecture::NVIDIA_TURING) {
+        // Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090
+        path = FA_SCALAR;
+    }
+
+    if (path == FA_COOPMAT1) {
+        bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) ||
+                        (!f32acc && device->coopmat_support_16x16x16_f16acc);
+        const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
+        bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc);
+
+        if (!shape_ok || !shmem_ok) {
+            path = FA_SCALAR;
         }
     }
-    return {64, 64};
+
+    // scalar is faster than coopmat when N==1
+    if (n_rows == 1 && (path == FA_COOPMAT1 || path == FA_COOPMAT2)) {
+        path = FA_SCALAR;
+    }
+
+    switch (path) {
+    case FA_SCALAR:
+        return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
+    case FA_COOPMAT1:
+        return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
+    case FA_COOPMAT2:
+        return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
+    default:
+        throw std::runtime_error("unsupported FaCodePath");
+    }
+}
+
+static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc,
+                                                  bool use_mask, bool use_mask_opt, bool use_logit_softcap) {
+    const bool old_amd_windows = device->vendor_id == VK_VENDOR_ID_AMD && device->driver_id == vk::DriverId::eAmdProprietary &&
+                                 (device->architecture == AMD_GCN || device->architecture == AMD_RDNA1 || device->architecture == AMD_RDNA2);
+
+    uint32_t flags = (use_mask_opt      ? 1 : 0) |
+                     (use_mask          ? 2 : 0) |
+                     (use_logit_softcap ? 4 : 0) |
+                     (old_amd_windows   ? 8 : 0);
+
+    const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size;
+
+    return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem};
 }
 
-static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows) {
-    return fa_rows_cols(path, hsk, hsv, 0, type, small_rows)[1];
+static std::vector get_fa_spec_constants(const vk_fa_pipeline_state& state) {
+    return {state.workgroup_size, state.Br, state.Bc, state.HSK, state.HSV, !state.aligned, state.D_split,
+            state.row_split, state.subgroup_size, state.shmem_staging ? 1u : 0u, state.flags, state.limit_occupancy_shmem};
 }
 
 static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) {
@@ -2613,7 +3010,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
     switch (src0_type) {
     case GGML_TYPE_IQ1_S:
     case GGML_TYPE_IQ1_M:
-        lut_size = 2*2048;
+        lut_size = 2*2048 + 4*2048;
         break;
     case GGML_TYPE_IQ2_XXS:
         lut_size = 8*256;
@@ -2784,9 +3181,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
         s_mmq_wg_denoms_k = { 32,  64,  1 };
 
         // spec constants and tile sizes for quant matmul_id
-        l_warptile_mmqid = { 256, 128, 128, 16, 1, device->subgroup_size };
-        m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
-        s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
+        l_warptile_mmqid = { 256, 128, 128, 32, 1, device->subgroup_size };
+        m_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size };
+        s_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size };
         l_mmqid_wg_denoms = { 128, 128, 1 };
         m_mmqid_wg_denoms = { 128, 64, 1 };
         s_mmqid_wg_denoms = { 128, 64, 1 };
@@ -2806,44 +3203,55 @@ static void ggml_vk_load_shaders(vk_device& device) {
         const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;
         const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;
 
-        l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
-        m_warptile = { 128,  64,  64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
-        s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
+        const uint32_t s_warptile_wm = device->subgroup_size == 8 ? 8 : 32;
 
-        l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
-        m_warptile_mmq = { 128,  64,  64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
-        s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
+        l_warptile = { 128,             128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
+        m_warptile = { 128,              64,  64, 16, subgroup_size_8,     32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
+        s_warptile = { subgroup_size_32, 32,  32, 16, s_warptile_wm,       32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
+
+        l_warptile_mmq = { 128,             128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
+        m_warptile_mmq = { 128,              64,  64, 32, subgroup_size_8,     32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
+        s_warptile_mmq = { subgroup_size_32, 32,  32, 32, s_warptile_wm,       32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
 
         // Integer MMQ has a smaller shared memory profile, but heavier register use
-        l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
-        m_warptile_mmq_int = { 128,  64,  64, 32, subgroup_size_8,     32, 2, 2, 2, 1, subgroup_size_8 };
-        s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32,       32, 2, 2, 1, 1, subgroup_size_8 };
+        l_warptile_mmq_int = { 128,             128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
+        m_warptile_mmq_int = { 128,              64,  64, 32, subgroup_size_8,     32, 2, 2, 2, 1, subgroup_size_8 };
+        s_warptile_mmq_int = { subgroup_size_32, 32,  32, 32, s_warptile_wm,       32, 2, 2, 1, 1, subgroup_size_8 };
 
         // K-quants use even more registers, mitigate by setting WMITER to 1
-        l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
-        m_warptile_mmq_int_k = { 128,  64,  64, 32, subgroup_size_8,     32, 1, 2, 2, 1, subgroup_size_8 };
-        s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, 32,       32, 1, 2, 1, 1, subgroup_size_8 };
+        l_warptile_mmq_int_k = { 128,               128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
+        m_warptile_mmq_int_k = { 128,                64,  64, 32, subgroup_size_8,     32, 1, 2, 2, 1, subgroup_size_8 };
+        s_warptile_mmq_int_k = { subgroup_size_32,   32,  32, 32, s_warptile_wm,       32, 1, 2, 1, 1, subgroup_size_8 };
 
-        l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
-        m_warptile_id = { 128,  64,  64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
-        s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
+        l_warptile_id = { 128,                      128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
+        m_warptile_id = { 128,                       64,  64, 16, mul_mat_subgroup_size_16,     32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
+        s_warptile_id = { mul_mat_subgroup_size_16,  32,  32, 16, s_warptile_wm,                32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
 
-        l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 };
-        m_warptile_mmqid = { 128,  64,  64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
-        s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
+        l_warptile_mmqid = { 128,                       128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 };
+        m_warptile_mmqid = { 128,                        64,  64, 32, mul_mat_subgroup_size_8,     32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
+        s_warptile_mmqid = { mul_mat_subgroup_size_32,   32,  32, 32, s_warptile_wm,               32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
 
-        l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
-        m_warptile_mmqid_int = { 128,  64,  64, 32, mul_mat_subgroup_size_8,     32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
-        s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, 32,       32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
+        l_warptile_mmqid_int = { 128,                       128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
+        m_warptile_mmqid_int = { 128,                        64,  64, 32, mul_mat_subgroup_size_8,     32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
+        s_warptile_mmqid_int = { mul_mat_subgroup_size_32,   32,  32, 32, s_warptile_wm,               32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
 
-        l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
-        m_warptile_mmqid_int_k = { 128,  64,  64, 32, mul_mat_subgroup_size_16,     32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
-        s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, 32,       32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
+        l_warptile_mmqid_int_k = { 128,                     128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
+        m_warptile_mmqid_int_k = { 128,                      64,  64, 32, mul_mat_subgroup_size_16,     32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
+        s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32,  32, 32, s_warptile_wm,                32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
 
         // chip specific tuning
         if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
             m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
             m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
+        } else if (device->vendor_id == VK_VENDOR_ID_AMD && device->coopmat_support && device->driver_id != vk::DriverId::eAmdProprietary) {
+            // This is intentionally using tx_m values, slight performance increase
+            l_warptile = { 256, 128, 128, 16, subgroup_size_8, 64, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
+            l_warptile_mmq = l_warptile_mmq_int = { 256, 128, 128, 32, subgroup_size_8, 64, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
+            l_warptile_mmq_int_k = { 256, 128, 128, 32, subgroup_size_16, 64, 1, 4, 2, 1, subgroup_size_16 };
+        } else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->coopmat_support && device->architecture == INTEL_XE2) {
+            // Xe2/Xe3 with coopmat enabled - warptile performance tuning
+            l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
+            l_warptile_mmq = { 512, 128, 128, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
         }
 
         l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
@@ -2899,7 +3307,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     }
 
     std::vector> compiles;
-    auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint,
+    auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& base_pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint,
                                               uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants,
                                               uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
 
@@ -2907,35 +3315,49 @@ static void ggml_vk_load_shaders(vk_device& device) {
             required_subgroup_size = get_subgroup_size(name, device->architecture);
         }
 
-        if (!pipeline) {
-            pipeline = std::make_shared();
-        }
-        if (!pipeline->initialized) {
-            pipeline->name = name;
-            pipeline->parameter_count = parameter_count;
-            pipeline->push_constant_size = push_constant_size;
-            pipeline->wg_denoms = wg_denoms;
-            pipeline->align = align;
-            pipeline->initialized = true;
-        }
+        vk_pipeline *ptr = &base_pipeline;
 
-        if (!pipeline->needed || pipeline->compiled) {
-            return;
+        int num_pipelines = 1;
+#if defined(VK_EXT_shader_64bit_indexing)
+        if (device->shader_64b_indexing) {
+            num_pipelines = 2;
         }
-        // TODO: We're no longer benefitting from the async compiles (shaders are
-        // compiled individually, as needed) and this complexity can be removed.
-        {
-            // wait until fewer than N compiles are in progress
-            uint32_t N = std::max(1u, std::thread::hardware_concurrency());
-            std::unique_lock guard(compile_count_mutex);
-            while (compile_count >= N) {
-                compile_count_cond.wait(guard);
+#endif
+        for (int i = 0; i < num_pipelines; ++i, ptr = &(*ptr)->next) {
+            vk_pipeline &pipeline = *ptr;
+            if (!pipeline) {
+                pipeline = std::make_shared();
+            }
+            if (!pipeline->initialized) {
+                pipeline->name = name;
+                pipeline->parameter_count = parameter_count;
+                pipeline->push_constant_size = push_constant_size;
+                pipeline->wg_denoms = wg_denoms;
+                pipeline->align = align;
+                pipeline->initialized = true;
+#if defined(VK_EXT_shader_64bit_indexing)
+                pipeline->is_64b_indexing = (i == 1);
+#endif
+            }
+
+            if (!pipeline->needed || pipeline->compiled) {
+                continue;
+            }
+            // TODO: We're no longer benefitting from the async compiles (shaders are
+            // compiled individually, as needed) and this complexity can be removed.
+            {
+                // wait until fewer than N compiles are in progress
+                uint32_t N = std::max(1u, std::thread::hardware_concurrency());
+                std::unique_lock guard(compile_count_mutex);
+                while (compile_count >= N) {
+                    compile_count_cond.wait(guard);
+                }
+                compile_count++;
             }
-            compile_count++;
-        }
 
-        compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
-                                      parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
+            compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
+                                          parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
+        }
     };
 
     auto const &ggml_vk_create_pipeline2 = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const char *entrypoint,
@@ -2946,59 +3368,43 @@ static void ggml_vk_load_shaders(vk_device& device) {
                                        align, disable_robustness, require_full_subgroups, required_subgroup_size);
     };
 
-    auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array {
-        return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
-    };
-
-    auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector {
-        // For large number of rows, 128 invocations seems to work best.
-        // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
-        // can't use 256 for D==80.
-        // For scalar, use 128 (arbitrary)
-        // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
-        const uint32_t D = (hsk|hsv);
-        uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
-                            ? scalar_flash_attention_workgroup_size
-                            : ((small_rows && (D % 32) == 0) ? 256 : 128);
-        auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows);
-
-        // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
-        // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
-        const uint32_t D_lsb = D ^ (D & (D-1));
-        uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
-
-        return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
-    };
-
 #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
         for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \
-            uint32_t HSK = fa.first.HSK; \
-            uint32_t HSV = fa.first.HSV; \
-            bool small_rows = fa.first.small_rows; \
             FaCodePath path = fa.first.path; \
+            uint32_t Br = fa.first.Br; \
+            uint32_t Bc = fa.first.Bc; \
             bool aligned = fa.first.aligned; \
             bool f32acc = fa.first.f32acc; \
+            uint32_t fa_sgs = fa.first.subgroup_size; \
+            bool fa_ds = fa.first.subgroup_size == 0; \
             if (path == FAPATH) { \
                 if (aligned) { \
                     if (f32acc) { \
-                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0));     \
                     } else { \
-                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0));     \
                     } \
                 } else { \
                     if (f32acc) { \
-                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1,                                        true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1,  true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0));     \
                     } else { \
-                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1,                                        true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1,  true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0));     \
                     } \
                 } \
             } \
         }
 
-    CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
-    CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
-    CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
-    CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
+    if (device->fp16) {
+        CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
+        CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
+        CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
+        CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
+    } else {
+        CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
+        CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
+        CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
+        CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
+    }
 #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
     if (device->coopmat1_fa_support) {
         CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
@@ -3021,17 +3427,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
 #endif
 #undef CREATE_FA
 
+    const int mul_mat_id_param_count = 5;
+
 #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
     if (device->coopmat2) {
 
         // Create 6 variants, {s,m,l}x{unaligned,aligned}
 #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
-        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1);   \
-        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1);   \
-        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1);   \
-        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align);   \
-        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align);   \
-        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align);   \
+        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, true);   \
+        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, true);   \
+        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, true);   \
+        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, true);   \
+        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, true);   \
+        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, true);   \
 
         // Create 2 variants, {f16,f32} accumulator
 #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
@@ -3067,32 +3475,32 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
         GGML_ASSERT(device->subgroup_ballot);
 
-        CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 5)
 #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
         if (device->coopmat_bf16_support) {
-            CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
+            CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 5)
         }
 #endif
-        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S],   matmul_id_subgroup_iq1_s_f16,   mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M],   matmul_id_subgroup_iq1_m_f16,   mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS],  matmul_id_subgroup_iq2_xs_f16,  mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S],   matmul_id_subgroup_iq2_s_f16,   mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S],   matmul_id_subgroup_iq3_s_f16,   mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS],  matmul_id_subgroup_iq4_xs_f16,  mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL],  matmul_id_subgroup_iq4_nl_f16,  mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4],   matmul_id_subgroup_mxfp4_f16,   mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S],   matmul_id_subgroup_iq1_s_f16,   mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M],   matmul_id_subgroup_iq1_m_f16,   mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS],  matmul_id_subgroup_iq2_xs_f16,  mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S],   matmul_id_subgroup_iq2_s_f16,   mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S],   matmul_id_subgroup_iq3_s_f16,   mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS],  matmul_id_subgroup_iq4_xs_f16,  mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL],  matmul_id_subgroup_iq4_nl_f16,  mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
+        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4],   matmul_id_subgroup_mxfp4_f16,   mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
 #undef CREATE_MM
 #undef CREATE_MM2
     } else
@@ -3181,35 +3589,35 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
         GGML_ASSERT(device->subgroup_ballot);
 
-        CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
-        CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
-        CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
+        CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+        CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+        CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
 #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
         if (device->coopmat_bf16_support) {
-            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
+            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
         }
 #endif
 
-        CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
-        CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
-        CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
-        CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
-        CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
-        CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
-        CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
-        CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
-        CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
-        CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
-        CREATE_MM2(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S],   matmul_id_subgroup_iq1_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
-        CREATE_MM2(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M],   matmul_id_subgroup_iq1_m_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
-        CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
-        CREATE_MM2(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS],  matmul_id_subgroup_iq2_xs_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
-        CREATE_MM2(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S],   matmul_id_subgroup_iq2_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
-        CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
-        CREATE_MM2(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S],   matmul_id_subgroup_iq3_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
-        CREATE_MM2(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS],  matmul_id_subgroup_iq4_xs_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
-        CREATE_MM2(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL],  matmul_id_subgroup_iq4_nl_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
-        CREATE_MM2(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4],   matmul_id_subgroup_mxfp4_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+        CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+        CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+        CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+        CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+        CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+        CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+        CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+        CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+        CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+        CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+        CREATE_MM2(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S],   matmul_id_subgroup_iq1_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+        CREATE_MM2(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M],   matmul_id_subgroup_iq1_m_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+        CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+        CREATE_MM2(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS],  matmul_id_subgroup_iq2_xs_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+        CREATE_MM2(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S],   matmul_id_subgroup_iq2_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+        CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+        CREATE_MM2(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S],   matmul_id_subgroup_iq3_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+        CREATE_MM2(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS],  matmul_id_subgroup_iq4_xs_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+        CREATE_MM2(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL],  matmul_id_subgroup_iq4_nl_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
+        CREATE_MM2(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4],   matmul_id_subgroup_mxfp4_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
 #undef CREATE_MM2
 #undef CREATE_MM
     } else
@@ -3294,91 +3702,91 @@ static void ggml_vk_load_shaders(vk_device& device) {
 #endif
 
         if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
-            CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
-            CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
-            CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
-            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
-
-            CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM2(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S],   matmul_id_subgroup_iq1_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM2(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M],   matmul_id_subgroup_iq1_m_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM2(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS],  matmul_id_subgroup_iq2_xs_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM2(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S],   matmul_id_subgroup_iq2_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM2(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S],   matmul_id_subgroup_iq3_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM2(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS],  matmul_id_subgroup_iq4_xs_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM2(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL],  matmul_id_subgroup_iq4_nl_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM2(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4],   matmul_id_subgroup_mxfp4_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+            CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
+            CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
+            CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
+            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
+
+            CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM2(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S],   matmul_id_subgroup_iq1_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM2(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M],   matmul_id_subgroup_iq1_m_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM2(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS],  matmul_id_subgroup_iq2_xs_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM2(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S],   matmul_id_subgroup_iq2_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM2(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S],   matmul_id_subgroup_iq3_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM2(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS],  matmul_id_subgroup_iq4_xs_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM2(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL],  matmul_id_subgroup_iq4_nl_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM2(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4],   matmul_id_subgroup_mxfp4_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 
 #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
             if (device->integer_dot_product) {
-                CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-                CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-                CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-                CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-                CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-
-                CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-
-                CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
-                CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
-                CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
-                CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
-                CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
+                CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+                CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+                CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+                CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+                CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+
+                CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+
+                CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
+                CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
+                CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
+                CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
+                CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
             }
 #endif
         } else {
-            CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
-            CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
-            CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
-            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
-
-            CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM2(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S],   matmul_id_iq1_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM2(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M],   matmul_id_iq1_m_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM2(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS],  matmul_id_iq2_xs_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM2(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S],   matmul_id_iq2_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM2(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S],   matmul_id_iq3_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM2(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS],  matmul_id_iq4_xs_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM2(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL],  matmul_id_iq4_nl_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM2(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4],   matmul_id_mxfp4_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+            CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+
+            CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM2(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S],   matmul_id_iq1_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM2(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M],   matmul_id_iq1_m_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM2(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS],  matmul_id_iq2_xs_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM2(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S],   matmul_id_iq2_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM2(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S],   matmul_id_iq3_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM2(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS],  matmul_id_iq4_xs_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM2(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL],  matmul_id_iq4_nl_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM2(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4],   matmul_id_mxfp4_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 
 #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
             if (device->integer_dot_product) {
-                CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, 4, _id, 0);
-                CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, 4, _id, 0);
-                CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, 4, _id, 0);
-                CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, 4, _id, 0);
-                CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, 4, _id, 0);
-
-                CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, 4, _id, 0);
-
-                CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
-                CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
-                CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
-                CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
-                CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
+                CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+                CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+                CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+                CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+                CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+
+                CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+
+                CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+                CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+                CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+                CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+                CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
             }
 #endif
         }
@@ -3455,57 +3863,57 @@ static void ggml_vk_load_shaders(vk_device& device) {
 #endif
 
         if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
-            CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
-            CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
-            CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16);
-            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
-
-            CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc,   matmul_id_subgroup_iq1_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc,   matmul_id_subgroup_iq1_m_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc,  matmul_id_subgroup_iq2_xs_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc,   matmul_id_subgroup_iq2_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc,   matmul_id_subgroup_iq3_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc,  matmul_id_subgroup_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc,  matmul_id_subgroup_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
-            CREATE_MM(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc,   matmul_id_subgroup_mxfp4_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+            CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
+            CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
+            CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
+            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
+
+            CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc,   matmul_id_subgroup_iq1_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc,   matmul_id_subgroup_iq1_m_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc,  matmul_id_subgroup_iq2_xs_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc,   matmul_id_subgroup_iq2_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc,   matmul_id_subgroup_iq3_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc,  matmul_id_subgroup_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc,  matmul_id_subgroup_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
+            CREATE_MM(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc,   matmul_id_subgroup_mxfp4_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
         } else {
-            CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
-            CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
-            CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
-            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
-
-            CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc,   matmul_id_iq1_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc,   matmul_id_iq1_m_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc,  matmul_id_iq2_xs_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc,   matmul_id_iq2_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc,   matmul_id_iq3_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc,  matmul_id_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc,  matmul_id_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
-            CREATE_MM(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc,   matmul_id_mxfp4_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+            CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+
+            CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc,   matmul_id_iq1_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc,   matmul_id_iq1_m_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc,  matmul_id_iq2_xs_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc,   matmul_id_iq2_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc,   matmul_id_iq3_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc,  matmul_id_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc,  matmul_id_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
+            CREATE_MM(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc,   matmul_id_mxfp4_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
         }
     }
     // reusing CREATE_MM from the fp32 path
@@ -3514,17 +3922,24 @@ static void ggml_vk_load_shaders(vk_device& device) {
         && !device->coopmat_bf16_support
 #endif
         ) {
+        const uint32_t s_warptile_wm = device->subgroup_size == 8 ? 8 : 32;
+
         // use scalar tile sizes
         l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
         m_warptile = { 128,  64,  64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 };
-        s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 };
+        s_warptile = { subgroup_size_32, 32, 32, 16, s_warptile_wm, 32, 2, 2, 2, 1, subgroup_size_8 };
 
         l_wg_denoms = {128, 128, 1 };
         m_wg_denoms = { 64,  64, 1 };
         s_wg_denoms = { 32,  32, 1 };
 
+        if (device->vendor_id == VK_VENDOR_ID_INTEL && device->architecture == INTEL_XE2) {
+            // Xe2/Xe3 - bf16 warptile performance tuning
+            l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, 4, 4, 1, subgroup_size_8 };
+        }
+
         CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
-        CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0);
+        CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
     }
 #undef CREATE_MM
 
@@ -3535,6 +3950,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     uint32_t rm_kq = 2;
     uint32_t rm_stdq_int = 1;
     uint32_t rm_kq_int = 1;
+    auto const &rm_iq_int = [](uint32_t i) { return i == 0 ? 8u : 4u; };
     if (device->vendor_id == VK_VENDOR_ID_AMD) {
         if (device->architecture == AMD_GCN) {
             rm_stdq = 2;
@@ -3638,6 +4054,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
                 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_q8_1_f32", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
                 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_q8_1_f32", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
                 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_q8_1_f32", arr_dmmv_q6_k_q8_1_f32_len[reduc], arr_dmmv_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
+
+                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_q8_1_f32", arr_dmmv_iq1_s_q8_1_f32_len[reduc], arr_dmmv_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int);
+                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_q8_1_f32", arr_dmmv_iq1_m_q8_1_f32_len[reduc], arr_dmmv_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int);
+
             }
 #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
         }
@@ -3671,19 +4091,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
             const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
             const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4);
 
-            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_q8_1_f32", arr_dmmv_id_q4_0_q8_1_f32_len[reduc], arr_dmmv_id_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
-            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_q8_1_f32", arr_dmmv_id_q4_1_q8_1_f32_len[reduc], arr_dmmv_id_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
-            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_q8_1_f32", arr_dmmv_id_q5_0_q8_1_f32_len[reduc], arr_dmmv_id_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
-            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_q8_1_f32", arr_dmmv_id_q5_1_q8_1_f32_len[reduc], arr_dmmv_id_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
-            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_q8_1_f32", arr_dmmv_id_q8_0_q8_1_f32_len[reduc], arr_dmmv_id_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
+            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_q8_1_f32", arr_dmmv_id_q4_0_q8_1_f32_len[reduc], arr_dmmv_id_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
+            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_q8_1_f32", arr_dmmv_id_q4_1_q8_1_f32_len[reduc], arr_dmmv_id_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
+            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_q8_1_f32", arr_dmmv_id_q5_0_q8_1_f32_len[reduc], arr_dmmv_id_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
+            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_q8_1_f32", arr_dmmv_id_q5_1_q8_1_f32_len[reduc], arr_dmmv_id_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
+            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_q8_1_f32", arr_dmmv_id_q8_0_q8_1_f32_len[reduc], arr_dmmv_id_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
+
+            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_q8_1_f32", arr_dmmv_id_mxfp4_q8_1_f32_len[reduc], arr_dmmv_id_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
 
-            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_q8_1_f32", arr_dmmv_id_mxfp4_q8_1_f32_len[reduc], arr_dmmv_id_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
+            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_q8_1_f32", arr_dmmv_id_q2_k_q8_1_f32_len[reduc], arr_dmmv_id_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
+            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_q8_1_f32", arr_dmmv_id_q3_k_q8_1_f32_len[reduc], arr_dmmv_id_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
+            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
+            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
+            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
 
-            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_q8_1_f32", arr_dmmv_id_q2_k_q8_1_f32_len[reduc], arr_dmmv_id_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
-            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_q8_1_f32", arr_dmmv_id_q3_k_q8_1_f32_len[reduc], arr_dmmv_id_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
-            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
-            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
-            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
+            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_q8_1_f32", arr_dmmv_id_iq1_s_q8_1_f32_len[reduc], arr_dmmv_id_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int);
+            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_q8_1_f32", arr_dmmv_id_iq1_m_q8_1_f32_len[reduc], arr_dmmv_id_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int);
         }
 #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
     }
@@ -3691,6 +4114,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
 #if !defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
     GGML_UNUSED(rm_stdq_int);
     GGML_UNUSED(rm_kq_int);
+    GGML_UNUSED(rm_iq_int);
 #endif
 
     // dequant shaders
@@ -3767,12 +4191,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4],   "get_rows_mxfp4_f32",   get_rows_mxfp4_f32_len,   get_rows_mxfp4_f32_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
+
+    for (auto &it : device->pipeline_fa_mask_opt) {
+        auto BrBc = it.first;
+        ggml_vk_create_pipeline(device, it.second, "fa_mask_opt", fa_mask_opt_len, fa_mask_opt_data, "main", 2, sizeof(vk_op_flash_attn_mask_opt_push_constants), {1, 1, 1}, {128, 128 / device->subgroup_size, BrBc.first, BrBc.second}, 1, true, true, device->subgroup_size);
+    }
 
     if (device->subgroup_clustered && device->subgroup_require_full_support) {
-        ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
+        ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
     } else {
-        ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_len, quantize_q8_1_x4_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_len, quantize_q8_1_x4_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
     }
 
     for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
@@ -3799,7 +4228,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     }
 
     ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -3900,7 +4329,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 1}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_set_f32, "set_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 0}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -3909,6 +4339,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_upscale_bicubic_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BICUBIC}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_antialias_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
@@ -3949,6 +4380,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     CREATE_UNARY(gelu_quick)
     CREATE_UNARY(silu)
     CREATE_UNARY(relu)
+    CREATE_UNARY(xielu)
     CREATE_UNARY(neg)
     CREATE_UNARY(tanh)
     CREATE_UNARY(sigmoid)
@@ -3978,9 +4410,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 
 #define CREATE_GLU(name)  \
     if (device->float_controls_rte_fp16) {  \
@@ -4030,6 +4462,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
         ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
         ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_rte_len, rope_multi_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
     } else {
         ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
         ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@@ -4038,6 +4471,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
         ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
         ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
     }
 
     for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
@@ -4073,10 +4507,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
 
-    ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size }, 1, true, true, device->subgroup_size);
+    const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4;
+    ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32,       "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 256, device->subgroup_size, cumsum_elem_per_thread }, 1, true, true, device->subgroup_size);
+    ggml_vk_create_pipeline(device, device->pipeline_cumsum_small_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size, 1 }, 1, true, true, device->subgroup_size);
+    ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass1_f32, "cumsum_multipass1_f32", cumsum_multipass1_f32_len, cumsum_multipass1_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size);
+    ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass2_f32, "cumsum_multipass2_f32", cumsum_multipass2_f32_len, cumsum_multipass2_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size);
 
     ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
 
+    ggml_vk_create_pipeline(device, device->pipeline_count_experts, "count_experts", count_experts_len, count_experts_data, "main", 2, sizeof(vk_op_count_experts_push_constants), {1, 1, 1}, {}, 1, true);
+
     for (auto &s : device->pipeline_solve_tri_f32) {
         const vk_solve_tri_pipeline_state &state = s.first;
 
@@ -4118,8 +4558,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
 
     if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
-        ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
-        ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
+        ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true);
+        ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true);
     } else {
         ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
         ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
@@ -4227,9 +4667,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     for (uint32_t use_push = 0; use_push < 2; ++use_push) {
         for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
-            ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX][use_push],      "topk_moe_f32_early_softmax_"+std::to_string(i),       topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size);
-            ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM][use_push], "topk_moe_f32_early_softmax_norm"+std::to_string(i),   topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size);
-            ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX][use_push],       "topk_moe_f32_late_softmax"+std::to_string(i),         topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size);
+            ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][use_push], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 4, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size);
         }
     }
 
@@ -4239,6 +4677,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
 }
 
 static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch);
+static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev);
 
 static vk_device ggml_vk_get_device(size_t idx) {
     VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
@@ -4248,9 +4687,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
         vk_device device = std::make_shared();
         vk_instance.devices[idx] = device;
 
-#ifdef GGML_VULKAN_MEMORY_DEBUG
         device->memory_logger = std::unique_ptr(new vk_memory_logger());
-#endif
 
         size_t dev_num = vk_instance.device_indices[idx];
 
@@ -4288,6 +4725,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
         bool pipeline_executable_properties_support = false;
         device->coopmat_support = false;
         device->integer_dot_product = false;
+        device->shader_64b_indexing = false;
         bool bfloat16_support = false;
 
         for (const auto& properties : ext_props) {
@@ -4333,6 +4771,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
             } else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 &&
                        getenv("GGML_VK_ENABLE_MEMORY_PRIORITY")) {
                 device->memory_priority = true;
+            } else if (strcmp("VK_EXT_external_memory_host", properties.extensionName) == 0) {
+                device->external_memory_host = true;
+#if defined(VK_EXT_shader_64bit_indexing)
+            } else if (strcmp("VK_EXT_shader_64bit_indexing", properties.extensionName) == 0) {
+                device->shader_64b_indexing = true;
+#endif
             }
         }
 
@@ -4347,6 +4791,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
         vk::PhysicalDeviceVulkan12Properties vk12_props;
         vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
         vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
+        vk::PhysicalDeviceExternalMemoryHostPropertiesEXT external_memory_host_props;
 
         props2.pNext = &props3;
         props3.pNext = &subgroup_props;
@@ -4386,11 +4831,22 @@ static vk_device ggml_vk_get_device(size_t idx) {
             last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
         }
 
+        if (device->external_memory_host) {
+            last_struct->pNext = (VkBaseOutStructure *)&external_memory_host_props;
+            last_struct = (VkBaseOutStructure *)&external_memory_host_props;
+        }
+
         device->physical_device.getProperties2(&props2);
         device->properties = props2.properties;
         device->vendor_id = device->properties.vendorID;
         device->driver_id = driver_props.driverID;
 
+        if (device->driver_id == vk::DriverId::eMoltenvk) {
+            // Disable external_memory_host until https://github.com/KhronosGroup/MoltenVK/pull/2622
+            // is available in the Vulkan SDK.
+            device->external_memory_host = false;
+        }
+
         // Implementing the async backend interfaces seems broken on older Intel HW,
         // see https://github.com/ggml-org/llama.cpp/issues/17302.
         device->support_async = (device->vendor_id != VK_VENDOR_ID_INTEL ||
@@ -4438,11 +4894,15 @@ static vk_device ggml_vk_get_device(size_t idx) {
             device->shader_core_count = sm_props.shaderSMCount;
         } else if (amd_shader_core_properties2) {
             device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount;
+        } else if (device->vendor_id == VK_VENDOR_ID_INTEL) {
+            device->shader_core_count = ggml_vk_intel_shader_core_count(device->physical_device);
         } else {
             device->shader_core_count = 0;
         }
         device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
 
+        device->subgroup_basic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
+                                 (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBasic);
         device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
                                       (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
 #ifdef __APPLE__
@@ -4472,6 +4932,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
 
         device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
 
+        device->min_imported_host_pointer_alignment = external_memory_host_props.minImportedHostPointerAlignment;
+
         device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations)));
 
         std::vector queue_family_props = device->physical_device.getQueueFamilyProperties();
@@ -4603,6 +5065,20 @@ static vk_device ggml_vk_get_device(size_t idx) {
             device_extensions.push_back("VK_KHR_pipeline_executable_properties");
         }
 
+        if (device->external_memory_host) {
+            device_extensions.push_back("VK_EXT_external_memory_host");
+        }
+
+#if defined(VK_EXT_shader_64bit_indexing)
+        VkPhysicalDeviceShader64BitIndexingFeaturesEXT shader_64bit_indexing_features {};
+        shader_64bit_indexing_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_64_BIT_INDEXING_FEATURES_EXT;
+        if (device->shader_64b_indexing) {
+            last_struct->pNext = (VkBaseOutStructure *)&shader_64bit_indexing_features;
+            last_struct = (VkBaseOutStructure *)&shader_64bit_indexing_features;
+            device_extensions.push_back("VK_EXT_shader_64bit_indexing");
+        }
+#endif
+
         vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
 
         device->pipeline_executable_properties_support = pipeline_executable_properties_support;
@@ -4639,11 +5115,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
 
 #if defined(VK_KHR_cooperative_matrix)
         device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
-
-        // coopmat1 fa shader currently assumes 32 invocations per subgroup
-        device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support &&
-                                      device->subgroup_size_control && device->subgroup_min_size <= 32 &&
-                                      device->subgroup_max_size >= 32;
+        device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support;
 #endif
 
         if (coopmat2_support) {
@@ -4869,11 +5341,23 @@ static vk_device ggml_vk_get_device(size_t idx) {
             switch (device->vendor_id) {
 #ifndef GGML_VULKAN_RUN_TESTS
             case VK_VENDOR_ID_AMD:
-            case VK_VENDOR_ID_INTEL:
-                device->mul_mat_l[i] = false;
+                device->mul_mat_l[i]    = device->coopmat_support && device->driver_id != vk::DriverId::eAmdProprietary;
+                device->mul_mat_m[i]    = true;
+                device->mul_mat_s[i]    = true;
+                device->mul_mat_id_l[i] = false;
+                device->mul_mat_id_m[i] = true;
+                device->mul_mat_id_s[i] = true;
+                break;
+            case VK_VENDOR_ID_INTEL:
+                if (!device->coopmat_support || device->architecture != INTEL_XE2) {
+                    device->mul_mat_l[i] = false;
+                    device->mul_mat_id_l[i] = false;
+                } else {
+                    device->mul_mat_l[i] = true;  // if coopmat & XE2+, allow large matmul warptile config for Intel
+                    device->mul_mat_id_l[i] = true;
+                }
                 device->mul_mat_m[i] = true;
                 device->mul_mat_s[i] = true;
-                device->mul_mat_id_l[i] = false;
                 device->mul_mat_id_m[i] = true;
                 device->mul_mat_id_s[i] = true;
                 break;
@@ -4915,13 +5399,19 @@ static vk_device ggml_vk_get_device(size_t idx) {
 
         ggml_vk_load_shaders(device);
 
+        const bool prefers_transfer_queue = device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN;
+
         if (!device->single_queue) {
             const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
             ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true);
+
+            device->async_use_transfer_queue = prefers_transfer_queue || (getenv("GGML_VK_ASYNC_USE_TRANSFER_QUEUE") != nullptr);
         } else {
             // TODO: Use pointer or reference to avoid copy
             device->transfer_queue.copyFrom(device->compute_queue);
             device->transfer_queue.cmd_pool.init(device, &device->transfer_queue);
+
+            device->async_use_transfer_queue = false;
         }
 
         device->buffer_type = {
@@ -5196,6 +5686,13 @@ static void ggml_vk_instance_init() {
     }
 
     vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
+    vk_perf_logger_concurrent = getenv("GGML_VK_PERF_LOGGER_CONCURRENT") != nullptr;
+    vk_enable_sync_logger = getenv("GGML_VK_SYNC_LOGGER") != nullptr;
+    vk_memory_logger_enabled = getenv("GGML_VK_MEMORY_LOGGER") != nullptr;
+    const char* GGML_VK_PIPELINE_STATS = getenv("GGML_VK_PIPELINE_STATS");
+    if (GGML_VK_PIPELINE_STATS != nullptr) {
+        vk_pipeline_stats_filter = GGML_VK_PIPELINE_STATS;
+    }
     const char* GGML_VK_PERF_LOGGER_FREQUENCY = getenv("GGML_VK_PERF_LOGGER_FREQUENCY");
 
     if (GGML_VK_PERF_LOGGER_FREQUENCY != nullptr) {
@@ -5242,22 +5739,30 @@ static void ggml_vk_instance_init() {
 
             if ((new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu || new_props.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu) && ggml_vk_device_is_supported(devices[i])) {
                 // Check if there are two physical devices corresponding to the same GPU
+                // This handles the case where the same GPU appears with different drivers (e.g., RADV + AMDVLK on Linux),
+                // see https://github.com/ggml-org/llama.cpp/pull/7582 for original deduplication.
+                // MoltenVK on macOS may report the same UUID for distinct GPUs on multi-GPU cards,
+                // see https://github.com/KhronosGroup/MoltenVK/issues/2683. Skip when both old/new
+                // driver is MoltenVK
                 auto old_device = std::find_if(
                     vk_instance.device_indices.begin(),
                     vk_instance.device_indices.end(),
-                    [&devices, &new_id](const size_t k){
+                    [&devices, &new_id, &new_driver](const size_t k){
                         vk::PhysicalDeviceProperties2 old_props;
+                        vk::PhysicalDeviceDriverProperties old_driver;
                         vk::PhysicalDeviceIDProperties old_id;
-                        old_props.pNext = &old_id;
+                        old_props.pNext = &old_driver;
+                        old_driver.pNext = &old_id;
                         devices[k].getProperties2(&old_props);
 
-                        bool equals = std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID));
-                        equals = equals || (
+                        bool same_uuid = std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID));
+                        same_uuid = same_uuid || (
                             old_id.deviceLUIDValid && new_id.deviceLUIDValid &&
                             std::equal(std::begin(old_id.deviceLUID), std::end(old_id.deviceLUID), std::begin(new_id.deviceLUID))
                         );
+                        bool both_molten_vk = (new_driver.driverID == vk::DriverId::eMoltenvk && old_driver.driverID == vk::DriverId::eMoltenvk);
 
-                        return equals;
+                        return same_uuid && !both_molten_vk;
                     }
                 );
                 if (old_device == vk_instance.device_indices.end()) {
@@ -5294,6 +5799,10 @@ static void ggml_vk_instance_init() {
                             driver_priorities[vk::DriverId::eMesaNvk] = 2;
 #endif
                             break;
+                        case VK_VENDOR_ID_QUALCOMM:
+                            driver_priorities[vk::DriverId::eQualcommProprietary] = 1;
+                            driver_priorities[vk::DriverId::eMesaTurnip] = 2;
+                            break;
                     }
                     driver_priorities[vk::DriverId::eMesaDozen] = 100;
 
@@ -5376,7 +5885,15 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
     ctx->almost_ready_fence = ctx->device->device.createFence({});
 
     ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue);
-    ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue);
+    if (ctx->device->async_use_transfer_queue) {
+        vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 };
+        vk::SemaphoreCreateInfo ci{};
+        ci.setPNext(&tci);
+        ctx->transfer_semaphore.s = ctx->device->device.createSemaphore(ci);
+        ctx->transfer_semaphore.value = 0;
+
+        ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue);
+    }
 
     if (vk_perf_logger_enabled) {
         ctx->perf_logger = std::unique_ptr(new vk_perf_logger());
@@ -5518,6 +6035,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
             case GGML_TYPE_Q4_K:
             case GGML_TYPE_Q5_K:
             case GGML_TYPE_Q6_K:
+            case GGML_TYPE_IQ1_S:
+            case GGML_TYPE_IQ1_M:
                 break;
             default:
                 return nullptr;
@@ -5674,6 +6193,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
             case GGML_TYPE_Q4_K:
             case GGML_TYPE_Q5_K:
             case GGML_TYPE_Q6_K:
+            case GGML_TYPE_IQ1_S:
+            case GGML_TYPE_IQ1_M:
                 break;
             default:
                 return nullptr;
@@ -5872,9 +6393,13 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
         std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), ";
     }
     std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
+    GGML_ASSERT(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
+                wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
+                wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
     GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
     GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
     GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
+    GGML_ASSERT(pipeline->push_constant_size == push_constant_size(push_constants));
 
     vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++];
     vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
@@ -5917,6 +6442,47 @@ static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) {
     subctx->s = subctx->seqs[subctx->seqs.size() - 1].data();
 }
 
+static vk_context ggml_vk_get_compute_ctx(ggml_backend_vk_context * ctx) {
+    if (!ctx->compute_ctx.expired()) {
+        return ctx->compute_ctx.lock();
+    }
+
+    vk_context result = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
+
+    ctx->compute_ctx = result;
+    ggml_vk_ctx_begin(ctx->device, result);
+
+    if (ctx->device->async_use_transfer_queue && ctx->transfer_semaphore_last_submitted < ctx->transfer_semaphore.value) {
+        result->s->wait_semaphores.push_back(ctx->transfer_semaphore);
+        ctx->transfer_semaphore_last_submitted = ctx->transfer_semaphore.value;
+    }
+
+    return result;
+}
+
+// Submit any pending transfer queue work and signal the transfer semaphore.
+// The next compute context created via ggml_vk_get_compute_ctx will wait on this semaphore.
+// Returns true if work was submitted.
+static bool ggml_vk_submit_transfer_ctx(ggml_backend_vk_context * ctx) {
+    if (!ctx->device->async_use_transfer_queue || ctx->transfer_ctx.expired()) {
+        return false;
+    }
+
+    vk_context cpy_ctx = ctx->transfer_ctx.lock();
+    ggml_vk_ctx_end(cpy_ctx);
+
+    for (auto& cpy : cpy_ctx->in_memcpys) {
+        memcpy(cpy.dst, cpy.src, cpy.n);
+    }
+
+    ctx->transfer_semaphore.value++;
+    cpy_ctx->seqs.back().back().signal_semaphores.push_back(ctx->transfer_semaphore);
+
+    ggml_vk_submit(cpy_ctx, {});
+    ctx->transfer_ctx.reset();
+    return true;
+}
+
 static size_t ggml_vk_align_size(size_t width, size_t align) {
     VK_LOG_DEBUG("ggml_vk_align_size(" << width << ", " << align << ")");
     return CEIL_DIV(width, align) * align;
@@ -6055,13 +6621,8 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
     }
 }
 
-static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) {
+static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) {
     VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")");
-    // Buffer is already mapped
-    if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
-        std::cerr << "ggml_vulkan: buffer_write_async dst buffer is host_visible. Use synchronous write." << std::endl;
-        GGML_ABORT("fatal error");
-    }
     // Check if src is pinned memory
     vk_buffer buf = nullptr;
     size_t buf_offset = 0;
@@ -6086,12 +6647,13 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
 
         ggml_vk_sync_buffers(nullptr, subctx);
         subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
-        return;
+        return true;
     }
     VK_LOG_DEBUG("STAGING");
 
     if (!sync_staging) {
-        GGML_ABORT("Asynchronous write to non-pinned memory not supported");
+        // copy was not handled caller needs to fall back
+        return false;
     }
 
     // Staging buffer required
@@ -6115,9 +6677,10 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
             deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys);
         }
     }
+    return true;
 }
 
-static void ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) {
+static bool ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) {
     VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")");
     return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging);
 }
@@ -6136,7 +6699,8 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void *
 
         vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
         ggml_vk_ctx_begin(dst->device, subctx);
-        ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true);
+        bool ret = ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true);
+        GGML_ASSERT(ret);
         ggml_vk_ctx_end(subctx);
 
         for (auto& cpy : subctx->in_memcpys) {
@@ -6414,8 +6978,16 @@ static void ggml_vk_matmul(
         uint32_t padded_n) {
         VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
     if (split_k == 1) {
-        const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
-        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch });
+        ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
+
+        uint32_t base_work_group_z = 0;
+        while (base_work_group_z < batch) {
+            uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
+
+            const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k, ne02, ne12, broadcast2, broadcast3, padded_n };
+            ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, groups_z });
+            base_work_group_z += groups_z;
+        }
         return;
     }
 
@@ -6429,9 +7001,17 @@ static void ggml_vk_matmul(
     uint32_t k_split = CEIL_DIV(k, split_k);
     k_split = ROUNDUP_POW2(k_split, 256);
 
-    const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
-    // Make sure enough workgroups get assigned for split k to work
-    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
+    ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
+
+    uint32_t base_work_group_z = 0;
+    while (base_work_group_z < batch) {
+        uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
+
+        const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
+        // Make sure enough workgroups get assigned for split k to work
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, groups_z });
+        base_work_group_z += groups_z;
+    }
     ggml_vk_sync_buffers(ctx, subctx);
     const std::array pc2 = { (uint32_t)(m * n * batch), split_k };
     ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 });
@@ -6471,18 +7051,18 @@ static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context *
 
 static void ggml_vk_matmul_id(
         ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
-        vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
+        vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids, const vk_subbuffer & expert_count_buf,
         uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
         uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
         uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11,
         uint32_t padded_n) {
-    VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " <<
+    VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), expert_count: (" << expert_count_buf.buffer->buffer << ", " << expert_count_buf.offset << ", " << expert_count_buf.size << "), " <<
         "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
         "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
         "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
     const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
                                               nei0, nei1, nbi1, ne11, padded_n };
-    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, pc, { m, nei1, n_as });
+    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids, expert_count_buf }, pc, { m, nei1, n_as });
 }
 
 static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
@@ -6654,10 +7234,34 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub
 
     vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
 
-    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array{ne}, { ne, 1, 1 });
+    const uint32_t num_blocks = CEIL_DIV(ne, pipeline->wg_denoms[0]);
+    // clamp the number of elements to the max workgroup count. The shader will iterate over the total number of blocks.
+    const uint64_t max_elements = std::min(uint64_t{ctx->device->properties.limits.maxComputeWorkGroupCount[0]} * pipeline->wg_denoms[0], std::numeric_limits::max());
+    const uint32_t elements = std::min(ne, static_cast(max_elements));
+
+    const vk_quantize_q8_1_push_constants pc = {
+        ne,
+        num_blocks,
+    };
+
+    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, { elements, 1, 1 });
     ggml_vk_sync_buffers(ctx, subctx);
 }
 
+static vk_pipeline ggml_vk_get_64b_indexing_pipeline(ggml_backend_vk_context * ctx, vk_pipeline &pipeline) {
+    GGML_UNUSED(ctx);
+#if defined(VK_EXT_shader_64bit_indexing)
+    vk_pipeline *ptr = &pipeline;
+    while (*ptr) {
+        if ((*ptr)->is_64b_indexing) {
+            return *ptr;
+        }
+        ptr = &(*ptr)->next;
+    }
+#endif
+    return pipeline;
+}
+
 static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool disable_split_k) {
     VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << ggml_type_name(src0->type) << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
     std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << ggml_type_name(src1->type) << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
@@ -6741,6 +7345,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
 
     vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
 
+    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
+        pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
+    }
+
     // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
     uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
     const uint64_t x_ne = ggml_nelements(src0);
@@ -6799,7 +7407,6 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
         }
 
         // Request descriptor sets
-        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
         if (qx_needs_dequant) {
             ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
         }
@@ -6938,7 +7545,7 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
     // Quantization overhead is not worth it for small k
     switch (device->vendor_id) {
     case VK_VENDOR_ID_NVIDIA:
-        if (src0_type == GGML_TYPE_Q2_K) {
+        if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) {
             return true;
         }
 
@@ -6969,6 +7576,18 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
             return false;
         }
 
+        if (device->driver_id == vk::DriverId::eIntelProprietaryWindows) {
+            // Intel Windows proprietary driver tuning
+            switch (src0_type) {
+            case GGML_TYPE_MXFP4:
+            case GGML_TYPE_Q4_K:
+            case GGML_TYPE_Q5_K:
+                return false;
+            default:
+                return true;
+            }
+        }
+
         switch (src0_type) {
         // From tests on A770 Linux, may need more tuning
         case GGML_TYPE_Q4_0:
@@ -7050,6 +7669,10 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
         to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
     }
 
+    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
+        dmmv = ggml_vk_get_64b_indexing_pipeline(ctx, dmmv);
+    }
+
     const bool qx_needs_dequant = x_non_contig;
     const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig);
 
@@ -7093,7 +7716,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
         if (quantize_y) {
             ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
         }
-        ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
     }
 
     vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
@@ -7188,22 +7810,29 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
         fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
     }
 
-    // compute
-    const vk_mat_vec_push_constants pc = {
-        (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
-        stride_batch_x, stride_batch_y, stride_batch_d,
-        fusion_flags,
-        (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
-    };
-    ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
-                              {
-                                d_X,
-                                d_Y,
-                                d_D,
-                                d_F0,
-                                d_F1,
-                              },
-                              pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
+    ggml_pipeline_request_descriptor_sets(ctx, dmmv, CEIL_DIV(ne12 * ne13, ctx->device->properties.limits.maxComputeWorkGroupCount[1]));
+
+    uint32_t base_work_group_y = 0;
+    while (base_work_group_y < ne12 * ne13) {
+
+        uint32_t groups_y = std::min((uint32_t)(ne12 * ne13) - base_work_group_y, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
+        const vk_mat_vec_push_constants pc = {
+            (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
+            stride_batch_x, stride_batch_y, stride_batch_d,
+            fusion_flags, base_work_group_y,
+            (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
+        };
+        ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
+                                  {
+                                    d_X,
+                                    d_Y,
+                                    d_D,
+                                    d_F0,
+                                    d_F1,
+                                  },
+                                  pc, { groups_x, groups_y, groups_z });
+        base_work_group_y += groups_y;
+    }
 
     if (x_non_contig) {
         ctx->prealloc_x_need_sync = true;
@@ -7245,9 +7874,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
         gqa_ratio = 1;
     }
 
+    vk_pipeline pipeline = ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1];
+
+    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
+        pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
+    }
+
     {
         // Request descriptor sets
-        ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1);
+        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
     }
 
     vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops], true);
@@ -7289,7 +7924,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
         workgroups_z /= gqa_ratio;
     }
 
-    ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1],
+    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
         {
             d_Qx,
             d_Qy,
@@ -7339,9 +7974,14 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
     const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);
     const uint32_t channel_stride_y = nb12 / sizeof(float);
 
+    vk_pipeline pipeline = ctx->device->pipeline_mul_mat_vec_nc_f16_f32;
+    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
+        pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
+    }
+
     {
         // Request descriptor sets
-        ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1);
+        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
     }
 
     vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops], true);
@@ -7378,7 +8018,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
 
     init_pushconst_tensor_offsets(ctx, pc, src0, src1, nullptr, nullptr, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
 
-    ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
+    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
         {
             d_Qx,
             d_Qy,
@@ -7397,8 +8037,9 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
     // Handle huge A matrix by splitting the M dimensions. This works well for convolution use cases
     // where the M dimension is very large.
     // Split_k doesn't work with M splitting.
+    // This only supports batchsize == 1.
     const size_t nbytes = ggml_nbytes(src0);
-    const bool needs_split = nbytes > ctx->device->properties.limits.maxStorageBufferRange;
+    const bool needs_split = dst->ne[2] == 1 && dst->ne[3] == 1 && nbytes > ctx->device->properties.limits.maxStorageBufferRange;
     if (needs_split) {
         // Choose the number of rows that can fit (and divide by two, to allow for any additional offsets)
         const uint32_t M_split = ctx->device->properties.limits.maxStorageBufferRange / (2 * src0->nb[1]);
@@ -7429,10 +8070,15 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
         src1->nb[2] <= src1->nb[1] &&
         src1->nb[1] <= src1->nb[3] &&
         src0->ne[3] == 1 &&
-        src1->ne[3] == 1) {
+        src1->ne[3] == 1 &&
+        src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
+        src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
         ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, cgraph, node_idx);
     } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
-               !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
+               !ggml_is_permuted(src0) && !ggml_is_permuted(src1) &&
+               src0->ne[3] <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
+               src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
+               src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
         ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, cgraph, node_idx);
     // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
     // when ne12 and ne13 are one.
@@ -7465,6 +8111,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
     const uint64_t nei0 = ids->ne[0];
     const uint64_t nei1 = ids->ne[1];
 
+    const uint32_t nbi0 = ids->nb[0];
     const uint32_t nbi1 = ids->nb[1];
     const uint32_t nbi2 = ids->nb[2];
 
@@ -7539,6 +8186,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
 
     vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
 
+    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
+        pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
+    }
     // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
     uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
     const uint64_t x_ne = ggml_nelements(src0);
@@ -7572,6 +8222,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
     if (quantize_y) {
         to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
     }
+    vk_pipeline count_experts = ctx->device->pipeline_count_experts;
+
+    uint32_t expert_count_size = sizeof(uint32_t) * n_as;
 
     {
         if (
@@ -7587,6 +8240,10 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
             ctx->prealloc_size_y = y_sz;
             ggml_vk_preallocate_buffers(ctx, subctx);
         }
+        if (ctx->prealloc_size_split_k < expert_count_size) {
+            ctx->prealloc_size_split_k = expert_count_size;
+            ggml_vk_preallocate_buffers(ctx, subctx);
+        }
 
         // Request descriptor sets
         ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
@@ -7599,6 +8256,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
         if (quantize_y) {
             ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
         }
+        ggml_pipeline_request_descriptor_sets(ctx, count_experts, 1);
     }
 
     vk_buffer d_D = dst_buf_ctx->dev_buffer;
@@ -7648,6 +8306,20 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
             ggml_vk_sync_buffers(ctx, subctx);
         }
     }
+    // Count how many times each expert is used
+    vk_subbuffer expert_count_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
+    if (ctx->prealloc_split_k_need_sync) {
+        ggml_vk_sync_buffers(ctx, subctx);
+    }
+    {
+        const std::vector pc = { (uint32_t)nei0,
+                                           (uint32_t)nei1,
+                                           (uint32_t)(nbi0 / ggml_type_size(ids->type)),
+                                           (uint32_t)(nbi1 / ggml_type_size(ids->type)),
+                                           (uint32_t)(get_misalign_bytes(ctx, ids) / ggml_type_size(ids->type)) };
+        ggml_vk_dispatch_pipeline(ctx, subctx, count_experts,
+            { vk_subbuffer{ d_ids, ids_buf_offset, ids_sz }, expert_count_buf }, pc, { (uint32_t)n_as, 1, 1});
+    }
 
     if (x_non_contig) {
         ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0));
@@ -7655,7 +8327,6 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
         const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
         ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0,
             { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_X, 0, x_sz } }, pc, { (uint32_t)x_ne, 1, 1});
-        ggml_vk_sync_buffers(ctx, subctx);
     }
     if (y_non_contig) {
         if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
@@ -7679,6 +8350,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
             ctx->prealloc_y_last_tensor_used = src1;
         }
     }
+    ggml_vk_sync_buffers(ctx, subctx);
 
     uint32_t stride_batch_x = ne00*ne01;
     uint32_t stride_batch_y = ne10*ne11;
@@ -7695,7 +8367,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
     ggml_vk_matmul_id(
         ctx, subctx, pipeline,
         { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz },
-        { d_D, d_buf_offset, d_sz }, { d_ids, ids_buf_offset, ids_sz },
+        { d_D, d_buf_offset, d_sz }, { d_ids, ids_buf_offset, ids_sz }, expert_count_buf,
         ne01, ne21, ne10, ne10, ne10, ne01,
         stride_batch_x, stride_batch_y, ne20*ne21,
         n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n
@@ -7707,6 +8379,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
     if (y_non_contig || quantize_y) {
         ctx->prealloc_y_need_sync = true;
     }
+    ctx->prealloc_split_k_need_sync = true;
 }
 
 static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
@@ -7735,8 +8408,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
 
     const uint64_t nei0 = ids->ne[0];
     const uint64_t nei1 = ids->ne[1];
-
-    GGML_ASSERT(nei1 == 1);
+    const uint32_t nbi1 = (uint32_t)(ids->nb[1] / sizeof(int));
 
     const uint64_t ne20 = dst->ne[0];
     const uint64_t ne21 = dst->ne[1];
@@ -7777,6 +8449,10 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
     const bool qx_needs_dequant = x_non_contig;
     const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig);
 
+    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
+        dmmv = ggml_vk_get_64b_indexing_pipeline(ctx, dmmv);
+    }
+
     // Not implemented
     GGML_ASSERT(y_non_contig || !qy_needs_dequant);  // NOLINT
     GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT
@@ -7816,7 +8492,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
         if (quantize_y) {
             ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
         }
-        ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
+        ggml_pipeline_request_descriptor_sets(ctx, dmmv, nei1);
     }
 
     vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
@@ -7874,7 +8550,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
     uint32_t stride_batch_y = ne10*ne11;
 
     if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
-        stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
+        stride_batch_y = src1->nb[2] / ggml_type_size(src1->type);
     }
 
     const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
@@ -7910,23 +8586,25 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
         fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1;
     }
 
-    // compute
-    const vk_mat_vec_id_push_constants pc = {
-        (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
-        (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
-        fusion_flags,
-        (uint32_t)nei0, (uint32_t)ne11,
-    };
-    ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
-        {
-            d_X,
-            d_Y,
-            d_D,
-            d_F0,
-            d_F1,
-            d_ids,
-        },
-        pc, { groups_x, (uint32_t)nei0, groups_z });
+    // Loop over the batch dimension
+    for (uint32_t expert_i1 = 0; expert_i1 < nei1; ++expert_i1) {
+        const vk_mat_vec_id_push_constants pc = {
+            (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
+            (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
+            fusion_flags,
+            (uint32_t)nei0, (uint32_t)ne11, expert_i1, nbi1
+        };
+        ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
+            {
+                d_X,
+                d_Y,
+                d_D,
+                d_F0,
+                d_F1,
+                d_ids,
+            },
+            pc, { groups_x, (uint32_t)nei0, groups_z });
+    }
 
     if (x_non_contig) {
         ctx->prealloc_x_need_sync = true;
@@ -7940,7 +8618,7 @@ static bool ggml_vk_use_mul_mat_vec_id(const struct ggml_cgraph * cgraph, int no
     ggml_tensor * dst = cgraph->nodes[node_idx];
     ggml_tensor * src0 = dst->src[0];
     ggml_tensor * src2 = dst->src[2];
-    return src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type));
+    return (src2->ne[1] <= 8) && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type));
 }
 
 static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
@@ -7956,55 +8634,70 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
     }
 }
 
-static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) {
+static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
+    GGML_UNUSED(f32acc);
     // Needs to be kept up to date on shader changes
-    GGML_UNUSED(hsv);
-    const uint32_t wg_size = scalar_flash_attention_workgroup_size;
-    const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv);
-    const uint32_t Bc = scalar_flash_attention_Bc;
+    const uint32_t wg_size = params.workgroup_size;
+    const uint32_t Br = params.block_rows;
+    const uint32_t Bc = params.block_cols;
+
+    const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
 
+    // tmpsh is overestimated slightly
     const uint32_t tmpsh = wg_size * sizeof(float);
-    const uint32_t tmpshv4 = wg_size * 4 * sizeof(float);
+    const uint32_t tmpshv4 = wg_size * 4 * float_type_size;
 
-    const uint32_t masksh = Bc * Br * sizeof(float);
+    const uint32_t masksh = Bc * (Br + 1) * float_type_size;
 
-    const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float);
+    const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
 
-    const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
+    const uint32_t D = std::max(hsk, hsv);
+    const uint32_t kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
+
+    const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh;
     const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
 
-    VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
+    VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
 
     return supported;
 }
 
-static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) {
+static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
     // Needs to be kept up to date on shader changes
-    GGML_UNUSED(hsv);
-    const uint32_t wg_size = scalar_flash_attention_workgroup_size;
-    const uint32_t Br = coopmat1_flash_attention_num_large_rows;
-    const uint32_t Bc = scalar_flash_attention_Bc;
+    const uint32_t Br = params.block_rows;
+    const uint32_t Bc = params.block_cols;
+
+    const uint32_t MatBr = 16, MatBc = 16;
+
+    const uint32_t row_split = Bc / MatBc;
 
     const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16);
+    const uint32_t hsv_pad = ROUNDUP_POW2(hsv, 16);
 
     const uint32_t acctype = f32acc ? 4 : 2;
     const uint32_t f16vec4 = 8;
 
-    const uint32_t tmpsh = wg_size * sizeof(float);
-    const uint32_t tmpshv4 = wg_size * 4 * acctype;
+    const uint32_t tmpsh = (Bc / MatBc) * sizeof(float);
 
     const uint32_t qstride = hsk_pad / 4 + 2;
     const uint32_t Qf = Br * qstride * f16vec4;
 
+    const uint32_t psh_stride = Br / 4 + 2;
+    const uint32_t Psh = Bc * psh_stride * f16vec4;
+
     const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
     const uint32_t sfsh = Bc * sfshstride * acctype;
 
-    const uint32_t kshstride = hsk_pad / 4 + 2;
-    const uint32_t ksh = Bc * kshstride * f16vec4;
+    const uint32_t kvshstride = (params.shmem_staging ? std::max(hsk_pad, hsv_pad) : MatBr) / 4 + 2;
+    const uint32_t vsh_stride = MatBc / 4 * row_split;
+    const uint32_t ksh = ((kvshstride >= vsh_stride) ? (Bc * kvshstride) : (Bc * vsh_stride)) * f16vec4;
+
+    const uint32_t osh_stride = params.row_split * MatBr / 4;
+    const uint32_t pvsh = MatBc * osh_stride * f16vec4;
 
-    const uint32_t slope = Br * sizeof(float);
+    const uint32_t slope = Br * acctype;
 
-    const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
+    const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + pvsh + slope;
     const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
 
     VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
@@ -8031,6 +8724,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
     GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
 
+    const uint32_t nem0 = mask ? mask->ne[0] : 0;
     const uint32_t nem1 = mask ? mask->ne[1] : 0;
     const uint32_t nem2 = mask ? mask->ne[2] : 0;
     const uint32_t nem3 = mask ? mask->ne[3] : 0;
@@ -8064,70 +8758,30 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     assert(q->type == GGML_TYPE_F32);
     assert(k->type == v->type);
 
-    FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 :
-                      ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
-
-    if (path == FA_COOPMAT1) {
-        const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
-                                             (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
-
-        const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32);
-
-        if (!coopmat_shape_supported || !coopmat_shmem_supported) {
-            path = FA_SCALAR;
-        }
-    }
-
     uint32_t gqa_ratio = 1;
     uint32_t qk_ratio = neq2 / nek2;
     uint32_t workgroups_x = (uint32_t)neq1;
     uint32_t workgroups_y = (uint32_t)neq2;
     uint32_t workgroups_z = (uint32_t)neq3;
 
+    const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32;
+
     // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
     // For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
-    uint32_t max_gqa;
-    switch (path) {
-    case FA_SCALAR:
-    case FA_COOPMAT1:
-        // We may switch from coopmat1 to scalar, so use the scalar limit for both
-        max_gqa = get_fa_scalar_num_large_rows(HSK, HSV);
-        break;
-    case FA_COOPMAT2:
-        max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
-        break;
-    default:
-        GGML_ASSERT(0);
-    }
+    vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, f32acc);
+    const uint32_t max_gqa = std::min(tuning_params.block_rows, 32u);
 
-    if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
+    if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa &&
         qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
         // grouped query attention - make the N dimension equal to gqa_ratio, reduce
         // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
         // and change addressing calculations to index Q's dimension 2.
         gqa_ratio = qk_ratio;
         N = gqa_ratio;
-        workgroups_y /= N;
+        workgroups_y /= gqa_ratio;
     }
 
-    bool small_rows = N <= get_fa_num_small_rows(path);
-
-    // coopmat1 does not actually support "small rows" (it needs 16 rows).
-    // So use scalar instead.
-    if (small_rows && path == FA_COOPMAT1) {
-        path = FA_SCALAR;
-    }
-
-    // scalar is faster than coopmat2 when N==1
-    if (N == 1 && path == FA_COOPMAT2) {
-        path = FA_SCALAR;
-    }
-
-    // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
-    if (path == FA_SCALAR &&
-        !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) {
-        small_rows = true;
-    }
+    tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, f32acc);
 
     const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
     uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
@@ -8141,19 +8795,32 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
         v_stride /= 4;
     }
 
-    uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
+    const uint32_t alignment = tuning_params.block_cols;
     bool aligned = (KV % alignment) == 0 &&
                    // the "aligned" shader variant will forcibly align strides, for performance
                    (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
 
     // Need to use the coopmat2 variant that clamps loads when HSK/HSV aren't sufficiently aligned.
-    if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) {
+    if (((HSK | HSV) % 16) != 0 && tuning_params.path == FA_COOPMAT2) {
         aligned = false;
     }
 
-    bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
+    float scale         = 1.0f;
+    float max_bias      = 0.0f;
+    float logit_softcap = 0.0f;
+
+    memcpy(&scale,         (const float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias,      (const float *) dst->op_params + 1, sizeof(float));
+    memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
+
+    if (logit_softcap != 0) {
+        scale /= logit_softcap;
+    }
 
-    vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, path, aligned, f32acc);
+    // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
+    bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768;
+    vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc,
+                                                                   mask != nullptr, use_mask_opt, logit_softcap != 0);
 
     vk_pipeline pipeline = nullptr;
 
@@ -8169,29 +8836,46 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     }
 
     assert(pipeline);
+    // Compile early to initialize wg_denoms.
+    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
 
     uint32_t split_kv = KV;
     uint32_t split_k = 1;
 
+    // Intel Alchemist prefers more workgroups
+    const uint32_t shader_core_count_multiplier = (ctx->device->vendor_id == VK_VENDOR_ID_INTEL && ctx->device->architecture != INTEL_XE2) ? 2 : 1;
+
     // Use a placeholder core count if one isn't available. split_k is a big help for perf.
-    const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
+    const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count * shader_core_count_multiplier : 16;
 
-    // Try to use split_k when KV is large enough to be worth the overhead
-    if (workgroups_x == 1 && shader_core_count > 0) {
-        // Try to run two workgroups per SM.
-        split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
-        if (split_k > 1) {
-            // Try to evenly split KV into split_k chunks, but it needs to be a multiple
-            // of "align", so recompute split_k based on that.
-            split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
-            split_k = CEIL_DIV(KV, split_kv);
-            workgroups_x = split_k;
+    const uint32_t Br = fa_pipeline_state.Br;
+    const uint32_t Bc = fa_pipeline_state.Bc;
+
+    GGML_ASSERT(Br == pipeline->wg_denoms[0]);
+    const uint32_t Tr = CEIL_DIV(N, Br);
+
+    // Try to use split_k when KV is large enough to be worth the overhead.
+    if (gqa_ratio > 1 && workgroups_x <= Br) {
+        split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z);
+    } else if (gqa_ratio <= 1) {
+        uint32_t total_wgs_no_split = Tr * workgroups_y * workgroups_z;
+        if (total_wgs_no_split < shader_core_count * 2) {
+            split_k = shader_core_count * 2 / total_wgs_no_split;
         }
     }
 
+    if (split_k > 1) {
+        // Try to evenly split KV into split_k chunks, but it needs to be a multiple
+        // of "align", so recompute split_k based on that.
+        split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
+        split_k = CEIL_DIV(KV, split_kv);
+    }
+
     // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
     // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
-    const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
+    // For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3].
+    // For L/M, the order is (inner to outer) [ne1, k, ne2, ne3].
+    const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne2 * ne3 : 0;
     if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) {
         GGML_ABORT("Requested preallocation size is too large");
     }
@@ -8200,24 +8884,29 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
         ggml_vk_preallocate_buffers(ctx, subctx);
     }
 
-    {
-        // Request descriptor sets
-        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
-        if (split_k > 1) {
-            ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
-        }
-    }
-
-    float scale         = 1.0f;
-    float max_bias      = 0.0f;
-    float logit_softcap = 0.0f;
+    const uint32_t mask_opt_num_dwords = CEIL_DIV(nem0, 16 * Bc);
+    const uint64_t mask_opt_size = sizeof(uint32_t) * mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2 * nem3;
 
-    memcpy(&scale,         (const float *) dst->op_params + 0, sizeof(float));
-    memcpy(&max_bias,      (const float *) dst->op_params + 1, sizeof(float));
-    memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
+    vk_pipeline pipeline_fa_mask_opt = nullptr;
+    if (use_mask_opt) {
+        std::lock_guard guard(ctx->device->mutex);
+        auto &pipelines = ctx->device->pipeline_fa_mask_opt;
+        auto it = pipelines.find({Br, Bc});
+        if (it != pipelines.end()) {
+            pipeline_fa_mask_opt = it->second;
+        } else {
+            pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared();
+        }
+        assert(pipeline_fa_mask_opt);
+        ggml_pipeline_request_descriptor_sets(ctx, pipeline_fa_mask_opt, 1);
 
-    if (logit_softcap != 0) {
-        scale /= logit_softcap;
+        if (ctx->prealloc_size_y < mask_opt_size) {
+            ctx->prealloc_size_y = mask_opt_size;
+            ggml_vk_preallocate_buffers(ctx, subctx);
+        }
+        if (ctx->prealloc_y_need_sync) {
+            ggml_vk_sync_buffers(ctx, subctx);
+        }
     }
 
     const uint32_t n_head_kv   = neq2;
@@ -8231,8 +8920,29 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
     vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer(ctx, mask) : q_buf;
     vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf;
+    vk_subbuffer mask_opt_buf = use_mask_opt ? ggml_vk_subbuffer(ctx, ctx->prealloc_y, 0) : q_buf;
+
+    uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | n_head_log2;
+
+    if (use_mask_opt)
+    {
+        const vk_op_flash_attn_mask_opt_push_constants opt_pc = {
+            nem0,
+            nem1,
+            nem2,
+            (uint32_t)(mask->nb[1] / sizeof(ggml_fp16_t)),
+            (uint32_t)(mask->nb[2] / sizeof(ggml_fp16_t)),
+            (uint32_t)(mask->nb[3] / sizeof(ggml_fp16_t)),
+            mask_opt_num_dwords,
+            mask_opt_num_dwords * CEIL_DIV(nem1, Br),
+            mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2,
+        };
 
-    uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2;
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline_fa_mask_opt,
+                                  { mask_buf, mask_opt_buf }, opt_pc,
+                                  { mask_opt_num_dwords, CEIL_DIV(nem1, Br), nem2 * nem3 });
+        ggml_vk_sync_buffers(ctx, subctx);
+    }
 
     const vk_flash_attn_push_constants pc = { N, KV,
                                               (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
@@ -8248,28 +8958,40 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
                                               gqa_ratio, split_kv, split_k };
 
     if (split_k > 1) {
+        ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
+
         if (ctx->prealloc_split_k_need_sync) {
             ggml_vk_sync_buffers(ctx, subctx);
         }
 
+        // We reuse workgroups_x to mean the number of splits, so we need to
+        // cancel out the divide by wg_denoms[0].
+        uint32_t dispatch_x;
+        if (gqa_ratio > 1) {
+            workgroups_x *= pipeline->wg_denoms[0];
+            dispatch_x = split_k * workgroups_x;
+        } else {
+            dispatch_x = Tr * split_k * pipeline->wg_denoms[0];
+        }
+
         vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
         ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
-                                    {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf},
-                                    // We only use split_k when group query attention is enabled, which means
-                                    // there's no more than one tile of rows (i.e. workgroups_x would have been
-                                    // one). We reuse workgroups_x to mean the number of splits, so we need to
-                                    // cancel out the divide by wg_denoms[0].
-                                    pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
+                                    {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf, mask_opt_buf},
+                                    pc, { dispatch_x, workgroups_y, workgroups_z });
 
         ggml_vk_sync_buffers(ctx, subctx);
-        const std::array pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) };
+        const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) };
         ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
                                     {split_k_buf, sinks_buf, dst_buf},
-                                    pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
+                                    pc2, { (uint32_t)ne1, HSV, (uint32_t)(ne2 * ne3) });
         ctx->prealloc_split_k_need_sync = true;
     } else {
+        if (gqa_ratio > 1) {
+            // When using gqa, we want one actual workgroup per batch, so cancel out wg_denoms
+            workgroups_x *= pipeline->wg_denoms[0];
+        }
         ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
-                                    {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf},
+                                    {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf, mask_opt_buf},
                                     pc, { workgroups_x, workgroups_y, workgroups_z });
     }
 }
@@ -8314,6 +9036,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_acc_f32;
         }
         return nullptr;
+    case GGML_OP_SET:
+        if (src0->type == src1->type && src0->type == dst->type &&
+            (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32)) {
+            return ctx->device->pipeline_set_f32;
+        }
+        return nullptr;
     case GGML_OP_ADD:
     case GGML_OP_SUB:
     case GGML_OP_MUL:
@@ -8378,7 +9106,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
         return nullptr;
     case GGML_OP_UPSCALE:
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(dst, 0) & 0xFF);
+            uint32_t mode = (ggml_get_op_params_i32(dst, 0) & (0xFF | GGML_SCALE_FLAG_ANTIALIAS));
             switch (mode) {
                 case GGML_SCALE_MODE_NEAREST:
                     return ctx->device->pipeline_upscale_nearest_f32;
@@ -8386,6 +9114,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
                     return ctx->device->pipeline_upscale_bilinear_f32;
                 case GGML_SCALE_MODE_BICUBIC:
                     return ctx->device->pipeline_upscale_bicubic_f32;
+                case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS:
+                    return ctx->device->pipeline_upscale_bilinear_antialias_f32;
                 default:
                     return nullptr;
             }
@@ -8523,6 +9253,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
                 return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
             case GGML_UNARY_OP_RELU:
                 return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16];
+            case GGML_UNARY_OP_XIELU:
+                return ctx->device->pipeline_xielu[dst->type == GGML_TYPE_F16];
             case GGML_UNARY_OP_NEG:
                 return ctx->device->pipeline_neg[dst->type == GGML_TYPE_F16];
             case GGML_UNARY_OP_TANH:
@@ -8587,10 +9319,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
         if (ctx->num_additional_fused_ops) {
             uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
             GGML_ASSERT(idx < num_topk_moe_pipelines);
-            topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
             // use n_experts from push constant if it's not equal to the power of two spec constant
             bool use_push = dst->ne[0] != (1u << idx);
-            return ctx->device->pipeline_topk_moe[idx][mode][use_push];
+            return ctx->device->pipeline_topk_moe[idx][use_push];
         }
 
         if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
@@ -8628,6 +9359,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
                 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
                     return ctx->device->pipeline_rope_multi_f32;
                 }
+                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
+                    return ctx->device->pipeline_rope_multi_f32_f16;
+                }
                 if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
                     return ctx->device->pipeline_rope_multi_f16;
                 }
@@ -8660,7 +9394,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
         return nullptr;
     case GGML_OP_CUMSUM:
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_cumsum_f32;
+            if (src0->ne[0] <= 512) {
+                return ctx->device->pipeline_cumsum_small_f32;
+            } else {
+                return ctx->device->pipeline_cumsum_f32;
+            }
         }
         return nullptr;
     case GGML_OP_SOLVE_TRI:
@@ -9031,10 +9769,20 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
             elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 };
         } break;
     case GGML_OP_DIAG_MASK_INF:
-    case GGML_OP_ROPE:
-    case GGML_OP_ROPE_BACK:
         elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
         break;
+    case GGML_OP_ROPE:
+    case GGML_OP_ROPE_BACK:
+        {
+            uint32_t nrows = (uint32_t)ggml_nrows(src0);
+            uint32_t z = 1;
+            if (nrows > ctx->device->properties.limits.maxComputeWorkGroupCount[0]) {
+                z = CEIL_DIV(nrows, 32768);
+                nrows = 32768;
+            }
+            elements = { nrows, (uint32_t)ne00, z };
+
+        } break;
     case GGML_OP_GET_ROWS:
         elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };
         elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
@@ -9058,6 +9806,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
             const uint32_t batch = src1->ne[is_2D ? 3 : 2];
 
             elements = { OW * KW * KH, OH, batch * IC };
+            elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
+            elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
         } break;
     case GGML_OP_IM2COL_3D:
         {
@@ -9278,16 +10028,16 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
     const uint32_t src1_type_size = ggml_type_size(src1->type);
     const uint32_t dst_type_size = ggml_type_size(dst->type);
 
-    int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
-    int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
-    // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
-    int offset = dst->op_params[3] / 4; // offset in bytes
+    int nb1 = dst->op_params[0] / src0_type_size; // 4 bytes of float32
+    int nb2 = dst->op_params[1] / src0_type_size; // 4 bytes of float32
+    int nb3 = dst->op_params[2] / src0_type_size; // 4 bytes of float32
+    int offset = dst->op_params[3] / src0_type_size; // offset in bytes
 
-    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ACC, {
+    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, {
         (uint32_t)ggml_nelements(src0),
-        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
+        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
         (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
-        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] /  dst_type_size,
+        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
         0,
         0.0f, 0.0f, offset,
     });
@@ -9597,8 +10347,9 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx,
 
     std::array elements;
 
-    const int splitH = 16;
-    const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, splitH);
+    const uint32_t d_state = src0->ne[0];
+    uint32_t num_subgroups = d_state / ctx->device->subgroup_size;
+    const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, num_subgroups);
     const uint32_t num_workgroups_y = n_seq;
     elements = { num_workgroups_x, num_workgroups_y, 1 };
 
@@ -9669,14 +10420,14 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
 
     ggml_vk_op_f32_opt_step_adamw(
         ctx, subctx, dst,
-        { (uint32_t)n, 0, 0.0f, 0.0f }
+        { (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f }
     );
 }
 
 static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
     const size_t n = ggml_nelements(dst->src[0]);
 
-    ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f });
+    ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f });
 }
 
 static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -9762,6 +10513,7 @@ static void ggml_vk_arange(ggml_backend_vk_context * ctx, vk_context& subctx, gg
         1,
         ggml_get_op_params_f32(dst, 0),
         ggml_get_op_params_f32(dst, 2),
+        0.0f, 0.0f,
     };
 
     vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_ARANGE);
@@ -9783,6 +10535,7 @@ static void ggml_vk_fill(ggml_backend_vk_context * ctx, vk_context& subctx, ggml
         1,
         ggml_get_op_params_f32(dst, 0),
         0.0f,
+        0.0f, 0.0f,
     };
 
     vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_FILL);
@@ -9898,13 +10651,13 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
 }
 
 static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
+    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
 }
 
 static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     float * op_params = (float *)dst->op_params;
 
-    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
+    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
 }
 
 static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@@ -9915,7 +10668,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
     const float eps = float_op_params[1];
     const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
 
-    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f });
+    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f, 0.0f, 0.0f });
 }
 
 static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
@@ -9956,12 +10709,22 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *
 
     uint32_t nb01 = src0->nb[1] / ggml_type_size(src0->type);
     uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type);
+    uint32_t nb03 = src0->nb[3] / ggml_type_size(src0->type);
+
+    uint32_t nb11 = dst->nb[1] / ggml_type_size(dst->type);
+    uint32_t nb12 = dst->nb[2] / ggml_type_size(dst->type);
+    uint32_t nb13 = dst->nb[3] / ggml_type_size(dst->type);
 
     vk_op_rope_push_constants rope {
-        (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
-        freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
-        has_ff, (uint32_t)src0->ne[2], nb01, nb02,
+        (uint32_t)mode, (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale,
+        freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, has_ff,
         { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
+
+        (uint32_t)src0->ne[0],
+        (uint32_t)src0->ne[1],
+        (uint32_t)src0->ne[2],
+        nb01, nb02, nb03,
+        nb11, nb12, nb13,
     };
 
     return rope;
@@ -10084,16 +10847,28 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
 
 static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     float * op_params = (float *)dst->op_params;
-    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
+    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
 }
 
 static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
-    float * op_params = (float *)dst->op_params;
-    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
+    const float * op_params = (const float *)dst->op_params;
+    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
+    p.param1 = op_params[0];
+    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, std::move(p));
 }
 
 static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
-    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
+    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
+}
+
+static void ggml_vk_xielu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+    float * op_params = (float *)dst->op_params;
+    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY,
+        {
+            (uint32_t)ggml_nelements(src0), 0,
+            op_params[1], op_params[2], op_params[3], op_params[4]
+        }
+    );
 }
 
 static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -10218,18 +10993,20 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
 
 static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     float * op_params = (float *)dst->op_params;
-    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] });
+    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1], 0.0f, 0.0f });
 }
 
 static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
-    topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
+    topk_moe_mode mode = ctx->fused_topk_moe_mode;
     ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
-    ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] :
-                            (mode == TOPK_MOE_EARLY_SOFTMAX)      ? cgraph->nodes[node_idx + 4] :
-                                                                    cgraph->nodes[node_idx + 5];
-    ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3];
+    ggml_tensor * bias = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 2]->src[1] : logits;
+    ggml_tensor * weights = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
+    ggml_tensor * ids = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 4] :
+                        (mode == TOPK_MOE_LATE_SOFTMAX) ?      cgraph->nodes[node_idx + 1] :
+                                                               cgraph->nodes[node_idx + 3];
 
     GGML_ASSERT(logits->type == GGML_TYPE_F32);
+    GGML_ASSERT(bias->type == GGML_TYPE_F32);
     GGML_ASSERT(weights->type == GGML_TYPE_F32);
     GGML_ASSERT(ids->type == GGML_TYPE_I32);
 
@@ -10244,6 +11021,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
     ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
 
     vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer(ctx, logits);
+    vk_subbuffer bias_buf = ggml_vk_tensor_subbuffer(ctx, bias);
     vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer(ctx, weights);
     vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer(ctx, ids);
 
@@ -10251,18 +11029,45 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
     pc.n_rows = n_rows;
     pc.n_experts_push = n_experts;
     pc.n_expert_used = n_expert_used;
+    pc.clamp_min = -std::numeric_limits::infinity();
+    pc.clamp_max = std::numeric_limits::infinity();
     if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
         ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
+        GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
         pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
         pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
     }
+    if (mode == TOPK_MOE_SIGMOID_NORM_BIAS) {
+        ggml_tensor * clamp = cgraph->nodes[node_idx + 8];
+        GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
+        pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
+        pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
+    }
+
+#define GATING_FUNC_SOFTMAX 0
+#define GATING_FUNC_SIGMOID 1
+#define GATING_FUNC_SOFTMAX_WEIGHT 2
+
+    pc.gating_func = mode == TOPK_MOE_SIGMOID_NORM_BIAS ? GATING_FUNC_SIGMOID :
+                     mode == TOPK_MOE_LATE_SOFTMAX ?      GATING_FUNC_SOFTMAX_WEIGHT :
+                                                          GATING_FUNC_SOFTMAX;
+    pc.has_bias = mode == TOPK_MOE_SIGMOID_NORM_BIAS;
+    pc.with_norm = mode == TOPK_MOE_EARLY_SOFTMAX_NORM || mode == TOPK_MOE_SIGMOID_NORM_BIAS;
+    if (ctx->fused_topk_moe_scale) {
+        GGML_ASSERT(weights->op == GGML_OP_SCALE);
+        pc.output_scale = ggml_get_op_params_f32(weights, 0);
+        pc.output_bias = ggml_get_op_params_f32(weights, 1);
+    } else {
+        pc.output_scale = 1.0f;
+        pc.output_bias = 0.0f;
+    }
 
     GGML_ASSERT(n_expert_used <= n_experts);
 
     const uint32_t rows_per_block = 4;
     std::array elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 };
 
-    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, weights_buf, ids_buf}, pc, elements);
+    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, bias_buf, weights_buf, ids_buf}, pc, elements);
 }
 
 static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop) {
@@ -10510,22 +11315,64 @@ static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, cons
 }
 
 static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
-    vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
-    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, p);
-}
+    vk_op_sum_rows_push_constants pc = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
+    // Use the single pass shader when the rows are small or there are enough rows to fill the GPU.
+    // For fewer, larger rows, use the multipass shader to spread each row across SMs.
+    if (dst->ne[0] <= 4096 || ggml_nrows(dst) >= ctx->device->shader_core_count) {
+        ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, pc);
+        return;
+    }
 
-static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
-    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
-}
+    // First pass computes partial sums within a block, and stores the last partial
+    // to the temp buffer. Second pass sums the block partials from the temp buffer
+    // and adds that to the result of the first pass.
+    vk_pipeline pipeline1 = ctx->device->pipeline_cumsum_multipass1_f32;
+    vk_pipeline pipeline2 = ctx->device->pipeline_cumsum_multipass2_f32;
+    GGML_ASSERT(pipeline1 != nullptr && pipeline2 != nullptr);
 
-static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
-}
+    ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1);
+    ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1);
 
-static void ggml_vk_solve_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    const uint32_t src0_type_size = ggml_type_size(src0->type);
-    const uint32_t src1_type_size = ggml_type_size(src1->type);
-    const uint32_t dst_type_size = ggml_type_size(dst->type);
+    std::array elements;
+
+    elements[0] = dst->ne[0];
+    elements[1] = (uint32_t)ggml_nrows(dst);
+    elements[2] = 1;
+
+    size_t temp_size = sizeof(float) * elements[0] * ggml_nrows(dst);
+
+    if (ctx->prealloc_size_split_k < temp_size) {
+        ctx->prealloc_size_split_k = temp_size;
+        ggml_vk_preallocate_buffers(ctx, subctx);
+    }
+
+    vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src0);
+    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
+    vk_subbuffer temp_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
+
+    if (ctx->prealloc_split_k_need_sync) {
+        ggml_vk_sync_buffers(ctx, subctx);
+    }
+
+    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, {src_buf, dst_buf, temp_buf}, pc, elements);
+    ggml_vk_sync_buffers(ctx, subctx);
+    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, {src_buf, dst_buf, temp_buf}, pc, elements);
+
+    ctx->prealloc_split_k_need_sync = true;
+}
+
+static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f, 0.0f, 0.0f });
+}
+
+static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
+}
+
+static void ggml_vk_solve_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    const uint32_t src0_type_size = ggml_type_size(src0->type);
+    const uint32_t src1_type_size = ggml_type_size(src1->type);
+    const uint32_t dst_type_size = ggml_type_size(dst->type);
 
     ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOLVE_TRI, {
         (uint32_t)ggml_nelements(src0),
@@ -10561,6 +11408,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
     const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
 
     const uint32_t pelements = OW * KW * KH;
+    const uint32_t batch = src1->ne[is_2D ? 3 : 2];
 
     const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
     const vk_buffer d_buf = d_buf_ctx->dev_buffer;
@@ -10573,7 +11421,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
         IC, IW, IH, OW, OH, KW, KH,
         pelements,
         IC * KH * KW,
-        s0, s1, p0, p1, d0, d1,
+        s0, s1, p0, p1, d0, d1, batch * IC
     });
 }
 
@@ -10778,7 +11626,7 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx
 
 static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     const float * op_params = (const float *)dst->op_params;
-    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f });
+    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f, 0.0f, 0.0f });
 }
 
 #ifdef GGML_VULKAN_RUN_TESTS
@@ -10924,7 +11772,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
         }
     }
 
-    ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
     if (split_k > 1) {
         ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
 
@@ -11098,7 +11945,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
     free(d_chk);
 
     ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
-    ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool);
 
     ggml_vk_destroy_buffer(d_X);
     ggml_vk_destroy_buffer(d_Y);
@@ -11434,7 +12280,6 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
         // y[i] = i % k;
     }
 
-    ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
     if (split_k > 1) {
         ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
 
@@ -11447,7 +12292,8 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
         }
     }
     if (mmq) {
-        ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_quantize_q8_1, num_it);
+        vk_pipeline pipeline_quantize_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
+        ggml_pipeline_request_descriptor_sets(ctx, pipeline_quantize_q8_1, num_it);
     }
 
     ggml_pipeline_allocate_descriptor_sets(ctx);
@@ -11683,7 +12529,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex
         ggml_vk_submit(subctx, {});
         ctx->submit_pending = true;
         ggml_vk_synchronize(ctx);
+        GGML_ASSERT(ctx->compute_ctx.expired());
         ggml_vk_ctx_begin(ctx->device, subctx);
+        ctx->compute_ctx = subctx;
     }
 
     if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) {
@@ -11701,6 +12549,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex
             ggml_vk_destroy_buffer(ctx->prealloc_y);
         }
         ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y);
+        ctx->prealloc_y_last_tensor_used = nullptr;
     }
     if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) {
         VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")");
@@ -11729,6 +12578,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
     if (ggml_is_empty(node) || ggml_op_is_empty(node->op) || !node->buffer) {
         return false;
     }
+    if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
+        return false;
+    }
 
     VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")");
     ctx->semaphore_idx = 0;
@@ -11753,15 +12605,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
         }
     }
 
-    vk_context compute_ctx;
-
-    if (ctx->compute_ctx.expired()) {
-        compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
-        ctx->compute_ctx = compute_ctx;
-        ggml_vk_ctx_begin(ctx->device, compute_ctx);
-    } else {
-        compute_ctx = ctx->compute_ctx.lock();
-    }
+    vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
 
     {
         // This logic detects dependencies between modes in the graph and calls ggml_vk_sync_buffers
@@ -11822,15 +12666,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
             }
         }
 
-#define ENABLE_SYNC_LOGGING 0
-
         if (need_sync) {
-#if ENABLE_SYNC_LOGGING
-            std::cerr <<  "sync" << std::endl;
-#endif
+            if (vk_enable_sync_logger) {
+                std::cerr <<  "sync" << std::endl;
+            }
             ctx->unsynced_nodes_written.clear();
             ctx->unsynced_nodes_read.clear();
             ggml_vk_sync_buffers(ctx, compute_ctx);
+
+            if (vk_perf_logger_enabled && vk_perf_logger_concurrent) {
+                ctx->query_node_idx[ctx->query_idx] = node_idx;
+                compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
+            }
         }
         // Add all fused nodes to the unsynchronized lists.
         for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
@@ -11847,20 +12694,20 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
             }
         }
     }
-#if ENABLE_SYNC_LOGGING
-    for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
-        auto *n = cgraph->nodes[node_idx + i];
-        std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " <<  n->name;
-        if (n->op == GGML_OP_GLU) {
-            std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " ";
-        }
-        if (n->op == GGML_OP_ROPE) {
-            const int mode = ((const int32_t *) n->op_params)[2];
-            std::cerr << " rope mode: " << mode;
+    if (vk_enable_sync_logger) {
+        for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
+            auto *n = cgraph->nodes[node_idx + i];
+            std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " <<  n->name;
+            if (n->op == GGML_OP_GLU) {
+                std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " ";
+            }
+            if (n->op == GGML_OP_ROPE) {
+                const int mode = ((const int32_t *) n->op_params)[2];
+                std::cerr << " rope mode: " << mode;
+            }
+            std::cerr << std::endl;
         }
-        std::cerr << std::endl;
     }
-#endif
 
     switch (node->op) {
     case GGML_OP_REPEAT:
@@ -11872,6 +12719,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
 
         break;
     case GGML_OP_ACC:
+    case GGML_OP_SET:
         ggml_vk_acc(ctx, compute_ctx, src0, src1, node);
 
         break;
@@ -12000,6 +12848,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
 
         break;
     case GGML_OP_UNARY:
+        if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
+            ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
+            break;
+        }
+
         switch (ggml_get_unary_op(node)) {
         case GGML_UNARY_OP_EXP:
         case GGML_UNARY_OP_SILU:
@@ -12021,6 +12874,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
         case GGML_UNARY_OP_TRUNC:
             ggml_vk_unary(ctx, compute_ctx, src0, node);
             break;
+        case GGML_UNARY_OP_XIELU:
+            ggml_vk_xielu(ctx, compute_ctx, src0, node);
+            break;
         default:
             return false;
         }
@@ -12044,7 +12900,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
 
         break;
     case GGML_OP_SOFT_MAX:
-        if (ctx->num_additional_fused_ops) {
+        if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
             ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
         } else {
             ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node);
@@ -12064,7 +12920,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
 
         break;
     case GGML_OP_ARGSORT:
-        if (ctx->num_additional_fused_ops) {
+        if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
             ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
         } else {
             ggml_vk_argsort(ctx, compute_ctx, src0, node);
@@ -12267,7 +13123,9 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
     ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
 
     ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
-    ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool);
+    if (ctx->device->async_use_transfer_queue) {
+        ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool);
+    }
 
     for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) {
         ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s });
@@ -12296,7 +13154,7 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
 static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
     VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")");
     // discard any unsubmitted command buffers
-    ctx->transfer_ctx.reset();
+    ctx->compute_ctx.reset();
     // wait for any pending command buffers to finish
     ggml_vk_synchronize(ctx);
 
@@ -12329,7 +13187,11 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
     ctx->descriptor_sets.clear();
 
     ctx->compute_cmd_pool.destroy(ctx->device->device);
-    ctx->transfer_cmd_pool.destroy(ctx->device->device);
+    if (ctx->device->async_use_transfer_queue) {
+        ctx->device->device.destroySemaphore(ctx->transfer_semaphore.s);
+
+        ctx->transfer_cmd_pool.destroy(ctx->device->device);
+    }
     if (vk_perf_logger_enabled) {
         ctx->perf_logger->print_timings(true);
     }
@@ -12626,20 +13488,40 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor
 
     ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
 
-    vk_context transfer_ctx;
+    vk_context cpy_ctx;
 
-    if (ctx->transfer_ctx.expired()) {
-        // Initialize new transfer context
-        transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
-        ctx->transfer_ctx = transfer_ctx;
-        ggml_vk_ctx_begin(ctx->device, transfer_ctx);
+    if (ctx->device->async_use_transfer_queue) {
+        if (ctx->transfer_ctx.expired()) {
+            // Initialize new transfer context
+            cpy_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool);
+            ctx->transfer_ctx = cpy_ctx;
+            ggml_vk_ctx_begin(ctx->device, cpy_ctx);
+        } else {
+            cpy_ctx = ctx->transfer_ctx.lock();
+        }
     } else {
-        transfer_ctx = ctx->transfer_ctx.lock();
+        cpy_ctx = ggml_vk_get_compute_ctx(ctx);
     }
 
     vk_buffer buf = buf_ctx->dev_buffer;
 
-    ggml_vk_buffer_write_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
+    auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
+
+    bool ret = ggml_vk_buffer_write_async(cpy_ctx, buf, dst_offset, data, size);
+
+    if (!ret) {
+        ggml_vk_ensure_sync_staging_buffer(ctx, size);
+        ggml_vk_sync_buffers(nullptr, cpy_ctx);
+
+        vk::BufferCopy buffer_cpy;
+        buffer_cpy.srcOffset = 0;
+        buffer_cpy.dstOffset = dst_offset;
+        buffer_cpy.size = size;
+
+        cpy_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy });
+        deferred_memcpy(ctx->sync_staging->ptr, data, size, &cpy_ctx->in_memcpys);
+        ggml_vk_synchronize(ctx);
+    }
 }
 
 static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
@@ -12649,87 +13531,126 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_
 
     ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
 
-    vk_context transfer_ctx;
-
-    if (ctx->transfer_ctx.expired()) {
-        // Initialize new transfer context
-        transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
-        ctx->transfer_ctx = transfer_ctx;
-        ggml_vk_ctx_begin(ctx->device, transfer_ctx);
-    } else {
-        transfer_ctx = ctx->transfer_ctx.lock();
-    }
+    vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
 
     vk_buffer buf = buf_ctx->dev_buffer;
 
     auto src_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
-    bool ret = ggml_vk_buffer_read_async(transfer_ctx, buf, src_offset, data, size);
+    bool ret = ggml_vk_buffer_read_async(compute_ctx, buf, src_offset, data, size);
 
     // If that failed, copy synchronously through a staging buffer
     if (!ret) {
         ggml_vk_ensure_sync_staging_buffer(ctx, size);
-        ggml_vk_sync_buffers(nullptr, transfer_ctx);
+        ggml_vk_sync_buffers(nullptr, compute_ctx);
 
         vk::BufferCopy buffer_cpy;
         buffer_cpy.srcOffset = src_offset;
         buffer_cpy.dstOffset = 0;
         buffer_cpy.size = size;
 
-        transfer_ctx->s->buffer.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy });
-        deferred_memcpy(data, ctx->sync_staging->ptr, size, &transfer_ctx->out_memcpys);
+        compute_ctx->s->buffer.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy });
+        deferred_memcpy(data, ctx->sync_staging->ptr, size, &compute_ctx->out_memcpys);
         ggml_vk_synchronize(ctx);
     }
 }
 
-static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) {
+static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
     VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()");
-    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
-    if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(src->buffer)) {
-        ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
-        ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
+    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend_dst->context;
 
-        vk_context transfer_ctx;
+    if (dst->buffer->buft != ggml_backend_vk_get_default_buffer_type(backend_dst)) {
+        return false;
+    }
 
-        if (ctx->transfer_ctx.expired()) {
-            // Initialize new transfer context
-            transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
-            ctx->transfer_ctx = transfer_ctx;
-            ggml_vk_ctx_begin(ctx->device, transfer_ctx);
-        } else {
-            transfer_ctx = ctx->transfer_ctx.lock();
+    ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
+    vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
+
+    if (ggml_backend_buffer_is_vk(src->buffer)) {
+        ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
+
+        // Async copy only works within the same device
+        if (src_buf_ctx->dev_buffer->device != dst_buf->device) {
+            return false;
         }
 
-        vk_buffer src_buf = src_buf_ctx->dev_buffer;
-        vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
+        vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
 
-        ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src));
+        ggml_vk_buffer_copy_async(compute_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs,
+                                   src_buf_ctx->dev_buffer, vk_tensor_offset(src) + src->view_offs,
+                                   ggml_nbytes(src));
         return true;
     }
 
+    if (ggml_backend_buffer_is_host(src->buffer)) {
+        vk_buffer pinned_buf = nullptr;
+        size_t pinned_offset = 0;
+        ggml_vk_host_get(ctx->device, src->data, pinned_buf, pinned_offset);
+        if (pinned_buf == nullptr) {
+            return false;
+        }
+
+        vk_context cpy_ctx;
+        if (ctx->device->async_use_transfer_queue) {
+            if (ctx->transfer_ctx.expired()) {
+                cpy_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool);
+                ctx->transfer_ctx = cpy_ctx;
+                ggml_vk_ctx_begin(ctx->device, cpy_ctx);
+            } else {
+                cpy_ctx = ctx->transfer_ctx.lock();
+            }
+        } else {
+            cpy_ctx = ggml_vk_get_compute_ctx(ctx);
+        }
+
+        return ggml_vk_buffer_write_async(cpy_ctx, dst_buf,
+                                          vk_tensor_offset(dst) + dst->view_offs,
+                                          src->data, ggml_nbytes(src));
+    }
+
+    GGML_UNUSED(backend_src);
     return false;
 }
 
 static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) {
     VK_LOG_DEBUG("ggml_vk_synchronize()");
 
-    bool do_transfer = !ctx->transfer_ctx.expired();
+    bool do_transfer = !ctx->compute_ctx.expired();
 
-    vk_context transfer_ctx;
+    if (ggml_vk_submit_transfer_ctx(ctx)) {
+        ctx->submit_pending = true;
+    }
+
+    vk_context compute_ctx;
     if (do_transfer) {
-        transfer_ctx = ctx->transfer_ctx.lock();
+        compute_ctx = ctx->compute_ctx.lock();
 
-        ggml_vk_ctx_end(transfer_ctx);
+        ggml_vk_ctx_end(compute_ctx);
 
-        for (auto& cpy : transfer_ctx->in_memcpys) {
+        for (auto& cpy : compute_ctx->in_memcpys) {
             memcpy(cpy.dst, cpy.src, cpy.n);
         }
 
-        ggml_vk_submit(transfer_ctx, {});
+        ggml_vk_submit(compute_ctx, {});
         ctx->submit_pending = true;
     }
 
     if (ctx->submit_pending) {
-        {
+        if (ctx->device->async_use_transfer_queue && ctx->transfer_semaphore_last_submitted < ctx->transfer_semaphore.value) {
+            vk::TimelineSemaphoreSubmitInfo tl_info{
+                1, &ctx->transfer_semaphore.value,
+                0, nullptr,
+            };
+            vk::PipelineStageFlags stage = ctx->device->transfer_queue.stage_flags;
+            vk::SubmitInfo si{
+                1, &ctx->transfer_semaphore.s, &stage,
+                0, nullptr,
+                0, nullptr,
+            };
+            si.setPNext(&tl_info);
+            std::lock_guard guard(queue_mutex);
+            ctx->device->compute_queue.queue.submit({ si }, ctx->fence);
+            ctx->transfer_semaphore_last_submitted = ctx->transfer_semaphore.value;
+        } else {
             std::lock_guard guard(queue_mutex);
             ctx->device->compute_queue.queue.submit({}, ctx->fence);
         }
@@ -12738,10 +13659,10 @@ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) {
     }
 
     if (do_transfer) {
-        for (auto& cpy : transfer_ctx->out_memcpys) {
+        for (auto& cpy : compute_ctx->out_memcpys) {
             memcpy(cpy.dst, cpy.src, cpy.n);
         }
-        ctx->transfer_ctx.reset();
+        ctx->compute_ctx.reset();
     }
 }
 
@@ -12916,42 +13837,81 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
 
     const ggml_tensor * softmax;
     const ggml_tensor * weights;
+    const ggml_tensor * get_rows;
+    const ggml_tensor * argsort;
 
     switch (mode) {
     case TOPK_MOE_EARLY_SOFTMAX_NORM:
         softmax = cgraph->nodes[node_idx + 0];
         weights = cgraph->nodes[node_idx + 9];
+        get_rows = cgraph->nodes[node_idx + 4];
+        argsort = cgraph->nodes[node_idx + 2];
+        break;
+    case TOPK_MOE_SIGMOID_NORM_BIAS:
+        softmax = cgraph->nodes[node_idx + 0]; // really sigmoid
+        weights = cgraph->nodes[node_idx + 10];
+        get_rows = cgraph->nodes[node_idx + 5];
+        argsort = cgraph->nodes[node_idx + 3];
+        if (ggml_get_unary_op(softmax) != GGML_UNARY_OP_SIGMOID) {
+            return false;
+        }
+        // bias is expected to be 1D
+        if (ggml_nrows(cgraph->nodes[node_idx + 2]->src[1]) != 1 ||
+            !ggml_is_contiguous(cgraph->nodes[node_idx + 2]->src[1])) {
+            return false;
+        }
+        // sigmoid fusion seems to generate infinities on moltenvk
+        if (ctx->device->driver_id == vk::DriverId::eMoltenvk) {
+            return false;
+        }
         break;
     case TOPK_MOE_EARLY_SOFTMAX:
         softmax = cgraph->nodes[node_idx + 0];
         weights = cgraph->nodes[node_idx + 4];
+        get_rows = cgraph->nodes[node_idx + 4];
+        argsort = cgraph->nodes[node_idx + 2];
         break;
     case TOPK_MOE_LATE_SOFTMAX:
         softmax = cgraph->nodes[node_idx + 4];
         weights = cgraph->nodes[node_idx + 5];
+        get_rows = cgraph->nodes[node_idx + 2];
+        argsort = cgraph->nodes[node_idx + 0];
         break;
     default:
         return false;
     }
 
-    const float * op_params = (const float *)softmax->op_params;
-
-    float scale = op_params[0];
-    float max_bias = op_params[1];
-
-    if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
+    ggml_tensor * probs = get_rows->src[0];
+    if (probs->op != GGML_OP_RESHAPE) {
         return false;
     }
+    probs = probs->src[0];
+    ggml_tensor * selection_probs = argsort->src[0];
 
-    if (scale != 1.0f || max_bias != 0.0f) {
+    if (probs != selection_probs && mode != TOPK_MOE_SIGMOID_NORM_BIAS) {
         return false;
     }
 
-    // don't fuse when masks or sinks are present
-    if (softmax->src[1] || softmax->src[2]) {
+    if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
         return false;
     }
 
+    if (softmax->op == GGML_OP_SOFT_MAX) {
+        const float * op_params = (const float *)softmax->op_params;
+
+        float scale = op_params[0];
+        float max_bias = op_params[1];
+
+        if (scale != 1.0f || max_bias != 0.0f) {
+            return false;
+        }
+
+        // don't fuse when masks or sinks are present
+        if (softmax->src[1] || softmax->src[2]) {
+            return false;
+        }
+    }
+
     const int n_expert = softmax->ne[0];
     if (n_expert > (1 << (num_topk_moe_pipelines-1))) {
         return false;
@@ -12993,21 +13953,20 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const
         return false;
     }
 
-    // Only norm/neox shaders have the fusion code
+    // Only norm/neox/mrope shaders have the fusion code
     const int mode = ((const int32_t *) rope->op_params)[2];
-    if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) {
+    if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_MROPE) {
         return false;
     }
 
     return true;
 }
 
-// Check whether the tensors overlap in memory but are not equal.
-// Fusions can potenitally overwrite src tensors in ways that are not prevented
-// by ggml-alloc. If the fusion is entirely elementwise, then it's OK for them
-// to overlap if they are exactly equal.
-// XXX TODO this check is probably missing from several fusion optimizations.
-static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const ggml_tensor * b) {
+// Check whether the tensors overlap in memory.
+// Fusions can potentially overwrite src tensors in ways that are not prevented
+// by ggml-alloc. If the fusion src is being applied in a way that's elementwise
+// with the destination, then it's OK for them to overlap if they are exactly equal.
+static bool ggml_vk_tensors_overlap(const ggml_tensor * a, const ggml_tensor * b, bool elementwise) {
     ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)a->buffer->context;
     vk_buffer a_buf = a_buf_ctx->dev_buffer;
     ggml_backend_vk_buffer_context * b_buf_ctx = (ggml_backend_vk_buffer_context *)b->buffer->context;
@@ -13018,7 +13977,7 @@ static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const g
         auto b_base = vk_tensor_offset(b) + b->view_offs;
         auto b_size = ggml_nbytes(b);
 
-        if (a_base == b_base && a_size == b_size) {
+        if (elementwise && a_base == b_base && a_size == b_size) {
             return false;
         }
 
@@ -13056,13 +14015,6 @@ static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, co
         return false;
     }
 
-    // must not overwrite srcs in a way that's not elementwise
-    ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0];
-    if (ggml_vk_tensors_overlap_but_not_equal(rms->src[0], rope) ||
-        ggml_vk_tensors_overlap_but_not_equal(other_src, rope)) {
-        return false;
-    }
-
     // conditions for pipeline creation
     if (!(ctx->device->float_controls_rte_fp16 &&
         sizeof(vk_op_rms_norm_mul_rope_push_constants) <= ctx->device->properties.limits.maxPushConstantsSize)) {
@@ -13124,8 +14076,21 @@ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const stru
     return num_adds;
 }
 
+static int32_t find_first_set(uint32_t x) {
+    int32_t ret = 0;
+    if (!x) {
+        return -1;
+    }
+    while (!(x & 1)) {
+        x >>= 1;
+        ret++;
+    }
+    return ret;
+}
+
 static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) {
     VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
+    GGML_UNUSED(batch_size);
     ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
 
     if (vk_instance.debug_utils_support) {
@@ -13142,7 +14107,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
     int last_node = cgraph->n_nodes - 1;
 
     // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly
-    while (last_node > 0 && ggml_vk_is_empty(cgraph->nodes[last_node])) {
+    while (last_node > 0 && (ggml_vk_is_empty(cgraph->nodes[last_node]) || ((cgraph->nodes[last_node]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0))) {
         last_node -= 1;
     }
 
@@ -13152,6 +14117,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
     bool first_node_in_batch = true; // true if next node will be first node in a batch
     int submit_node_idx = 0; // index to first node in a batch
 
+    ggml_vk_submit_transfer_ctx(ctx);
+
     vk_context compute_ctx;
     if (vk_perf_logger_enabled) {
         // allocate/resize the query pool
@@ -13165,17 +14132,19 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
             ctx->query_pool = ctx->device->device.createQueryPool(query_create_info);
             ctx->num_queries = query_create_info.queryCount;
             ctx->query_fusion_names.resize(ctx->num_queries);
+            ctx->query_fusion_node_count.resize(ctx->num_queries);
             ctx->query_nodes.resize(ctx->num_queries);
+            ctx->query_node_idx.resize(ctx->num_queries);
         }
 
         ctx->device->device.resetQueryPool(ctx->query_pool, 0, cgraph->n_nodes+1);
         std::fill(ctx->query_fusion_names.begin(), ctx->query_fusion_names.end(), nullptr);
+        std::fill(ctx->query_fusion_node_count.begin(), ctx->query_fusion_node_count.end(), 0);
         std::fill(ctx->query_nodes.begin(), ctx->query_nodes.end(), nullptr);
+        std::fill(ctx->query_node_idx.begin(), ctx->query_node_idx.end(), 0);
 
         GGML_ASSERT(ctx->compute_ctx.expired());
-        compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
-        ctx->compute_ctx = compute_ctx;
-        ggml_vk_ctx_begin(ctx->device, compute_ctx);
+        compute_ctx = ggml_vk_get_compute_ctx(ctx);
         ctx->query_idx = 0;
         compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
     }
@@ -13185,13 +14154,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
 
     if (ctx->prealloc_size_add_rms_partials) {
         ggml_vk_preallocate_buffers(ctx, nullptr);
-        if (ctx->compute_ctx.expired()) {
-            compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
-            ctx->compute_ctx = compute_ctx;
-            ggml_vk_ctx_begin(ctx->device, compute_ctx);
-        } else {
-            compute_ctx = ctx->compute_ctx.lock();
-        }
+        compute_ctx = ggml_vk_get_compute_ctx(ctx);
         // initialize partial sums to zero.
         ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials, 0, 0, ctx->prealloc_size_add_rms_partials);
         ggml_vk_sync_buffers(ctx, compute_ctx);
@@ -13218,70 +14181,192 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
             total_mul_mat_bytes += bytes;
         }
 
+        // op_srcs_fused_elementwise indicates whether an op's srcs all contribute to
+        // the fused result in an elementwise-way. This affects whether the memory for
+        // the src is allowed to overlap the memory for the destination.
+        // The array is sized to handle the largest fusion (asserted later).
+        bool op_srcs_fused_elementwise[12];
+
+        ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;
+        ctx->fused_topk_moe_scale = false;
         const char *fusion_string {};
         if (!ctx->device->disable_fusion) {
             uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
             if (num_adds) {
                 ctx->num_additional_fused_ops = num_adds - 1;
                 fusion_string = "MULTI_ADD";
+                std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, true);
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_ADD })) {
                 ctx->num_additional_fused_ops = 2;
                 fusion_string = "MUL_MAT_ADD_ADD";
+                op_srcs_fused_elementwise[0] = false;
+                op_srcs_fused_elementwise[1] = true;
+                op_srcs_fused_elementwise[2] = true;
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) {
                 ctx->num_additional_fused_ops = 1;
                 fusion_string = "MUL_MAT_ADD";
+                op_srcs_fused_elementwise[0] = false;
+                op_srcs_fused_elementwise[1] = true;
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL })) {
                 ctx->num_additional_fused_ops = 2;
                 fusion_string = "MUL_MAT_ID_ADD_ID_MUL";
+                op_srcs_fused_elementwise[0] = false;
+                op_srcs_fused_elementwise[1] = true;
+                op_srcs_fused_elementwise[2] = true;
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {
                 ctx->num_additional_fused_ops = 1;
                 fusion_string = "MUL_MAT_ID_ADD_ID";
+                op_srcs_fused_elementwise[0] = false;
+                op_srcs_fused_elementwise[1] = true;
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) {
                 ctx->num_additional_fused_ops = 1;
                 fusion_string = "MUL_MAT_ID_MUL";
+                op_srcs_fused_elementwise[0] = false;
+                op_srcs_fused_elementwise[1] = true;
             } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) &&
                        ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) &&
                        ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) &&
                        ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i + 2)) {
                 ctx->num_additional_fused_ops = 4;
                 fusion_string = "RMS_NORM_MUL_ROPE_VIEW_SET_ROWS";
+                op_srcs_fused_elementwise[0] = false;
+                op_srcs_fused_elementwise[1] = false;
+                op_srcs_fused_elementwise[2] = false;
+                op_srcs_fused_elementwise[3] = false;
+                op_srcs_fused_elementwise[4] = false;
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE })&&
                        ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i)) {
                 ctx->num_additional_fused_ops = 2;
                 fusion_string = "RMS_NORM_MUL_ROPE";
+                // rope is approximately elementwise - whole rows are done by a single workgroup and it's row-wise
+                op_srcs_fused_elementwise[0] = false;
+                op_srcs_fused_elementwise[1] = true;
+                op_srcs_fused_elementwise[2] = true;
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
                 ctx->num_additional_fused_ops = 1;
                 fusion_string = "RMS_NORM_MUL";
+                // rms_norm is not elementwise, but whole rows must be consumed and the scale factor computed before
+                // they are overwritten, and one workgroup per row. So close enough.
+                op_srcs_fused_elementwise[0] = true;
+                op_srcs_fused_elementwise[1] = true;
             } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
                        ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
                        ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
                 ctx->num_additional_fused_ops = 2;
                 fusion_string = "ROPE_VIEW_SET_ROWS";
+                op_srcs_fused_elementwise[0] = false;
+                op_srcs_fused_elementwise[1] = false;
+                op_srcs_fused_elementwise[2] = false;
             } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
                        ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
                        ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
                 ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
                 // view of argsort writes to memory
                 ctx->fused_ops_write_mask |= 1 << 3;
+                ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM;
                 fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM";
+                std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
+            } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) &&
+                       ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) &&
+                       ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) {
+                ctx->num_additional_fused_ops = topk_moe_sigmoid_norm_bias.size() - 1;
+                // view of argsort writes to memory
+                ctx->fused_ops_write_mask |= 1 << 4;
+                ctx->fused_topk_moe_mode = TOPK_MOE_SIGMOID_NORM_BIAS;
+                fusion_string = "TOPK_MOE_SIGMOID_NORM_BIAS";
+                std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
             } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
                        ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
                        ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
                 ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
                 // view of argsort writes to memory
                 ctx->fused_ops_write_mask |= 1 << 3;
+                ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX;
                 fusion_string = "TOPK_MOE_EARLY_SOFTMAX";
+                std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
             } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
                        ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
                        ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
                 ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
                 // view of argsort writes to memory
                 ctx->fused_ops_write_mask |= 1 << 1;
+                ctx->fused_topk_moe_mode = TOPK_MOE_LATE_SOFTMAX;
                 fusion_string = "TOPK_MOE_LATE_SOFTMAX";
+                std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
+            }
+            if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
+                // Look for an additional scale op to fuse - occurs in deepseek2 and nemotron3 nano.
+                if (ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops - 1, { GGML_OP_DIV, GGML_OP_RESHAPE, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 }) ||
+                    ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops, { GGML_OP_GET_ROWS, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 })) {
+                    ctx->fused_topk_moe_scale = true;
+                    ctx->num_additional_fused_ops++;
+                    op_srcs_fused_elementwise[ctx->num_additional_fused_ops] = false;
+                }
             }
         }
+        GGML_ASSERT(ctx->num_additional_fused_ops < (int)(sizeof(op_srcs_fused_elementwise) / sizeof(op_srcs_fused_elementwise[0])));
         ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops;
 
+        // Check whether fusion would overwrite src operands while they're still in use.
+        // If so, disable fusion.
+        if (ctx->num_additional_fused_ops) {
+            // There are up to two output nodes - topk_moe has two.
+            uint32_t bits = ctx->fused_ops_write_mask & ~(1 << ctx->num_additional_fused_ops);
+            ggml_tensor *output_nodes[2] {};
+            output_nodes[0] = cgraph->nodes[i + ctx->num_additional_fused_ops];
+            if (bits) {
+                int output_idx = find_first_set(bits);
+                GGML_ASSERT(bits == (1u << output_idx));
+                output_nodes[1] = cgraph->nodes[i + output_idx];
+            }
+
+            bool need_disable = false;
+
+            // topk_moe often overwrites the source, but for a given row all the src values are
+            // loaded before anything is stored. If there's only one row, this is safe, so treat
+            // this as a special case.
+            bool is_topk_moe_single_row = ctx->fused_topk_moe_mode != TOPK_MOE_COUNT &&
+                                          ggml_nrows(cgraph->nodes[i]->src[0]) == 1;
+
+            if (!is_topk_moe_single_row) {
+                for (int j = 0; j < 2; ++j) {
+                    ggml_tensor *dst = output_nodes[j];
+                    if (!dst) {
+                        continue;
+                    }
+                    // Loop over all srcs of all nodes in the fusion. If the src overlaps
+                    // the destination and the src is not an intermediate node that's being
+                    // elided, then disable fusion.
+                    for (int k = 0; k <= ctx->num_additional_fused_ops; ++k) {
+                        for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
+                            ggml_tensor *src = cgraph->nodes[i + k]->src[s];
+                            if (!src || src->op == GGML_OP_NONE) {
+                                continue;
+                            }
+                            if (ggml_vk_tensors_overlap(src, dst, op_srcs_fused_elementwise[k])) {
+                                bool found = false;
+                                for (int n = 0; n < k; ++n) {
+                                    if (cgraph->nodes[i + n] == src) {
+                                        found = true;
+                                        break;
+                                    }
+                                }
+                                if (!found) {
+                                    need_disable = true;
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+            if (need_disable) {
+                ctx->num_additional_fused_ops = 0;
+                ctx->fused_ops_write_mask = 1;
+                ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;
+                ctx->fused_topk_moe_scale = false;
+            }
+        }
+
         // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
         bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
         bool submit = (submitted_nodes >= nodes_per_submit) ||
@@ -13292,16 +14377,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
         bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit);
 
         if (vk_perf_logger_enabled && enqueued) {
-            if (ctx->compute_ctx.expired()) {
-                compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
-                ctx->compute_ctx = compute_ctx;
-                ggml_vk_ctx_begin(ctx->device, compute_ctx);
+            compute_ctx = ggml_vk_get_compute_ctx(ctx);
+            if (!vk_perf_logger_concurrent) {
+                // track a single node/fusion for the current query
+                ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i];
+                ctx->query_fusion_names[ctx->query_idx] = fusion_string;
+                compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
             } else {
-                compute_ctx = ctx->compute_ctx.lock();
+                // track a fusion string and number of fused ops for the current node_idx
+                ctx->query_fusion_names[i] = fusion_string;
+                ctx->query_fusion_node_count[i] = ctx->num_additional_fused_ops;
             }
-            ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i];
-            ctx->query_fusion_names[ctx->query_idx] = fusion_string;
-            compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
         }
 
         if (enqueued) {
@@ -13339,16 +14425,37 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
         ggml_vk_submit(compute_ctx, ctx->device->fence);
         VK_CHECK(ctx->device->device.waitForFences({ ctx->device->fence }, true, UINT64_MAX), "GGML_VULKAN_PERF waitForFences");
         ctx->device->device.resetFences({ ctx->device->fence });
+        ctx->compute_ctx.reset();
 
         // Get the results and pass them to the logger
         std::vector timestamps(cgraph->n_nodes + 1);
         VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->query_pool, 0, ctx->query_idx, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), "get timestamp results");
-        for (int i = 1; i < ctx->query_idx; i++) {
-            auto node = ctx->query_nodes[i];
-            auto name = ctx->query_fusion_names[i];
-            ctx->perf_logger->log_timing(node, name, uint64_t((timestamps[i] - timestamps[i-1]) * ctx->device->properties.limits.timestampPeriod));
+        if (!vk_perf_logger_concurrent) {
+            // Log each op separately
+            for (int i = 1; i < ctx->query_idx; i++) {
+                auto node = ctx->query_nodes[i];
+                auto name = ctx->query_fusion_names[i];
+                ctx->perf_logger->log_timing(node, name, uint64_t((timestamps[i] - timestamps[i-1]) * ctx->device->properties.limits.timestampPeriod));
+            }
+        } else {
+            // Log each group of nodes
+            int prev_node_idx = 0;
+            for (int i = 1; i < ctx->query_idx; i++) {
+                auto cur_node_idx = ctx->query_node_idx[i];
+                std::vector nodes;
+                std::vector names;
+                for (int node_idx = prev_node_idx; node_idx < cur_node_idx; ++node_idx) {
+                    if (ggml_op_is_empty(cgraph->nodes[node_idx]->op)) {
+                        continue;
+                    }
+                    nodes.push_back(cgraph->nodes[node_idx]);
+                    names.push_back(ctx->query_fusion_names[node_idx]);
+                    node_idx += ctx->query_fusion_node_count[node_idx];
+                }
+                prev_node_idx = cur_node_idx;
+                ctx->perf_logger->log_timing(nodes, names, uint64_t((timestamps[i] - timestamps[i-1]) * ctx->device->properties.limits.timestampPeriod));
+            }
         }
-
         ctx->perf_logger->print_timings();
     }
 
@@ -13359,7 +14466,6 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
     return GGML_STATUS_SUCCESS;
 
     UNUSED(backend);
-    UNUSED(batch_size);
 }
 
 // Sort the graph for improved parallelism.
@@ -13431,6 +14537,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
         if (keep_pattern(topk_moe_early_softmax_norm)) {
             continue;
         }
+        if (keep_pattern(topk_moe_sigmoid_norm_bias)) {
+            continue;
+        }
         if (keep_pattern(topk_moe_early_softmax)) {
             continue;
         }
@@ -13457,6 +14566,7 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
             }
             // Don't pull forward nodes from fusion patterns
             if (match_pattern(topk_moe_early_softmax_norm, j) ||
+                match_pattern(topk_moe_sigmoid_norm_bias, j) ||
                 match_pattern(topk_moe_early_softmax, j) ||
                 match_pattern(topk_moe_late_softmax, j)) {
                 continue;
@@ -13468,7 +14578,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
                     !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL) &&
                     !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) &&
                     !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) &&
-                    !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL)) {
+                    !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL) &&
+                    !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD)) {
                     ok = false;
                     break;
                 }
@@ -13596,21 +14707,56 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
     }
 }
 
+static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
+    VK_LOG_DEBUG("ggml_backend_vk_event_record(backend=" << backend << ", event=" << event << ")");
+    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
+    vk_event *vkev = (vk_event *)event->context;
+
+    ggml_vk_submit_transfer_ctx(ctx);
+
+    vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
+
+    // the backend interface doesn't have an explicit reset, so reset it here
+    // before we record the command to set it
+    ctx->device->device.resetEvent(vkev->event);
+    ctx->device->device.resetFences({ vkev->fence });
+
+    ggml_vk_set_event(compute_ctx, vkev->event);
+
+    ggml_vk_ctx_end(compute_ctx);
+
+    ggml_vk_submit(compute_ctx, {vkev->fence});
+    ctx->submit_pending = true;
+    ctx->compute_ctx.reset();
+}
+
+static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
+    VK_LOG_DEBUG("ggml_backend_vk_event_wait(backend=" << backend << ", event=" << event << ")");
+    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
+    vk_event *vkev = (vk_event *)event->context;
+
+    vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
+
+    ggml_vk_wait_events(compute_ctx, {vkev->event});
+    ggml_vk_ctx_end(compute_ctx);
+    ctx->compute_ctx.reset();
+}
+
 // TODO: enable async and synchronize
 static ggml_backend_i ggml_backend_vk_interface = {
     /* .get_name                = */ ggml_backend_vk_name,
     /* .free                    = */ ggml_backend_vk_free,
-    /* .set_tensor_async        = */ NULL,  // ggml_backend_vk_set_tensor_async,
+    /* .set_tensor_async        = */ ggml_backend_vk_set_tensor_async,
     /* .get_tensor_async        = */ ggml_backend_vk_get_tensor_async,
-    /* .cpy_tensor_async        = */ NULL,  // ggml_backend_vk_cpy_tensor_async,
+    /* .cpy_tensor_async        = */ ggml_backend_vk_cpy_tensor_async,
     /* .synchronize             = */ ggml_backend_vk_synchronize,
     /* .graph_plan_create       = */ NULL,
     /* .graph_plan_free         = */ NULL,
     /* .graph_plan_update       = */ NULL,
     /* .graph_plan_compute      = */ NULL,
     /* .graph_compute           = */ ggml_backend_vk_graph_compute,
-    /* .event_record            = */ NULL,
-    /* .event_wait              = */ NULL,
+    /* .event_record            = */ ggml_backend_vk_event_record,
+    /* .event_wait              = */ ggml_backend_vk_event_wait,
     /* .graph_optimize          = */ ggml_vk_graph_optimize,
 };
 
@@ -13653,86 +14799,15 @@ void ggml_backend_vk_get_device_description(int device, char * description, size
     ggml_vk_get_device_description(dev_idx, description, description_size);
 }
 
-std::string ggml_backend_vk_get_device_id(int device) {
+void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
     GGML_ASSERT(device < (int) vk_instance.device_indices.size());
-    int dev_idx = vk_instance.device_indices[device];
-    return ggml_vk_get_device_id(dev_idx);
-}
-
-//////////////////////////
-
-struct ggml_backend_vk_device_context {
-    size_t device;
-    std::string name;
-    std::string description;
-    bool is_integrated_gpu;
-    // Combined string id in the form "dddd:bb:dd.f" (domain:bus:device.function)
-    std::string pci_id;
-    std::string id;
-    std::string uuid;
-    std::string luid;
-    int major;
-    int minor;
-    int driver_major;
-    int driver_minor;
-};
-
-void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size_t * free, size_t * total) {
-    GGML_ASSERT(ctx->device < (int) vk_instance.device_indices.size());
-    GGML_ASSERT(ctx->device < (int) vk_instance.device_supports_membudget.size());
+    GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size());
 
-    vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[ctx->device]];
+    vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
     vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops;
     vk::PhysicalDeviceMemoryProperties2 memprops = {};
-    const bool membudget_supported = vk_instance.device_supports_membudget[ctx->device];
+    const bool membudget_supported = vk_instance.device_supports_membudget[device];
     const bool is_integrated_gpu = vkdev.getProperties().deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
-    
-    vk::PhysicalDeviceProperties2 props2;
-    vkdev.getProperties2(&props2);
-    GGML_LOG_DEBUG("ggml_backend_vk_get_device_memory called: uuid %s\n", ctx->uuid.c_str());
-    GGML_LOG_DEBUG("ggml_backend_vk_get_device_memory called: luid %s\n", ctx->luid.c_str());
-
-    // Check VRAM reporting for Windows IGPU/DGPU using DXGI + PDH (vendor agnostic)
-    if (ggml_dxgi_pdh_init() == 0) {
-        GGML_LOG_DEBUG("DXGI + PDH Initialized. Getting GPU free memory info\n");
-        int status = ggml_dxgi_pdh_get_device_memory(ctx->luid.c_str(), free, total, ctx->is_integrated_gpu);
-        if (status == 0) {
-            GGML_LOG_DEBUG("%s utilizing DXGI + PDH memory reporting free: %zu total: %zu\n", __func__, *free, *total);
-            ggml_dxgi_pdh_release();
-            return;
-        }
-        ggml_dxgi_pdh_release();
-    }
-
-    if (!is_integrated_gpu)
-    {
-        // Use vendor specific management libraries for best VRAM reporting if available
-        switch (props2.properties.vendorID) {
-        case VK_VENDOR_ID_AMD:
-            if (ggml_hip_mgmt_init() == 0) {
-                int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total, ctx->is_integrated_gpu);
-                if (status == 0) {
-                    GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total);
-                    ggml_hip_mgmt_release();
-                    return;
-                }
-                ggml_hip_mgmt_release();
-            }
-            break;
-        case VK_VENDOR_ID_NVIDIA:
-            if (ggml_nvml_init() == 0) {
-                int status = ggml_nvml_get_device_memory(ctx->uuid.c_str(), free, total);
-                if (status == 0) {
-                    GGML_LOG_DEBUG("%s device %s utilizing NVML memory reporting free: %zu total: %zu\n", __func__, ctx->uuid.c_str(), *free, *total);
-                    ggml_nvml_release();
-                    return;
-                }
-                ggml_nvml_release();
-            }
-            break;
-        }
-    }
-    // else fallback to memory budget if supported
 
     if (membudget_supported) {
         memprops.pNext = &budgetprops;
@@ -13775,56 +14850,64 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) {
 
     const std::vector ext_props = device.enumerateDeviceExtensionProperties();
 
-    bool ext_support = false;
-
+    vk::PhysicalDeviceProperties devProps = device.getProperties();
+    bool ext_support = devProps.vendorID == VK_VENDOR_ID_AMD || devProps.vendorID == VK_VENDOR_ID_NVIDIA;
     for (const auto& properties : ext_props) {
-        if (strcmp("VK_EXT_pci_bus_info", properties.extensionName) == 0) {
+        if (strcmp(properties.extensionName, VK_EXT_PCI_BUS_INFO_EXTENSION_NAME) == 0) {
             ext_support = true;
             break;
         }
     }
 
-    vk::PhysicalDeviceProperties2 props2;
     if (!ext_support) {
-        device.getProperties2(&props2);
-        if (props2.properties.vendorID != VK_VENDOR_ID_AMD) {
-            return "";
-        }
-        // AMD doesn't claim to support PCI ID, but actually does, so try anyway and check for non-zero
+        return "";
     }
 
     vk::PhysicalDeviceProperties2 props = {};
     vk::PhysicalDevicePCIBusInfoPropertiesEXT pci_bus_info = {};
 
     props.pNext = &pci_bus_info;
+    try {
+        device.getProperties2(&props);
 
-    device.getProperties2(&props);
+        // If not supported and values are 0, it might be invalid
+        if (!ext_support && pci_bus_info.pciDomain == 0 && pci_bus_info.pciBus == 0 &&
+            pci_bus_info.pciDevice == 0 && pci_bus_info.pciFunction == 0) {
+            return "";
+        }
 
-    const uint32_t pci_domain = pci_bus_info.pciDomain;
-    const uint32_t pci_bus = pci_bus_info.pciBus;
-    const uint32_t pci_device = pci_bus_info.pciDevice;
-    const uint8_t pci_function = (uint8_t) pci_bus_info.pciFunction; // pci function is between 0 and 7, prevent printf overflow warning
+        const uint32_t pci_domain = pci_bus_info.pciDomain;
+        const uint32_t pci_bus = pci_bus_info.pciBus;
+        const uint32_t pci_device = pci_bus_info.pciDevice;
+        const uint8_t pci_function = (uint8_t) pci_bus_info.pciFunction; // pci function is between 0 and 7, prevent printf overflow warning
 
-    char pci_bus_id[16] = {};
-    snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function);
-    if (pci_domain == 0 && pci_bus == 0 && pci_device == 0 && pci_function == 0) {
+        char pci_bus_id[16] = {};
+        snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function);
+
+        return std::string(pci_bus_id);
+    } catch(...) {
         return "";
     }
-
-    return std::string(pci_bus_id);
 }
 
-static bool ggml_backend_vk_parse_pci_bus_id(const std::string & id, int *domain, int *bus, int *device) {
-    if (id.empty()) return false;
-    unsigned int d = 0, b = 0, dev = 0, func = 0;
-    // Expected format: dddd:bb:dd.f (all hex)
-    int n = sscanf(id.c_str(), "%4x:%2x:%2x.%1x", &d, &b, &dev, &func);
-    if (n < 4) return false;
-    if (domain) *domain = (int) d;
-    if (bus) *bus = (int) b;
-    if (device) *device = (int) dev;
-    return true;
-}
+//////////////////////////
+
+struct ggml_backend_vk_device_context {
+    size_t device;
+    std::string name;
+    std::string description;
+    bool is_integrated_gpu;
+    // Combined string id in the form "dddd:bb:dd.f" (domain:bus:device.function)
+    std::string pci_id;
+    std::string id;
+    std::string uuid;
+    std::string luid;
+    int major;
+    int minor;
+    int driver_major;
+    int driver_minor;
+    int op_offload_min_batch_size;
+};
 
 static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {
     ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
@@ -13843,7 +14926,56 @@ static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) {
 
 static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
     ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context;
-    ggml_backend_vk_get_device_memory(ctx, free, total);
+    GGML_LOG_DEBUG("ggml_backend_vk_device_get_memory called: uuid %s\n", ctx->uuid.c_str());
+    GGML_LOG_DEBUG("ggml_backend_vk_device_get_memory called: luid %s\n", ctx->luid.c_str());
+
+    // Check VRAM reporting for Windows IGPU/DGPU using DXGI + PDH (vendor agnostic)
+    if (ggml_dxgi_pdh_init() == 0) {
+        GGML_LOG_DEBUG("DXGI + PDH Initialized. Getting GPU free memory info\n");
+        int status = ggml_dxgi_pdh_get_device_memory(ctx->luid.c_str(), free, total, ctx->is_integrated_gpu);
+        if (status == 0) {
+            GGML_LOG_DEBUG("%s utilizing DXGI + PDH memory reporting free: %zu total: %zu\n", __func__, *free, *total);
+            ggml_dxgi_pdh_release();
+            return;
+        }
+        ggml_dxgi_pdh_release();
+    }
+
+    // Use vendor specific management libraries for best VRAM reporting if available
+    if (!ctx->is_integrated_gpu) {
+        GGML_ASSERT(ctx->device < (int) vk_instance.device_indices.size());
+        vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[ctx->device]];
+        vk::PhysicalDeviceProperties2 props2;
+        vkdev.getProperties2(&props2);
+
+        switch (props2.properties.vendorID) {
+        case VK_VENDOR_ID_AMD:
+            if (ggml_hip_mgmt_init() == 0) {
+                int status = ggml_hip_get_device_memory(!ctx->pci_id.empty() ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total, ctx->is_integrated_gpu);
+                if (status == 0) {
+                    GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, !ctx->pci_id.empty() ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total);
+                    ggml_hip_mgmt_release();
+                    return;
+                }
+                ggml_hip_mgmt_release();
+            }
+            break;
+        case VK_VENDOR_ID_NVIDIA:
+            if (ggml_nvml_init() == 0) {
+                int status = ggml_nvml_get_device_memory(ctx->uuid.c_str(), free, total);
+                if (status == 0) {
+                    GGML_LOG_DEBUG("%s device %s utilizing NVML memory reporting free: %zu total: %zu\n", __func__, ctx->uuid.c_str(), *free, *total);
+                    ggml_nvml_release();
+                    return;
+                }
+                ggml_nvml_release();
+            }
+            break;
+        }
+    }
+
+    // Fallback to Vulkan memory budget
+    ggml_backend_vk_get_device_memory(ctx->device, free, total);
 }
 
 static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) {
@@ -13893,6 +15025,35 @@ static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const
 }
 
 static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
+    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
+    const vk_device& device = ggml_vk_get_device(ctx->device);
+
+    const bool uses_bda = (op->op == GGML_OP_IM2COL || op->op == GGML_OP_IM2COL_3D) &&
+                          device->shader_int64 && device->buffer_device_address;
+
+    auto const & tensor_size_supported = [&](size_t tensor_size) {
+        if (tensor_size > device->max_buffer_size) {
+            return false;
+        }
+        // For im2col shaders using BDA, maxStorageBufferRange limit doesn't apply.
+        // If shader64BitIndexing is enabled, maxStorageBufferRange limit doesn't apply.
+        if (!uses_bda && !device->shader_64b_indexing) {
+            if (tensor_size > device->properties.limits.maxStorageBufferRange) {
+                return false;
+            }
+        }
+        return true;
+    };
+    // reject any tensors larger than the max buffer size
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        if (op->src[i] && !tensor_size_supported(ggml_nbytes(op->src[i]))) {
+            return false;
+        }
+    }
+    if (!tensor_size_supported(ggml_nbytes(op))) {
+        return false;
+    }
+
     switch (op->op) {
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
@@ -13902,6 +15063,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 case GGML_UNARY_OP_GELU_QUICK:
                 case GGML_UNARY_OP_SILU:
                 case GGML_UNARY_OP_RELU:
+                case GGML_UNARY_OP_XIELU:
                 case GGML_UNARY_OP_NEG:
                 case GGML_UNARY_OP_TANH:
                 case GGML_UNARY_OP_SIGMOID:
@@ -13940,8 +15102,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_MUL_MAT_ID:
             {
                 ggml_type src0_type = op->src[0]->type;
-                ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
-                const vk_device& device = ggml_vk_get_device(ctx->device);
                 if (op->op == GGML_OP_MUL_MAT_ID) {
                     if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
                         // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
@@ -14002,8 +15162,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
             }
         case GGML_OP_FLASH_ATTN_EXT:
             {
-                ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
-                auto device = ggml_vk_get_device(ctx->device);
                 bool coopmat2 = device->coopmat2;
                 uint32_t HSK = op->src[1]->ne[0];
                 uint32_t HSV = op->src[2]->ne[0];
@@ -14180,6 +15338,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_REPEAT_BACK:
             return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_ROPE:
+            return ggml_is_contiguous_rows(op) && ggml_is_contiguous_rows(op->src[0]);
         case GGML_OP_ROPE_BACK:
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
@@ -14190,8 +15349,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
             return true;
         case GGML_OP_NORM:
         case GGML_OP_GROUP_NORM:
-        case GGML_OP_L2_NORM:
             return ggml_is_contiguous(op->src[0]);
+        case GGML_OP_L2_NORM:
+            return ggml_is_contiguous_rows(op->src[0]) &&
+                   op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
         case GGML_OP_ADD:
         case GGML_OP_SUB:
         case GGML_OP_MUL:
@@ -14225,8 +15386,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
                     return false;
                 }
-                ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
-                auto device = ggml_vk_get_device(ctx->device);
                 // pipeline_argsort_large_f32 requires vulkan memory model.
                 if (device->vulkan_memory_model) {
                     return true;
@@ -14239,8 +15398,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
                     return false;
                 }
-                ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
-                auto device = ggml_vk_get_device(ctx->device);
                 // We could potentially support larger, using argsort to sort the
                 // whole thing. Not clear if this is needed.
                 uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1;
@@ -14251,9 +15408,17 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
             }
             return true;
         case GGML_OP_UPSCALE:
-            return op->src[0]->type == GGML_TYPE_F32 && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
-        case GGML_OP_ACC:
+            if (op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS) {
+                if ((op->op_params[0] & 0xFF) != GGML_SCALE_MODE_BILINEAR) {
+                    return false;
+                }
+            }
             return op->src[0]->type == GGML_TYPE_F32;
+        case GGML_OP_ACC:
+            return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
+        case GGML_OP_SET:
+            return op->src[0]->type == op->src[1]->type && op->src[0]->type == op->type &&
+                   (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_I32);
         case GGML_OP_CONCAT:
             return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32);
         case GGML_OP_ADD1:
@@ -14282,8 +15447,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
             return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
         case GGML_OP_CUMSUM:
             {
-                ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
-                auto device = ggml_vk_get_device(ctx->device);
                 if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
                     return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
                 }
@@ -14291,9 +15454,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
             }
         case GGML_OP_SOLVE_TRI:
             {
-                ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
-                const vk_device& device = ggml_vk_get_device(ctx->device);
-
                 if (op->type != GGML_TYPE_F32 || op->src[0]->type != GGML_TYPE_F32) {
                     return false;
                 }
@@ -14358,14 +15518,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                     return false;
                 }
 
-                ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
-                const vk_device& device = ggml_vk_get_device(ctx->device);
-
-                const uint32_t SPLIT_H = 16;
+                size_t shmem_size = d_state * sizeof(float);
 
-                size_t stateC_size = SPLIT_H * d_state * sizeof(float);
+                if (shmem_size > device->properties.limits.maxComputeSharedMemorySize) {
+                    return false;
+                }
 
-                if (stateC_size > device->properties.limits.maxComputeSharedMemorySize) {
+                if (!device->subgroup_basic) {
                     return false;
                 }
 
@@ -14404,13 +15563,111 @@ static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_ba
     return buft_ctx->device->idx == ctx->device;
 }
 
+static int64_t ggml_vk_get_op_batch_size(const ggml_tensor * op) {
+    switch (op->op) {
+        case GGML_OP_GET_ROWS:
+            return 0;
+        case GGML_OP_MUL_MAT:
+            return op->ne[1];
+        case GGML_OP_MUL_MAT_ID:
+        case GGML_OP_ROPE:
+        case GGML_OP_ROPE_BACK:
+            return op->ne[2];
+        default:
+            return ggml_nrows(op);
+    }
+}
+
 static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
-    const int min_batch_size = 32;
+    ggml_backend_vk_device_context * dev_ctx = (ggml_backend_vk_device_context *)dev->context;
+
+    return ggml_vk_get_op_batch_size(op) >= dev_ctx->op_offload_min_batch_size;
+}
 
-    return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
-           (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
+static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t dev) {
+    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
+    auto device = ggml_vk_get_device(ctx->device);
 
-    UNUSED(dev);
+    vk_event *vkev = new vk_event;
+    if (!vkev) {
+        return nullptr;
+    }
+
+    // The event/fence is expected to initially be in the signaled state.
+    vkev->event = device->device.createEvent({});
+    vkev->fence = device->device.createFence({vk::FenceCreateFlagBits::eSignaled});
+    device->device.setEvent(vkev->event);
+
+    return new ggml_backend_event {
+        /* .device  = */ dev,
+        /* .context = */ vkev,
+    };
+}
+
+static void ggml_backend_vk_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
+    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
+    auto device = ggml_vk_get_device(ctx->device);
+
+    vk_event *vkev = (vk_event *)event->context;
+
+    device->device.destroyFence(vkev->fence);
+    device->device.destroyEvent(vkev->event);
+    delete vkev;
+    delete event;
+}
+
+static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
+    VK_LOG_DEBUG("ggml_backend_vk_device_event_synchronize(backend=" << dev << ", event=" << event << ")");
+    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
+    auto device = ggml_vk_get_device(ctx->device);
+    vk_event *vkev = (vk_event *)event->context;
+
+    VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize");
+}
+
+static vk_buffer ggml_vk_buffer_from_host_ptr(vk_device & device, void * ptr, size_t size) {
+    if (!device->external_memory_host) {
+        return {};
+    }
+
+    uintptr_t uptr = reinterpret_cast(ptr);
+    if (uptr & (device->min_imported_host_pointer_alignment - 1)) {
+        return {};
+    }
+    if (size & (device->min_imported_host_pointer_alignment - 1)) {
+        return {};
+    }
+
+    const vk::MemoryPropertyFlags property_flags = vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached;
+
+    vk_buffer buf {};
+    try {
+        buf = ggml_vk_create_buffer(device, size, { property_flags }, ptr);
+    } catch (vk::SystemError& e) {
+        GGML_LOG_WARN("ggml_vulkan: Failed ggml_vk_create_buffer (%s)\n", e.what());
+    }
+
+    return buf;
+}
+
+static ggml_backend_buffer_t ggml_backend_vk_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
+    VK_LOG_DEBUG("ggml_backend_vk_device_buffer_from_host_ptr(backend=" << dev << ", ptr=" << ptr << ", size=" << size << ")");
+    GGML_UNUSED(max_tensor_size);
+
+    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
+    auto device = ggml_vk_get_device(ctx->device);
+
+    vk_buffer buf = ggml_vk_buffer_from_host_ptr(device, ptr, size);
+
+    if (!buf) {
+        return {};
+    }
+
+    ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(device, std::move(buf), device->name);
+
+    ggml_backend_buffer_t ret = ggml_backend_buffer_init(ggml_backend_vk_device_get_buffer_type(dev), ggml_backend_vk_buffer_interface, bufctx, size);
+
+    return ret;
 }
 
 static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
@@ -14422,13 +15679,13 @@ static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
     /* .init_backend         = */ ggml_backend_vk_device_init,
     /* .get_buffer_type      = */ ggml_backend_vk_device_get_buffer_type,
     /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type,
-    /* .buffer_from_host_ptr = */ NULL,
+    /* .buffer_from_host_ptr = */ ggml_backend_vk_device_buffer_from_host_ptr,
     /* .supports_op          = */ ggml_backend_vk_device_supports_op,
     /* .supports_buft        = */ ggml_backend_vk_device_supports_buft,
     /* .offload_op           = */ ggml_backend_vk_device_offload_op,
-    /* .event_new            = */ NULL,
-    /* .event_free           = */ NULL,
-    /* .event_synchronize    = */ NULL,
+    /* .event_new            = */ ggml_backend_vk_device_event_new,
+    /* .event_free           = */ ggml_backend_vk_device_event_free,
+    /* .event_synchronize    = */ ggml_backend_vk_device_event_synchronize,
 };
 
 static const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) {
@@ -14451,6 +15708,7 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
         std::lock_guard lock(mutex);
         if (!initialized) {
             std::vector vk_devices = vk_instance.instance.enumeratePhysicalDevices();
+            const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
 
             for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) {
                 ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;
@@ -14461,12 +15719,13 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
                 ctx->description = desc;
                 ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu;
                 ctx->pci_id = ggml_backend_vk_get_device_pci_id(i);
-                ctx->id = ggml_backend_vk_get_device_id(i);
+                ctx->id = ggml_vk_get_device_id(i);
                 devices.push_back(new ggml_backend_device {
                     /* .iface   = */ ggml_backend_vk_device_i,
                     /* .reg     = */ reg,
                     /* .context = */ ctx,
                 });
+
                 // Gather additional information about the device
                 int dev_idx = vk_instance.device_indices[i];
                 vk::PhysicalDeviceProperties props1;
@@ -14482,8 +15741,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
                 std::ostringstream oss;
                 oss << std::hex << std::setfill('0');
                 int byteIdx = 0;
-                for (int i = 0; i < 16; ++i, ++byteIdx) {
-                    oss << std::setw(2) << static_cast(device_id_props.deviceUUID[i]);
+                for (int j = 0; j < 16; ++j, ++byteIdx) {
+                    oss << std::setw(2) << static_cast(device_id_props.deviceUUID[j]);
                     if (byteIdx == 3 || byteIdx == 5 || byteIdx == 7 || byteIdx == 9) {
                         oss << '-';
                     }
@@ -14502,6 +15761,7 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
                 // TODO regex parse driver_props.driverInfo for a X.Y or X.Y.Z version string
                 ctx->driver_major = 0;
                 ctx->driver_minor = 0;
+                ctx->op_offload_min_batch_size = min_batch_size;
             }
             initialized = true;
         }
@@ -14620,6 +15880,46 @@ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDevicePrope
     }
 }
 
+static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev) {
+    VkPhysicalDeviceProperties2 props = vkdev.getProperties2();
+
+    if (props.properties.vendorID != VK_VENDOR_ID_INTEL) {
+        return 0;
+    }
+
+    const uint32_t device_id = props.properties.deviceID;
+
+    switch (device_id) {
+    case 0x56A6:  // A310
+        return 6;
+    case 0x5693:  // A370M
+    case 0x56A5:  // A380
+    case 0x56B1:  // Pro A40/A50
+        return 8;
+    case 0x5697:  // A530M
+        return 12;
+    case 0x5692:  // A550M
+    case 0x56B3:  // Pro A60
+        return 16;
+    case 0x56A2:  // A580
+        return 24;
+    case 0x5691:  // A730M
+    case 0x56A1:  // A750
+        return 28;
+    case 0x56A0:  // A770
+    case 0x5690:  // A770M
+        return 32;
+    case 0xE212:  // Pro B50
+        return 16;
+    case 0xE20C:  // B570
+        return 18;
+    case 0xE20B:  // B580
+        return 20;
+    default:
+        return 0;
+    }
+}
+
 // checks
 
 #ifdef GGML_VULKAN_CHECK_RESULTS
@@ -14845,7 +16145,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
         } else if (tensor->op == GGML_OP_LOG) {
             tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
         } else if (tensor->op == GGML_OP_TRI) {
-            tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0));
+            tensor_clone = ggml_tri(ggml_ctx, src_clone[0], (ggml_tri_type)ggml_get_op_params_i32(tensor, 0));
         } else if (tensor->op == GGML_OP_DIAG) {
             tensor_clone = ggml_diag(ggml_ctx, src_clone[0]);
         } else if (tensor->op == GGML_OP_CLAMP) {
@@ -14862,6 +16162,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
             tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);
         } else if (tensor->op == GGML_OP_ACC) {
             tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
+        } else if (tensor->op == GGML_OP_SET) {
+            tensor_clone = ggml_set(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
         } else if (tensor->op == GGML_OP_NORM) {
             tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
         } else if (tensor->op == GGML_OP_GROUP_NORM) {
@@ -14933,6 +16235,13 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
             case GGML_UNARY_OP_RELU:
                 tensor_clone = ggml_relu(ggml_ctx, src_clone[0]);
                 break;
+            case GGML_UNARY_OP_XIELU:
+                tensor_clone = ggml_xielu(ggml_ctx, src_clone[0], 0, 0, 0, 0);
+                ggml_set_op_params_f32(tensor_clone, 1, ggml_get_op_params_f32(tensor, 1));
+                ggml_set_op_params_f32(tensor_clone, 2, ggml_get_op_params_f32(tensor, 2));
+                ggml_set_op_params_f32(tensor_clone, 3, ggml_get_op_params_f32(tensor, 3));
+                ggml_set_op_params_f32(tensor_clone, 4, ggml_get_op_params_f32(tensor, 4));
+                break;
             case GGML_UNARY_OP_NEG:
                 tensor_clone = ggml_neg(ggml_ctx, src_clone[0]);
                 break;
@@ -15287,7 +16596,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph *
         ggml_vk_print_graph_origin(tensor, done);
     }
 
-    if (avg_err > 0.5 || std::isnan(avg_err)) {
+    if (avg_err > 0.01 || std::isnan(avg_err)) {
         std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl;
         std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
         if (src0 != nullptr) {
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp
index 5084a70ed49..6ba3d1d89e0 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp
@@ -3,6 +3,9 @@
 #include "types.glsl"
 #include "generic_binary_head.glsl"
 
+// false for SET, true for ACC
+layout(constant_id = 1) const bool ACC = true;
+
 layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
 
 void main() {
@@ -13,17 +16,22 @@ void main() {
 
     const uint offset = p.param3;
     const uint src1_i = idx - offset;
-    const uint oz = src1_i / p.nb02;
-    const uint oy = (src1_i - (oz * p.nb02)) / p.nb01;
-    const uint ox = src1_i % p.nb01;
+    const uint i3 = src1_i / p.nb03;
+    const uint rem2 = src1_i - i3 * p.nb03;
+    const uint i2 = rem2 / p.nb02;
+    const uint rem1 = rem2 - i2 * p.nb02;
+    const uint i1 = rem1 / p.nb01;
+    const uint i0 = rem1 % p.nb01;
 
     uint i00, i01, i02, i03;
-    get_indices(idx, i00, i01, i02, i03);
 
-    if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
-        data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
+    if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) {
+        if (ACC) {
+            data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
+        } else {
+            data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
+        }
     } else {
-        data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]));
+        data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]));
     }
 }
-
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp
new file mode 100644
index 00000000000..ffc8608691f
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp
@@ -0,0 +1,51 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+
+#include "types.glsl"
+
+layout (push_constant) uniform parameter
+{
+    uint32_t ne00;
+    uint32_t ne01;
+    uint32_t nb00;
+    uint32_t nb01;
+    uint32_t a_offset;
+} p;
+
+#define BLOCK_SIZE 256
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {uint data_a[];};
+layout (binding = 1) writeonly buffer D {uint data_d[];};
+
+shared uint vals[BLOCK_SIZE];
+
+void main() {
+    const uint expert_id = gl_WorkGroupID.x;
+    const uint num_elements = p.ne00 * p.ne01;
+    const uint tid = gl_LocalInvocationID.x;
+
+    uint count = 0;
+    for (uint idx = tid; idx < num_elements; idx += BLOCK_SIZE) {
+        const uint i01 = idx / p.ne00;
+        const uint i00 = idx % p.ne00;
+        const uint a = data_a[p.a_offset + i01 * p.nb01 + i00 * p.nb00];
+
+        count += uint(a == expert_id);
+    }
+
+    vals[tid] = count;
+    barrier();
+    [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+        if (tid < s) {
+            vals[tid] += vals[tid + s];
+        }
+        barrier();
+    }
+
+    if (tid == 0) {
+        data_d[expert_id] = vals[0];
+    }
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp
index a4c8fc354e9..75e3c3b0eb4 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp
@@ -14,6 +14,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
 
 layout (constant_id = 0) const uint BLOCK_SIZE = 128;
 layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
+layout (constant_id = 2) const uint ELEM_PER_THREAD = 4;
 
 #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
 
@@ -38,32 +39,45 @@ void main() {
         last_sum = 0;
     }
 
-    uint col = tid;
-    uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE);
+    uint col = tid * ELEM_PER_THREAD;
+    uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE * ELEM_PER_THREAD);
     for (int i = 0; i < num_iter; ++i) {
-        FLOAT_TYPE v = 0;
-        if (col < p.n_cols) {
-            v = FLOAT_TYPE(data_a[src_idx + col]);
+        FLOAT_TYPE v[ELEM_PER_THREAD];
+        FLOAT_TYPE thread_sum = 0;
+        [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
+            if (col + j < p.n_cols) {
+                thread_sum += FLOAT_TYPE(data_a[src_idx + col + j]);
+            }
+            v[j] = thread_sum;
         }
-        v = subgroupInclusiveAdd(v);
 
+        thread_sum = subgroupExclusiveAdd(thread_sum);
+        [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
+            v[j] += thread_sum;
+        }
         // Store the largest partial sum for each subgroup, then add the partials for all
         // lower subgroups and the final partial sum from the previous iteration.
         if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
-            partial[subgroup_id] = v;
+            partial[subgroup_id] = v[ELEM_PER_THREAD - 1];
         }
         barrier();
-        for (int j = 0; j < subgroup_id; ++j) {
-            v += partial[j];
+        for (int s = 0; s < subgroup_id; ++s) {
+            [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
+                v[j] += partial[s];
+            }
+        }
+        [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
+            v[j] += last_sum;
         }
-        v += last_sum;
         barrier();
         if (tid == BLOCK_SIZE - 1) {
-            last_sum = v;
+            last_sum = v[ELEM_PER_THREAD - 1];
         }
-        if (col < p.n_cols) {
-            data_d[dst_idx + col] = D_TYPE(v);
+        [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
+            if (col + j < p.n_cols) {
+                data_d[dst_idx + col + j] = D_TYPE(v[j]);
+            }
         }
-        col += BLOCK_SIZE;
+        col += BLOCK_SIZE * ELEM_PER_THREAD;
     }
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp
new file mode 100644
index 00000000000..6d39f927fc1
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp
@@ -0,0 +1,60 @@
+#version 450
+
+#include "types.glsl"
+#include "sum_rows.glsl"
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+#extension GL_KHR_shader_subgroup_basic : enable
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+layout (binding = 2) writeonly buffer T {D_TYPE data_t[];};
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 128;
+layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
+
+#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
+
+shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE];
+
+void main() {
+    const uint row = gl_WorkGroupID.y;
+    const uint tid = gl_LocalInvocationID.x;
+    const uint col = gl_GlobalInvocationID.x;
+
+    const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
+    const uint i03_offset = i03 * p.ne01*p.ne02;
+    const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
+    const uint i01 = row - i03_offset - i02*p.ne01;
+
+    const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
+    const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
+
+    uint subgroup_id = tid / SUBGROUP_SIZE;
+
+    FLOAT_TYPE v = 0;
+    if (col < p.n_cols) {
+        v = FLOAT_TYPE(data_a[src_idx + col]);
+    }
+    v = subgroupInclusiveAdd(v);
+
+    // Store the largest partial sum for each subgroup, then add the partials for all
+    // lower subgroups and the final partial sum from the previous iteration.
+    if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
+        partial[subgroup_id] = v;
+    }
+    barrier();
+    for (int j = 0; j < subgroup_id; ++j) {
+        v += partial[j];
+    }
+    barrier();
+    if (tid == BLOCK_SIZE - 1) {
+        data_t[gl_WorkGroupID.x + gl_NumWorkGroups.x * row] = v;
+    }
+    if (col < p.n_cols) {
+        data_d[dst_idx + col] = D_TYPE(v);
+    }
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp
new file mode 100644
index 00000000000..e401893466c
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp
@@ -0,0 +1,66 @@
+#version 450
+
+#include "types.glsl"
+#include "sum_rows.glsl"
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+#extension GL_KHR_shader_subgroup_basic : enable
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) buffer D {D_TYPE data_d[];};
+layout (binding = 2) readonly buffer T {D_TYPE data_t[];};
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 128;
+layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
+
+#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
+
+shared FLOAT_TYPE temp[BLOCK_SIZE / SUBGROUP_SIZE];
+
+void main() {
+    const uint row = gl_WorkGroupID.y;
+    const uint tid = gl_LocalInvocationID.x;
+
+    const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
+    const uint i03_offset = i03 * p.ne01*p.ne02;
+    const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
+    const uint i01 = row - i03_offset - i02*p.ne01;
+
+    const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
+    const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
+
+    const uint col = gl_GlobalInvocationID.x;
+
+    float v = 0;
+    // prefetch value we're adding to
+    if (col < p.n_cols) {
+        v = data_d[dst_idx + col];
+    }
+
+    // compute the sum of all previous blocks
+    uint c = tid;
+    float sum = 0;
+    while (c < gl_WorkGroupID.x) {
+        sum += data_t[c + gl_NumWorkGroups.x * row];
+        c += BLOCK_SIZE;
+    }
+
+    sum = subgroupAdd(sum);
+    if (gl_SubgroupInvocationID == 0) {
+        temp[gl_SubgroupID] = sum;
+    }
+    barrier();
+    sum = 0;
+    [[unroll]] for (uint s = 0; s < BLOCK_SIZE / SUBGROUP_SIZE; ++s) {
+        sum += temp[s];
+    }
+
+    // Add the sum to what the first pass computed
+    if (col < p.n_cols) {
+        data_d[dst_idx + col] = v + sum;
+    }
+}
+
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl
index 70ee542d969..7865a6bda79 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl
@@ -401,13 +401,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
     const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
     const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3;
     const uint qshift = (iqs & 16) >> 2;
-    u8vec4 qs = u8vec4(
-        data_a[a_offset + ib].qs[iq + 0],
-        data_a[a_offset + ib].qs[iq + 1],
-        data_a[a_offset + ib].qs[iq + 2],
-        data_a[a_offset + ib].qs[iq + 3]
-    );
-    qs = (qs >> qshift) & uint8_t(0xF);
+    const u8vec4 qs = unpack8((data_a_packed32[a_offset + ib].qs[iq/4] >> qshift) & 0x0F0F0F0F);
 
     const float dl = float(int(sl | (sh << 4)) - 32);
     return dl * vec4(
@@ -468,7 +462,8 @@ vec2 get_dm(uint ib, uint a_offset) {
 
 #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
 vec2 get_dm(uint ib, uint a_offset) {
-    return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m));
+    const vec2 dm = vec2(data_a_packed32[a_offset + ib].dm);
+    return dm;
 }
 #endif
 
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
index 0379e5d5024..ec48f5b1152 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
@@ -3,9 +3,13 @@
 #extension GL_EXT_control_flow_attributes : enable
 #extension GL_EXT_shader_16bit_storage : require
 
-#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
 
+#ifdef FLOAT16
+#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
+#extension GL_EXT_shader_subgroup_extended_types_float16 : require
+#endif
+
 #extension GL_KHR_shader_subgroup_shuffle : enable
 #extension GL_KHR_shader_subgroup_vote : enable
 
@@ -15,8 +19,10 @@
 const uint32_t HSK_per_thread = HSK / D_split;
 const uint32_t HSV_per_thread = HSV / D_split;
 
-const uint32_t cols_per_iter = WorkGroupSize / D_split;
+const uint32_t rows_per_thread = Br / row_split;
+const uint32_t cols_per_iter = WorkGroupSize / D_split / row_split;
 const uint32_t cols_per_thread = Bc / cols_per_iter;
+const uint32_t num_subgroups = SubGroupSize == 0 ? 0 : WorkGroupSize / SubGroupSize;
 
 
 layout (binding = 0) readonly buffer Q {float data_q[];};
@@ -27,20 +33,22 @@ layout (binding = 2) readonly buffer V {float16_t data_v[];};
 layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
 layout (binding = 3) readonly buffer M {float16_t data_m[];};
 
-// Store the output when doing grouped query attention.
-// Rows index by Q's dimension 2, and the first N rows are valid.
-D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
-{
-    uint32_t offset = (iq2 + r) * HSV + c;
-    data_o[o_offset + offset] = D_TYPE(elem);
-    return elem;
-}
+// If SubGroupSize is set to 0 then only use shmem reductions
+const uint32_t tmpsh_size = (SubGroupSize > 0) ? (row_split == 1 ? num_subgroups * D_split : num_subgroups) : WorkGroupSize;
+shared float tmpsh[tmpsh_size];
+shared FLOAT_TYPEV4 tmpshv4[tmpsh_size];
 
-shared FLOAT_TYPE tmpsh[WorkGroupSize];
-shared vec4 tmpshv4[WorkGroupSize];
+const uint32_t masksh_stride = Br + 1;
+shared FLOAT_TYPE masksh[Bc * masksh_stride];
 
-shared float masksh[Bc][Br];
-shared vec4 Qf[Br][HSK / 4];
+const uint32_t qf_stride = HSK / 4 + 1;
+shared FLOAT_TYPEV4 Qf[Br * qf_stride];
+
+const uint32_t D = HSK > HSV ? HSK : HSV;
+const uint32_t kvsh_stride = D / 4 + 1;
+shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1];
+
+shared vec4 occupancy_limiter[LIMIT_OCCUPANCY_SHMEM > 0 ? LIMIT_OCCUPANCY_SHMEM : 1];
 
 void main() {
 #ifdef NEEDS_INIT_IQ_SHMEM
@@ -50,50 +58,70 @@ void main() {
     init_indices();
 
     const uint32_t tid = gl_LocalInvocationIndex;
+    const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
+    const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
+    const uint32_t rowgroup_tid = gl_LocalInvocationIndex % threads_per_rowgroup;
     const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
-    const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
+    const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
+
+    if (LIMIT_OCCUPANCY_SHMEM > 0) {
+        // This just exists to avoid the occupancy_limiter array getting optimized out
+        occupancy_limiter[tid] = vec4(tid);
+
+        barrier();
+
+        if (occupancy_limiter[tid] == vec4(99999.0)) {
+            data_ov4[0] = D_TYPEV4(occupancy_limiter[tid]);
+        }
+    }
+
+#define tile_row(r) (row_tid * rows_per_thread + (r))
 
-    uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
+    uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02 + iq3*p.nb03) / 4;
 
     [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
         uint32_t d = (idx + tid) % (HSK / 4);
         uint32_t r = (idx + tid) / (HSK / 4);
         if (r < Br && d < HSK / 4 &&
             i * Br + r < N) {
-            Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
+            Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
         }
     }
     barrier();
 
-    vec4 Of[Br][HSV_per_thread / 4];
+    FLOAT_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
     [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            Of[r][d] = vec4(0.0);
+        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+            Of[r][d] = FLOAT_TYPEV4(0.0);
         }
     }
 
-    float Lf[Br], Mf[Br];
+    float Lf[rows_per_thread], Mf[rows_per_thread];
 
     // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
     const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
 
-    [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
         Lf[r] = 0;
         Mf[r] = NEG_FLT_MAX_OVER_2;
     }
 
-    float slope[Br];
-    [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-        slope[r] = 1.0;
+    ACC_TYPE slope[rows_per_thread];
+    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+        slope[r] = ACC_TYPE(1.0);
     }
 
     // ALiBi
     if (p.max_bias > 0.0f) {
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
+        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+            slope[r] = perElemOpComputeSlope(tile_row(r), col_tid, ACC_TYPE(0), iq2);
         }
     }
 
+    const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);
+    // mo_offset will point to the tile starting at row i*Br and col 0
+    uint32_t mo_offset = mo_stride * i;
+
 #if BLOCK_SIZE > 1
     uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
     uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
@@ -101,69 +129,149 @@ void main() {
     uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
     uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
 #endif
-    uint32_t m_offset = 0;
+    uint32_t m_offset = gqa_iq1*KV;
     if (p.nem2 != 1 || p.nem3 != 1) {
-        m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
+        m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
+        mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;
     }
 
+    uint32_t mask_opt = 0;
+    uint32_t mask_opt_idx = ~0;
+    uint32_t mask_opt_bits = 0;
+
     [[dont_unroll]]
     for (uint32_t j = start_j; j < end_j; ++j) {
+        if (MASK_ENABLE) {
+            if (USE_MASK_OPT && mask_opt_idx != j / 16) {
+                mask_opt_idx = j / 16;
+                mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
+            }
+            mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
+            if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
+                // skip this block
+                continue;
+            }
+            // Only load if the block is not all zeros
+            if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
+                bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
 
-        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
-            bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
-
-            float max_mask = NEG_FLT_MAX_OVER_2;
-            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
-                uint32_t c = (idx + tid) % Bc;
-                uint32_t r = (idx + tid) / Bc;
-                if (idx + tid < Bc * Br) {
-                    if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
-                        float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
-                        masksh[c][r] = m;
-                        max_mask = max(max_mask, m);
-                    } else {
-                        masksh[c][r] = float(0);
+                float max_mask = NEG_FLT_MAX_OVER_2;
+                barrier();
+                [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
+                    uint32_t c = (idx + tid) % Bc;
+                    uint32_t r = (idx + tid) / Bc;
+                    if (idx + tid < Bc * Br) {
+                        if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
+                            FLOAT_TYPE m = FLOAT_TYPE(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
+                            masksh[c * masksh_stride + r] = m;
+                            max_mask = max(max_mask, float(m));
+                        } else {
+                            masksh[c * masksh_stride + r] = FLOAT_TYPE(0);
+                        }
                     }
                 }
-            }
-            // skip the block if the mask is entirely -inf
-            bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
-            barrier();
-            if (gl_SubgroupInvocationID == 0) {
-                tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
-            }
-            barrier();
-            [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
-                max_mask = max(max_mask, tmpsh[s]);
-            }
-            if (max_mask <= NEG_FLT_MAX_OVER_2) {
-                continue;
+                // skip the block if the mask is entirely -inf
+                bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
+                barrier();
+                if (gl_SubgroupInvocationID == 0) {
+                    tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
+                }
+                barrier();
+                [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
+                    max_mask = max(max_mask, tmpsh[s]);
+                }
+                if (max_mask <= NEG_FLT_MAX_OVER_2) {
+                    continue;
+                }
             }
         }
 
-        float Sf[Br][cols_per_thread];
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+        ACC_TYPE Sf[rows_per_thread][cols_per_thread];
+        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
             [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
-                Sf[r][c] = 0.0;
+                Sf[r][c] = ACC_TYPE(0.0);
             }
         }
 
+        if (SHMEM_STAGING != 0) {
+            barrier();
+            [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
+                uint32_t d = (idx + tid) % (HSK / 4);
+                uint32_t c = (idx + tid) / (HSK / 4);
+                if (idx + gl_WorkGroupSize.x <= Bc * HSK / 4 || c < Bc) {
+                    FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0);
+                    if (!KV_bounds_check || j * Bc + c < KV) {
+#if BLOCK_SIZE > 1
+                        uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
+                        uint ib = coord / BLOCK_SIZE;
+                        uint iqs = (coord % BLOCK_SIZE);
+                        K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
+#else
+                        K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
+#endif
+                    }
 
-        [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
-            if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
-                continue;
+                    kvsh[c * kvsh_stride + d] = K_Tf;
+                }
             }
+            barrier();
+        }
+
+        // More d iterations means Q register caching becomes relevant
+        // Few iterations means the additional registers needed are worse than the speed-up from caching
+        if (HSK_per_thread / 4 > 4) {
             [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
+                FLOAT_TYPEV4 Q_cache[rows_per_thread];
+                [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                    Q_cache[r] = Qf[tile_row(r) * qf_stride + d * D_split + d_tid];
+                }
+
+                [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+                    if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
+                        continue;
+                    }
+
+                    FLOAT_TYPEV4 K_Tf;
+                    if (SHMEM_STAGING != 0) {
+                        K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
+                    } else {
 #if BLOCK_SIZE > 1
-                uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
-                uint ib = coord / BLOCK_SIZE;
-                uint iqs = (coord % BLOCK_SIZE);
-                vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
+                        uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
+                        uint ib = coord / BLOCK_SIZE;
+                        uint iqs = (coord % BLOCK_SIZE);
+                        K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
 #else
-                vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
+                        K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
 #endif
-                [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-                    Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf);
+                    }
+                    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                        Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf));
+                    }
+                }
+            }
+        } else {
+            [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+                if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
+                    continue;
+                }
+
+                [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
+                    FLOAT_TYPEV4 K_Tf;
+                    if (SHMEM_STAGING != 0) {
+                        K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
+                    } else {
+#if BLOCK_SIZE > 1
+                        uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
+                        uint ib = coord / BLOCK_SIZE;
+                        uint iqs = (coord % BLOCK_SIZE);
+                        K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
+#else
+                        K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
+#endif
+                    }
+                    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                        Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf));
+                    }
                 }
             }
         }
@@ -171,89 +279,109 @@ void main() {
         [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
             // Compute sum across the D_split
             [[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) {
-                [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+                [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
                     Sf[r][c] += subgroupShuffleXor(Sf[r][c], s);
                 }
             }
         }
 
-        if (p.logit_softcap != 0.0f) {
-            [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+        if (LOGIT_SOFTCAP) {
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
                 [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
-                    Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]);
+                    Sf[r][c] = ACC_TYPE(p.logit_softcap * tanh(Sf[r][c]));
                 }
             }
         }
 
-        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
+        if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
             [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
-                [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-                    float mvf = masksh[c * cols_per_iter + col_tid][r];
+                [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                    FLOAT_TYPE mvf = masksh[(c * cols_per_iter + col_tid) * masksh_stride + tile_row(r)];
 
                     Sf[r][c] += slope[r]*mvf;
                 }
             }
-            barrier();
         }
 
-        float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            rowmaxf[r] = NEG_FLT_MAX_OVER_2;
+        float eMf[rows_per_thread];
+        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+            float rowmaxf = NEG_FLT_MAX_OVER_2;
             [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
                 if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
                     continue;
                 }
-                rowmaxf[r] = max(rowmaxf[r], Sf[r][c]);
+                rowmaxf = max(rowmaxf, float(Sf[r][c]));
             }
-            Moldf[r] = Mf[r];
+            float Moldf = Mf[r];
 
             // M = max(rowmax, Mold)
             // P = e^(S - M)
             // eM = e^(Mold - M)
-            Mf[r] = max(rowmaxf[r], Moldf[r]);
-            [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
-                Pf[r][c] = exp(Sf[r][c] - Mf[r]);
-            }
-            eMf[r] = exp(Moldf[r] - Mf[r]);
+            Mf[r] = max(rowmaxf, Moldf);
+            eMf[r] = exp(Moldf - Mf[r]);
+            Lf[r] = eMf[r]*Lf[r];
+        }
 
-            // Compute sum across row of P
-            rowsumf[r] = 0.0;
-            [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
-                if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
-                    continue;
-                }
-                rowsumf[r] += Pf[r][c];
+        [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                Of[r][d] = FLOAT_TYPE(eMf[r]) * Of[r][d];
             }
-
-            Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
         }
 
-        [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-            [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-                Of[r][d] = eMf[r] * Of[r][d];
+        if (SHMEM_STAGING != 0) {
+            barrier();
+            [[unroll]] for (uint32_t idx = 0; idx < Bc * HSV / 4; idx += gl_WorkGroupSize.x) {
+                uint32_t d = (idx + tid) % (HSV / 4);
+                uint32_t c = (idx + tid) / (HSV / 4);
+                if (idx + gl_WorkGroupSize.x <= Bc * HSV / 4 || c < Bc) {
+                    FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0);
+                    if (!KV_bounds_check || j * Bc + c < KV) {
+#if BLOCK_SIZE > 1
+                        uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
+                        uint ib = coord / BLOCK_SIZE;
+                        uint iqs = (coord % BLOCK_SIZE);
+                        V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
+#else
+                        V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
+#endif
+                    }
+
+                    kvsh[c * kvsh_stride + d] = V_Tf;
+                }
             }
+            barrier();
         }
 
         [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
             if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
                 continue;
             }
+
+            FLOAT_TYPE Pf[rows_per_thread];
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                Pf[r] = FLOAT_TYPE(exp(float(Sf[r][c]) - Mf[r]));
+                Lf[r] += Pf[r];
+            }
+
             [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+                FLOAT_TYPEV4 Vf;
+                if (SHMEM_STAGING != 0) {
+                    Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
+                } else {
 #if BLOCK_SIZE > 1
-                uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
-                uint ib = coord / BLOCK_SIZE;
-                uint iqs = (coord % BLOCK_SIZE);
-                vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
+                    uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
+                    uint ib = coord / BLOCK_SIZE;
+                    uint iqs = (coord % BLOCK_SIZE);
+                    Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
 #else
-                vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
+                    Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
 #endif
-                [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-                    Of[r][d] += Pf[r][c] * Vf;
+                }
+                [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                    Of[r][d] += FLOAT_TYPEV4(Pf[r] * Vf);
                 }
             }
         }
-
-        barrier();
     }
 
     // prevent race on tmpsh
@@ -261,58 +389,115 @@ void main() {
 
     // reduce across threads
 
-    [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-        float rowmaxf, eMf;
+    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+        float rowmaxf = Mf[r];
 
-        tmpsh[tid] = Mf[r];
         // Compute max across the row
-        barrier();
-        [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
-            if (tid < s) {
-                tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]);
+        if (SubGroupSize > 0) {
+            [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
+                rowmaxf = max(rowmaxf, subgroupShuffleXor(rowmaxf, s));
+            }
+            if (row_split == 1) {
+                // Reduce inside workgroup with shmem
+                barrier();
+                if (gl_SubgroupInvocationID == d_tid) {
+                    tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf;
+                }
+                barrier();
+                rowmaxf = tmpsh[d_tid];
+                [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
+                    rowmaxf = max(rowmaxf, tmpsh[s * D_split + d_tid]);
+                }
             }
+        } else {
             barrier();
+            tmpsh[tid] = rowmaxf;
+            barrier();
+            [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
+                if (rowgroup_tid < s) {
+                    tmpsh[tid] = max(tmpsh[tid], tmpsh[tid ^ s]);
+                }
+                barrier();
+            }
+            rowmaxf = tmpsh[row_tid * threads_per_rowgroup + d_tid];
         }
-        rowmaxf = tmpsh[d_tid];
-        barrier();
 
         float Moldf = Mf[r];
 
         // M = max(rowmax, Mold)
         // eM = e^(Mold - M)
         Mf[r] = max(rowmaxf, Moldf);
-        eMf = exp(Moldf - Mf[r]);
+        float eMf = exp(Moldf - Mf[r]);
 
         Lf[r] = eMf*Lf[r];
 
-        tmpsh[tid] = Lf[r];
-
         // Compute sum across the row
-        barrier();
-        [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
-            if (tid < s) {
-                tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s];
+        if (SubGroupSize > 0) {
+            [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
+                Lf[r] += subgroupShuffleXor(Lf[r], s);
+            }
+            if (row_split == 1) {
+                barrier();
+                if (gl_SubgroupInvocationID == d_tid) {
+                    tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r];
+                }
+                barrier();
+                Lf[r] = tmpsh[d_tid];
+                [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
+                    Lf[r] += tmpsh[s * D_split + d_tid];
+                }
             }
+        } else {
+            barrier();
+            tmpsh[tid] = Lf[r];
             barrier();
+            [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
+                if (rowgroup_tid < s) {
+                    tmpsh[tid] = tmpsh[tid] + tmpsh[tid ^ s];
+                }
+                barrier();
+            }
+            Lf[r] = tmpsh[row_tid * threads_per_rowgroup + d_tid];
         }
-        Lf[r] = tmpsh[d_tid];
-        barrier();
 
         [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+            Of[r][d] = FLOAT_TYPE(eMf) * Of[r][d];
 
-            Of[r][d] = eMf * Of[r][d];
-            tmpshv4[tid] = Of[r][d];
-
-            barrier();
-            [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
-                if (tid < s) {
-                    Of[r][d] += tmpshv4[tid + s];
-                    tmpshv4[tid] = Of[r][d];
+            if (SubGroupSize > 0) {
+                [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
+                    if (!OLD_AMD_WINDOWS) {
+                        Of[r][d] += subgroupShuffleXor(Of[r][d], s);
+                    } else {
+                        // Something about f16vec4 subgroupShuffleXor is broken on AMD Windows RDNA2 and below.
+                        // Shuffle full vec4 as workaround.
+                        // See https://github.com/ggml-org/llama.cpp/issues/19881#issuecomment-3958643697
+                        Of[r][d] += FLOAT_TYPEV4(subgroupShuffleXor(vec4(Of[r][d]), s));
+                    }
                 }
+                if (row_split == 1) {
+                    barrier();
+                    if (gl_SubgroupInvocationID == d_tid) {
+                        tmpshv4[gl_SubgroupID * D_split + d_tid] = Of[r][d];
+                    }
+                    barrier();
+                    Of[r][d] = tmpshv4[d_tid];
+                    [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
+                        Of[r][d] += tmpshv4[s * D_split + d_tid];
+                    }
+                }
+            } else {
+                barrier();
+                tmpshv4[tid] = Of[r][d];
                 barrier();
+                [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
+                    if (rowgroup_tid < s) {
+                        Of[r][d] += tmpshv4[tid ^ s];
+                        tmpshv4[tid] = Of[r][d];
+                    }
+                    barrier();
+                }
+                Of[r][d] = tmpshv4[row_tid * threads_per_rowgroup + d_tid];
             }
-            Of[r][d] = tmpshv4[d_tid];
-            barrier();
         }
     }
 
@@ -320,32 +505,53 @@ void main() {
     // If there is split_k, then the split_k resolve shader does the final
     // division by L. Store the intermediate O value and per-row m and L values.
     if (p.k_num > 1) {
-        uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
-
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            if (r < N) {
-                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
-                        perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
+        if (p.gqa_ratio > 1) {
+            // note: O and Q have swapped coord 1,2.
+            uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4;
+
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                const uint row = tile_row(r);
+                if (row < N) {
+                    [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+                        gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N);
                     }
                 }
             }
-        }
 
-        o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            if (r < N) {
-                perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
-                perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
+            o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                const uint row = tile_row(r);
+                if (row < N) {
+                    perElemOpStoreCol0(row, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
+                    perElemOpStoreCol0(row, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
+                }
             }
-        }
+        } else {
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                const uint row = tile_row(r);
+                const uint global_row = i * Br + row;
+
+                if (global_row < N) {
+                    uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4;
 
+                    [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+                        data_ov4[o_offset + iq2 * HSV/4 + d * D_split + d_tid] = D_TYPEV4(Of[r][d]);
+                    }
+                }
+
+                if (global_row < N && d_tid == 0 && col_tid == 0) {
+                    uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
+                    data_o[lm_offset + iq2] = D_TYPE(Lf[r]);
+                    data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]);
+                }
+            }
+        }
         return;
     }
 
     if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
+        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+            float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2);
 
             float ms = 1.0f;
             float vs = 1.0f;
@@ -354,7 +560,7 @@ void main() {
                 ms = exp(Mf[r] - sink);
 
                 [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-                    Of[r][d] *= ms;
+                    Of[r][d] *= FLOAT_TYPE(ms);
                 }
             } else {
                 vs = exp(sink - Mf[r]);
@@ -364,39 +570,37 @@ void main() {
         }
     }
 
-    float Lfrcp[Br];
-    [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+    float Lfrcp[rows_per_thread];
+    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
         Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
     }
 
     [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            Of[r][d] *= Lfrcp[r];
-#if defined(ACC_TYPE_MAX)
-            Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX));
+        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+            Of[r][d] *= FLOAT_TYPE(Lfrcp[r]);
+#if defined(FLOAT_TYPE_MAX)
+            Of[r][d] = clamp(Of[r][d], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX);
 #endif
         }
     }
 
-    uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
+    uint32_t o_offset = (gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV) / 4;
 
     if (p.gqa_ratio > 1) {
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            if (r < N) {
+        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+            const uint row = tile_row(r);
+            if (row < N) {
                 [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
-                        perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
-                    }
+                    gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N);
                 }
             }
         }
     } else {
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            if (i * Br + r < N) {
+        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+            const uint row = tile_row(r);
+            if (i * Br + row < N) {
                 [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
-                        data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
-                    }
+                    data_ov4[o_offset + (iq2 * HSV + (i * Br + row) * p.ne1 * HSV) / 4 + d * D_split + d_tid] = D_TYPEV4(Of[r][d]);
                 }
             }
         }
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl
index eb93903c468..172d38f034e 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl
@@ -1,13 +1,23 @@
 
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
-layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
-layout (constant_id = 1) const uint32_t Br = 1;
-layout (constant_id = 2) const uint32_t Bc = 32;
-layout (constant_id = 3) const uint32_t HSK = 32;
-layout (constant_id = 4) const uint32_t HSV = 32;
-layout (constant_id = 5) const uint32_t Clamp = 0;
-layout (constant_id = 6) const uint32_t D_split = 16;
+layout (constant_id =  0) const uint32_t WorkGroupSize = 128;
+layout (constant_id =  1) const uint32_t Br = 1;
+layout (constant_id =  2) const uint32_t Bc = 32;
+layout (constant_id =  3) const uint32_t HSK = 32;
+layout (constant_id =  4) const uint32_t HSV = 32;
+layout (constant_id =  5) const uint32_t Clamp = 0;
+layout (constant_id =  6) const uint32_t D_split = 16;
+layout (constant_id =  7) const uint32_t row_split = 1;
+layout (constant_id =  8) const uint32_t SubGroupSize = 32;
+layout (constant_id =  9) const uint32_t SHMEM_STAGING = 0;
+layout (constant_id = 10) const uint32_t Flags = 0;
+layout (constant_id = 11) const uint32_t LIMIT_OCCUPANCY_SHMEM = 0;
+
+const bool USE_MASK_OPT    = (Flags & 1) != 0;
+const bool MASK_ENABLE     = (Flags & 2) != 0;
+const bool LOGIT_SOFTCAP   = (Flags & 4) != 0;
+const bool OLD_AMD_WINDOWS = (Flags & 8) != 0;
 
 // Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
 const uint32_t HSK_pad = (HSK + 15) & ~15;
@@ -57,12 +67,17 @@ layout (push_constant) uniform parameter {
 } p;
 
 #define SINK_ENABLE_BIT (1<<24)
-#define MASK_ENABLE_BIT (1<<16)
 #define N_LOG2_MASK 0xFFFF
 
 layout (binding = 4) readonly buffer S {float data_s[];};
 
 layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
+layout (binding = 5) writeonly buffer OV4 {D_TYPEV4 data_ov4[];};
+
+layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];};
+
+#define MASK_OPT_ALL_NEG_INF 1
+#define MASK_OPT_ALL_ZERO 2
 
 #define BINDING_IDX_K 0
 #define BINDING_IDX_V 1
@@ -74,17 +89,21 @@ layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16
 layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
 #endif
 
+#ifndef BLOCK_SIZE
+#define BLOCK_SIZE 1
+#endif
+
 #if defined(DATA_A_F32)
 #undef BLOCK_SIZE
 #define BLOCK_SIZE 4
 #define BLOCK_BYTE_SIZE 16
 
-vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
+FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
     // iqs is currently always zero in the flash attention shaders
     if (binding_idx == BINDING_IDX_K) {
-        return k_packed.k_data_packed[a_offset + ib];
+        return FLOAT_TYPEV4(k_packed.k_data_packed[a_offset + ib]);
     } else {
-        return v_packed.v_data_packed[a_offset + ib];
+        return FLOAT_TYPEV4(v_packed.v_data_packed[a_offset + ib]);
     }
 }
 #endif
@@ -92,7 +111,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
 #if defined(DATA_A_Q4_0)
 #define BLOCK_BYTE_SIZE 18
 
-vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
+FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
     if (binding_idx == BINDING_IDX_K) {
         uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
         uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
@@ -100,7 +119,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
         vui_lo >>= shift;
         vui_hi >>= shift;
 
-        return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
+        return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f));
     } else {
         uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
         uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
@@ -108,24 +127,24 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
         vui_lo >>= shift;
         vui_hi >>= shift;
 
-        return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
+        return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f));
     }
 }
 #endif
 
 #if defined(DATA_A_Q8_0)
 #define BLOCK_BYTE_SIZE 34
-vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
+FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
     if (binding_idx == BINDING_IDX_K) {
         const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
         const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
 
-        return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
+        return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y);
     } else {
         const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
         const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
 
-        return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
+        return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y);
     }
 }
 #endif
@@ -165,7 +184,7 @@ ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC
 }
 
 uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
-         iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
+         gqa_iq1, iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
          q_stride, k_stride, v_stride, m_stride;
 
 void init_indices()
@@ -173,12 +192,25 @@ void init_indices()
     N = p.N;
     KV = p.KV;
 
-    i = gl_WorkGroupID.x;
-    split_k_index = 0;
-
     if (p.k_num > 1) {
+        if (p.gqa_ratio > 1) {
+            i = 0;
+            // batch and split_k share gl_WorkGroupID.x
+            gqa_iq1 = gl_WorkGroupID.x / p.k_num;
+            split_k_index = gl_WorkGroupID.x % p.k_num;
+        } else {
+            gqa_iq1 = 0;
+            split_k_index = gl_WorkGroupID.x % p.k_num;
+            i = gl_WorkGroupID.x / p.k_num;
+        }
+    } else if (p.gqa_ratio > 1) {
         i = 0;
-        split_k_index = gl_WorkGroupID.x;
+        gqa_iq1 = gl_WorkGroupID.x;
+        split_k_index = 0;
+    } else {
+        i = gl_WorkGroupID.x;
+        gqa_iq1 = 0;
+        split_k_index = 0;
     }
 
     Tr = CEIL_DIV(N, Br);
@@ -218,3 +250,15 @@ void init_indices()
     // and breaking the alignment detection.
     m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
 }
+
+// Bias applied to softmax to stay in fp16 range.
+// Based on ggml-cuda issue https://github.com/ggml-org/llama.cpp/issues/18606
+const float FATTN_KQ_MAX_OFFSET = 3.0f*0.6931f;
+
+// Store the output when doing grouped query attention.
+// Rows index by Q's dimension 2, and the first N rows are valid.
+void gqaStore(const in uint32_t r, const in uint32_t c, const in FLOAT_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
+{
+    uint32_t offset = (iq2 + r) * HSV / 4 + c;
+    data_ov4[o_offset + offset] = D_TYPEV4(elems);
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
index c995ab140ee..526e8da384e 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
@@ -7,6 +7,7 @@
 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
 
 #extension GL_KHR_shader_subgroup_basic : enable
+#extension GL_KHR_shader_subgroup_arithmetic : enable
 #extension GL_KHR_shader_subgroup_vote : enable
 #extension GL_KHR_memory_scope_semantics : enable
 #extension GL_KHR_cooperative_matrix : enable
@@ -14,12 +15,12 @@
 #include "types.glsl"
 #include "flash_attn_base.glsl"
 
-const uint32_t HSK_per_thread = HSK / D_split;
-const uint32_t HSV_per_thread = HSV / D_split;
+// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
+const uint32_t MatBr = 16;
+const uint32_t MatBc = 16;
 
-const uint32_t row_split = 4;
 const uint32_t rows_per_thread = Br / row_split;
-const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
+const uint32_t cols_per_iter = gl_WorkGroupSize.x / row_split;
 const uint32_t cols_per_thread = Bc / cols_per_iter;
 
 
@@ -31,33 +32,28 @@ layout (binding = 2) readonly buffer V {float16_t data_v[];};
 layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
 layout (binding = 3) readonly buffer M {float16_t data_m[];};
 
-// Store the output when doing grouped query attention.
-// Rows index by Q's dimension 2, and the first N rows are valid.
-D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
-{
-    uint32_t offset = (iq2 + r) * HSV + c;
-    data_o[o_offset + offset] = D_TYPE(elem);
-    return elem;
-}
-
-// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
-const uint32_t MatBr = 16;
-const uint32_t MatBc = 16;
-
-shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
-shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
+shared float tmpsh[row_split];
 
 const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
 shared f16vec4 Qf[Br * qstride];
 
+const uint psh_stride = Br / 4 + 2;
+shared f16vec4 Psh[Bc * psh_stride];
+
 // Avoid padding for hsk==256 to make it fit in 48KB shmem.
-const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
-shared ACC_TYPE sfsh[Bc * sfshstride];
+const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4;
+shared ACC_TYPEV4 sfsh[Bc * sfshstride];
 
-const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4
-shared f16vec4 ksh[Bc * kshstride];
+const uint32_t D_pad = HSK_pad > HSV_pad ? HSK_pad : HSV_pad;
+const uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2; // in units of f16vec4
+const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups
+const uint vsh_stride = v_cols;
+shared f16vec4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)];
 
-shared float slope[Br];
+const uint32_t osh_stride = row_split * MatBr / 4;
+shared f16vec4 pvsh[MatBc * osh_stride];
+
+shared ACC_TYPE slope[Br];
 
 void main() {
 #ifdef NEEDS_INIT_IQ_SHMEM
@@ -69,9 +65,9 @@ void main() {
     const uint32_t tid = gl_LocalInvocationIndex;
 
     const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
+    const uint32_t d_per_thread = (HSV/4 + threads_per_rowgroup - 1) / threads_per_rowgroup;
     const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
-    const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
-    const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
+    const uint32_t col_tid = gl_LocalInvocationIndex % threads_per_rowgroup;
 
 #define tile_row(r) (row_tid * rows_per_thread + (r))
 
@@ -82,15 +78,10 @@ void main() {
                 Qf[i + tid] = f16vec4(0);
             }
         }
-        [[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) {
-            if (i + tid < Bc * kshstride) {
-                ksh[i + tid] = f16vec4(0);
-            }
-        }
         barrier();
     }
 
-    uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
+    uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02+iq3*p.nb03) / 4;
 
     [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
         uint32_t d = (idx + tid) % (HSK / 4);
@@ -102,10 +93,10 @@ void main() {
     }
     barrier();
 
-    ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
-    [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-            Of[r][d] = ACC_TYPEV4(0.0);
+    f16vec4 Of[rows_per_thread][d_per_thread];
+    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+        [[unroll]] for (uint32_t d = 0; d < d_per_thread; ++d) {
+            Of[r][d] = f16vec4(0.0);
         }
     }
 
@@ -125,15 +116,17 @@ void main() {
             uint r = tid;
             slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
         }
-        barrier();
     } else {
         if (tid < Br) {
             uint r = tid;
-            slope[r] = 1.0;
+            slope[r] = ACC_TYPE(1.0);
         }
-        barrier();
     }
 
+    const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);
+    // mo_offset will point to the tile starting at row i*Br and col 0
+    uint32_t mo_offset = mo_stride * i;
+
 #if BLOCK_SIZE > 1
     uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
     uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
@@ -141,65 +134,114 @@ void main() {
     uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
     uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
 #endif
-    uint32_t m_offset = 0;
+    uint32_t m_offset = gqa_iq1*KV;
     if (p.nem2 != 1 || p.nem3 != 1) {
-        m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
+        m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
+        mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;
     }
 
+    uint32_t mask_opt = 0;
+    uint32_t mask_opt_idx = ~0;
+    uint32_t mask_opt_bits = 0;
+    f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize];
+
     [[dont_unroll]]
     for (uint32_t j = start_j; j < end_j; ++j) {
 
-        float mask_cache[Bc * Br / WorkGroupSize];
-        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
-            bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
-
-            float max_mask = NEG_FLT_MAX_OVER_2;
-            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
-                uint32_t c = (idx + tid) % Bc;
-                uint32_t r = (idx + tid) / Bc;
-                if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
-                    if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
-                        float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
-                        mask_cache[idx / WorkGroupSize] = m;
-                        max_mask = max(max_mask, m);
-                    }
-                }
-            }
-            // skip the block if the mask is entirely -inf
-            bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
-            barrier();
-            if (gl_SubgroupInvocationID == 0) {
-                tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
-            }
-            barrier();
-            [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
-                max_mask = max(max_mask, tmpsh[s]);
+        [[unroll]] for (uint32_t idx = 0; idx < mask_cache.length(); ++idx) {
+            mask_cache[idx] = f16vec4(0);
+        }
+
+        if (MASK_ENABLE) {
+            if (USE_MASK_OPT && mask_opt_idx != j / 16) {
+                mask_opt_idx = j / 16;
+                mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
             }
-            if (max_mask <= NEG_FLT_MAX_OVER_2) {
+            mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
+            if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
+                // skip this block
                 continue;
             }
+            // Only load if the block is not all zeros
+            if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
+                bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
+
+                float max_mask = NEG_FLT_MAX_OVER_2;
+                [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
+                    uint32_t c = (idx + tid) / (Br / 4);
+                    uint32_t r = (idx + tid) % (Br / 4);
+                    if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
+                        if ((!KV_bounds_check || j * Bc + c < KV)) {
+                            f16vec4 m;
+                            if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) {
+                                m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],
+                                            data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
+                                            data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
+                                            data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]);
+                                max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3]));
+                            } else if (i * Br + r * 4 + 2 < p.nem1) {
+                                m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],
+                                            data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
+                                            data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
+                                            0.0);
+                                max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2]));
+                            } else if (i * Br + r * 4 + 1 < p.nem1) {
+                                m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],
+                                            data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
+                                            0.0,
+                                            0.0);
+                                max_mask = max(max(max_mask, float(m[0])), float(m[1]));
+                            } else if (i * Br + r * 4 < p.nem1) {
+                                m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],
+                                            0.0,
+                                            0.0,
+                                            0.0);
+                                max_mask = max(max_mask, float(m[0]));
+                            } else {
+                                m = f16vec4(0.0);
+                            }
+                            mask_cache[idx / WorkGroupSize] = m;
+                        }
+                    }
+                }
+                // skip the block if the mask is entirely -inf
+                bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
+                barrier();
+                if (gl_SubgroupInvocationID == 0) {
+                    tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
+                }
+                barrier();
+                [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
+                    max_mask = max(max_mask, tmpsh[s]);
+                }
+                if (max_mask <= NEG_FLT_MAX_OVER_2) {
+                    continue;
+                }
+            }
         }
 
-        [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
-            uint32_t d = (idx + tid) % (HSK / 4);
-            uint32_t c = (idx + tid) / (HSK / 4);
-            if (c < Bc && d < HSK / 4) {
-                f16vec4 K_Tf = f16vec4(0);
-                if (!KV_bounds_check || j * Bc + c < KV) {
+        if (SHMEM_STAGING != 0) {
+            [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK_pad / 4; idx += gl_WorkGroupSize.x) {
+                uint32_t d = (idx + tid) % (HSK_pad / 4);
+                uint32_t c = (idx + tid) / (HSK_pad / 4);
+                if (idx + gl_WorkGroupSize.x <= Bc * HSK_pad / 4 || c < Bc) {
+                    f16vec4 K_Tf = f16vec4(0);
+                    if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) {
 #if BLOCK_SIZE > 1
-                    uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
-                    uint ib = coord / BLOCK_SIZE;
-                    uint iqs = (coord % BLOCK_SIZE);
-                    K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
+                        uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
+                        uint ib = coord / BLOCK_SIZE;
+                        uint iqs = (coord % BLOCK_SIZE);
+                        K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
 #else
-                    K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
+                        K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
 #endif
-                }
+                    }
 
-                ksh[c * kshstride + d] = K_Tf;
+                    kvsh[c * kvsh_stride + d] = K_Tf;
+                }
             }
+            barrier();
         }
-        barrier();
 
         // K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br
         // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
@@ -208,11 +250,59 @@ void main() {
         coopmat KMat;
         coopmat QMat;
 
-        for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
-            coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
+        [[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
+            // If SHMEM_STAGING is set, a Bc * HSK_pad size tile of K is loaded to shmem
+            // If not, f16 K is loaded directly from global memory if aligned, otherwise
+            // staged through a Bc * MatBr size staging buffer.
+            // If K is not type f16, then it is always staged for dequantization.
+            if (SHMEM_STAGING == 0) {
+#if BLOCK_SIZE == 1
+            if (KV_bounds_check || d * 16 + 16 > HSK) {
+#endif
+            barrier();
+            [[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) {
+                uint32_t col_vec = (idx + tid) % (MatBr / 4);
+                uint32_t row = (idx + tid) / (MatBr / 4);
+                if (idx + tid < Bc * MatBr / 4) {
+                    f16vec4 K_Tf = f16vec4(0);
+                    if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) {
+#if BLOCK_SIZE > 1
+                        uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4;
+                        uint ib = coord / BLOCK_SIZE;
+                        uint iqs = (coord % BLOCK_SIZE);
+                        K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
+#else
+                        K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]);
+#endif
+                    }
+
+                    kvsh[row * kvsh_stride + col_vec] = K_Tf;
+                }
+            }
+            barrier();
+#if BLOCK_SIZE == 1
+            }
+#endif
+
+#if BLOCK_SIZE == 1
+            if (KV_bounds_check || d * 16 + 16 > HSK)
+#endif
+            {
+                uint coord = (gl_SubgroupID * MatBc) * kvsh_stride;
+                coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
+            }
+#if BLOCK_SIZE == 1
+            else {
+                const uint coord = k_offset / 4 + (j * Bc + gl_SubgroupID * MatBc) * k_stride / 4 + d * 16 / 4;
+                coopMatLoad(KMat, data_kv4, coord, k_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
+            }
+#endif
+            } else {
+                uint coord = (gl_SubgroupID * MatBc) * kvsh_stride + d * 16 / 4;
+                coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
+            }
 
-            uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
-            coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
+            coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
 
             SfMat = coopMatMulAdd(KMat, QMat, SfMat);
         }
@@ -221,27 +311,27 @@ void main() {
         coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor);
         barrier();
 
-        if (p.logit_softcap != 0.0f) {
-            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
-                uint32_t c = (idx + tid) / Br;
-                uint32_t r = (idx + tid) % Br;
-                if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
-                    sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));
+        if (LOGIT_SOFTCAP) {
+            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
+                uint32_t c = (idx + tid) / (Br / 4);
+                uint32_t r = (idx + tid) % (Br / 4);
+                if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
+                    sfsh[c * sfshstride + r] = ACC_TYPEV4(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));
                 }
             }
             barrier();
         }
 
-        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
-            bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
-
-            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
-                uint32_t c = (idx + tid) % Bc;
-                uint32_t r = (idx + tid) / Bc;
-                if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
-                    if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
-                        float f = mask_cache[idx / WorkGroupSize];
-                        sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * f);
+        if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
+            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
+                uint32_t c = (idx + tid) / (Br / 4);
+                uint32_t r = (idx + tid) % (Br / 4);
+                if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
+                    if (!KV_bounds_check || j * Bc + c < KV) {
+                        // Mask nem1 bounds check is handled when loading masks
+                        ACC_TYPEV4 masks = ACC_TYPEV4(mask_cache[idx / WorkGroupSize]);
+                        ACC_TYPEV4 slopes = ACC_TYPEV4(slope[r * 4], slope[r * 4 + 1], slope[r * 4 + 2], slope[r * 4 + 3]);
+                        sfsh[c * sfshstride + r] += slopes * masks;
                     }
                 }
             }
@@ -250,143 +340,237 @@ void main() {
 
         float eMf[rows_per_thread];
         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+            const uint r_vec  = tile_row(r) / 4;
+            const uint r_comp = tile_row(r) % 4;
+
             float rowmaxf = NEG_FLT_MAX_OVER_2;
             [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
                 if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
                     continue;
                 }
-                rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride]));
+                rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp]));
             }
             float Moldf = Mf[r];
 
+            // Compute max across the row
+            rowmaxf = subgroupMax(rowmaxf);
+
             // M = max(rowmax, Mold)
             // P = e^(S - M)
             // eM = e^(Mold - M)
             Mf[r] = max(rowmaxf, Moldf);
             eMf[r] = exp(Moldf - Mf[r]);
+
+            Lf[r] = eMf[r]*Lf[r];
         }
 
-        [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+        [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
+            const uint d_local = d0 / threads_per_rowgroup;
             [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-                Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
+                Of[r][d_local] = float16_t(eMf[r]) * Of[r][d_local];
             }
         }
-        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-            Lf[r] = eMf[r]*Lf[r];
-        }
 
+        // Calculate and store Pf in Psh
         [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
-            if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
-                continue;
-            }
-            float Pf[rows_per_thread];
-            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-                Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
-                Lf[r] += Pf[r];
+            const uint col = c * cols_per_iter + col_tid;
+
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; r += 4) {
+                const uint row = tile_row(r);
+                if (KV_bounds_check && j * Bc + col >= KV) {
+                    Psh[col * psh_stride + row / 4] = f16vec4(0.0f);
+                } else {
+                    const vec4 mfvec = vec4(Mf[r], Mf[r + 1], Mf[r + 2], Mf[r + 3]);
+                    const f16vec4 Pf = f16vec4(exp(vec4(sfsh[row / 4 + col * sfshstride]) - mfvec));
+                    [[unroll]] for (uint32_t vec_idx = 0; vec_idx < 4; ++vec_idx) {
+                        Lf[r + vec_idx] += Pf[vec_idx];
+                    }
+                    Psh[col * psh_stride + row / 4] = Pf;
+                }
             }
-            [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+        }
+
+        if (SHMEM_STAGING != 0) {
+            [[unroll]] for (uint32_t idx = 0; idx < Bc * HSV_pad / 4; idx += gl_WorkGroupSize.x) {
+                uint32_t d = (idx + tid) % (HSV_pad / 4);
+                uint32_t c = (idx + tid) / (HSV_pad / 4);
+                if (idx + gl_WorkGroupSize.x <= Bc * HSV_pad / 4 || c < Bc) {
+                    f16vec4 V_Tf = f16vec4(0);
+                    if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) {
 #if BLOCK_SIZE > 1
-                uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
-                uint ib = coord / BLOCK_SIZE;
-                uint iqs = (coord % BLOCK_SIZE);
-                vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
+                        uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
+                        uint ib = coord / BLOCK_SIZE;
+                        uint iqs = (coord % BLOCK_SIZE);
+                        V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
 #else
-                vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
+                        V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
 #endif
-                [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-                    Of[r][d] += ACC_TYPE(Pf[r]) * ACC_TYPEV4(Vf);
+                    }
+
+                    kvsh[c * kvsh_stride + d] = V_Tf;
                 }
             }
         }
-
         barrier();
-    }
 
-    // prevent race on tmpsh
-    barrier();
+        const uint num_hsv_tiles = (HSV + MatBc * row_split - 1) / (MatBc * row_split); // round up
 
-    // reduce across threads
+        // Each subgroup handles HSV/4 columns
+        [[unroll]] for (uint32_t hsv_tile = 0; hsv_tile < num_hsv_tiles; ++hsv_tile) {
+            const uint hsv_offset = (hsv_tile * row_split + gl_SubgroupID) * 16;
 
-    float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread];
-    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-        FLOAT_TYPE M = Mf[r];
-        tmpsh[tid] = M;
-        // Compute max across the row
-        barrier();
-        [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
-            M = max(M, tmpsh[tid ^ s]);
-            barrier();
-            tmpsh[tid] = M;
-            barrier();
-        }
-        rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
-        barrier();
-    }
+            coopmat PVMat = coopmat(0);
 
-    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-        Moldf[r] = Mf[r];
+            // Preload V tiles for [Bc, 16 * num subgroups]
+            const uint v_rows = Bc;
+            const uint v_total = v_rows * v_cols;
+            const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x;
 
-        // M = max(rowmax, Mold)
-        // eM = e^(Mold - M)
-        Mf[r] = max(rowmaxf[r], Moldf[r]);
-        eMf[r] = exp(Moldf[r] - Mf[r]);
+            // If SHMEM_STAGING is set, a Bc * HSV_pad size tile of V is loaded to shmem.
+            // If not, f16 V is loaded directly from global memory if aligned, otherwise
+            // staged through a Bc * MatBr size staging buffer.
+            // If V is not type f16, then it is always staged for dequantization.
+            if (SHMEM_STAGING == 0) {
+#if BLOCK_SIZE == 1
+            // For f16, only preload if not aligned
+            if (KV_bounds_check) {
+#endif
+            [[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) {
+                const uint idx = i * gl_WorkGroupSize.x + tid;
+                const uint row = idx / v_cols;
+                const uint col = idx % v_cols;
 
-        Lf[r] = eMf[r]*Lf[r];
-    }
+                const uint v_row = j * Bc + row;
+                const uint v_col = hsv_tile * MatBc * row_split + col * 4;
 
-    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-        FLOAT_TYPE L = Lf[r];
-        tmpsh[tid] = L;
-        // Compute sum across the row
-        barrier();
-        [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
-            L += tmpsh[tid ^ s];
-            barrier();
-            tmpsh[tid] = L;
+                const uint coord = v_row * v_stride * BLOCK_SIZE + v_col;
+                const uint ib = coord / BLOCK_SIZE;
+                const uint iqs = coord % BLOCK_SIZE;
+
+                if (!KV_bounds_check || (v_row < KV && v_col < HSV)) {
+#if BLOCK_SIZE > 1
+                    kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
+#else
+                    kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
+#endif
+                } else {
+                    kvsh[row * vsh_stride + col] = f16vec4(0.0f);
+                }
+            }
+
+#if BLOCK_SIZE == 1
+            }
+#endif
+            }
             barrier();
-        }
-        Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
-        barrier();
-    }
 
-    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-        [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+            const uint o_offset = gl_SubgroupID * MatBr / 4;
 
-            Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
-            tmpshv4[tid] = Of[r][d];
+            if (hsv_offset < HSV_pad) {
+                [[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) {
+                    coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor);
 
-            barrier();
-            [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
-                Of[r][d] += tmpshv4[tid ^ s];
-                barrier();
-                tmpshv4[tid] = Of[r][d];
-                barrier();
+                    if (SHMEM_STAGING == 0) {
+#if BLOCK_SIZE == 1
+                    if (!KV_bounds_check) {
+                        // F16 values can be loaded directly from global memory
+                        const uint v_tile_row = j * Bc + bc_chunk * MatBc;
+                        const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;
+                        coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
+                    } else
+#endif
+                    {
+                        const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4);
+                        coopMatLoad(QMat, kvsh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);
+                    }
+                    } else {
+                        const uint v_tile_offset = bc_chunk * MatBc * kvsh_stride + (hsv_tile * row_split + gl_SubgroupID) * (MatBc / 4);
+                        coopMatLoad(QMat, kvsh, v_tile_offset, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
+                    }
+
+                    PVMat = coopMatMulAdd(KMat, QMat, PVMat);
+                }
+
+                // Store PVMat to pvsh and load into Of
+                coopMatStore(PVMat, pvsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor);
             }
-            Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup];
+
             barrier();
+
+            const uint hsv_per_tile = row_split * MatBc;
+            const uint hsv_base = hsv_tile * hsv_per_tile;
+            const uint d_values_per_tile = hsv_per_tile / 4;
+
+            const uint d_start = hsv_tile * d_values_per_tile;
+            const uint d_end = min(d_start + d_values_per_tile, HSV / 4);
+
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                const uint row = tile_row(r);
+
+                [[unroll]] for (uint32_t d_local = 0; d_local < d_per_thread; ++d_local) {
+                    const uint d = d_local * threads_per_rowgroup + col_tid;
+                    const uint hsv_col = 4 * d;
+
+                    if (hsv_col >= hsv_base && hsv_col < hsv_base + hsv_per_tile && hsv_col < HSV) {
+                        const uint local_hsv = (hsv_col - hsv_base) / 4;
+                        Of[r][d_local] += pvsh[row * osh_stride + local_hsv];
+                    }
+                }
+            }
         }
+
+        barrier();
+    }
+
+    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+        Lf[r] = subgroupAdd(Lf[r]);
     }
 
     // If there is split_k, then the split_k resolve shader does the final
     // division by L. Store the intermediate O value and per-row m and L values.
     if (p.k_num > 1) {
-        uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
+        if (p.gqa_ratio > 1) {
+            // note: O and Q have swapped coord 1,2.
+            uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4;
 
-        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-            if (tile_row(r) < N) {
-                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
-                        perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                if (tile_row(r) < N) {
+                    [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
+                        const uint d = d0 + col_tid;
+                        if (d >= HSV/4) break;
+                        const uint d_local = d0 / threads_per_rowgroup;
+                        gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N);
                     }
                 }
             }
-        }
 
-        o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
-        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-            if (tile_row(r) < N) {
-                perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
-                perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
+            o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                if (tile_row(r) < N) {
+                    perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
+                    perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
+                }
+            }
+        } else {
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                const uint row = tile_row(r);
+                const uint global_row = i * Br + row;
+
+                if (global_row < N) {
+                    uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4;
+
+                    [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
+                        const uint d = d0 + col_tid;
+                        if (d >= HSV/4) break;
+                        data_ov4[o_offset + iq2 * HSV/4 + d] = D_TYPEV4(Of[r][d/threads_per_rowgroup]);
+                    }
+                }
+
+                if (global_row < N && col_tid == 0) {
+                    uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
+                    data_o[lm_offset + iq2] = D_TYPE(Lf[r]);
+                    data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]);
+                }
             }
         }
 
@@ -403,8 +587,9 @@ void main() {
             if (sink > Mf[r]) {
                 ms = exp(Mf[r] - sink);
 
-                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-                    Of[r][d] *= ACC_TYPE(ms);
+                [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
+                    const uint d_local = d0 / threads_per_rowgroup;
+                    Of[r][d_local] *= float16_t(ms);
                 }
             } else {
                 vs = exp(sink - Mf[r]);
@@ -419,34 +604,37 @@ void main() {
         Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
     }
 
-    [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+    [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
+        const uint d_local = d0 / threads_per_rowgroup;
         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-            Of[r][d] *= ACC_TYPE(Lfrcp[r]);
-#if defined(ACC_TYPE_MAX)
-            Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX);
+            Of[r][d_local] *= float16_t(Lfrcp[r]);
+#if defined(FLOAT_TYPE_MAX)
+            Of[r][d_local] = clamp(Of[r][d_local], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX);
 #endif
         }
     }
 
-    uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
+    uint32_t o_offset = (gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV) / 4;
 
     if (p.gqa_ratio > 1) {
         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
             if (tile_row(r) < N) {
-                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
-                        perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
-                    }
+                [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
+                    const uint d = d0 + col_tid;
+                    if (d >= HSV / 4) break;
+                    const uint d_local = d0 / threads_per_rowgroup;
+                    gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N);
                 }
             }
         }
     } else {
         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
             if (i * Br + tile_row(r) < N) {
-                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
-                        data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
-                    }
+                [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
+                    const uint d = d0 + col_tid;
+                    if (d >= HSV / 4) break;
+                    const uint d_local = d0 / threads_per_rowgroup;
+                    data_ov4[o_offset + (iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV) / 4 + d] = D_TYPEV4(Of[r][d_local]);
                 }
             }
         }
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
index 9a71996383d..0ea181342ce 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
@@ -55,7 +55,7 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
     return max(elem0, elem1);
 }
 
-#if defined(BLOCK_SIZE)
+#if BLOCK_SIZE > 1
 #define DECODEFUNC , DEQUANTFUNC
 #else
 #define DECODEFUNC
@@ -72,6 +72,28 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
     return elem;
 }
 
+// Store O values for non-GQA split_k. Rows are tokens, not heads.
+D_TYPE perElemOpNonGqaSplitKStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t unused, const in uint32_t iq2, const in uint32_t N) {
+    uint32_t global_row = i * Br + r;
+    if (global_row < N && c < HSV) {
+        uint32_t o_off = HSV * p.ne1
+            * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
+        data_o[o_off + iq2 * HSV + c] = D_TYPE(elem);
+    }
+    return elem;
+}
+
+// Store L/M values for non-GQA split_k.
+ACC_TYPE perElemOpNonGqaSplitKStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t lm_base, const in uint32_t iq2, const in uint32_t N) {
+    uint32_t global_row = i * Br + r;
+    if (global_row < N && c == 0) {
+        uint32_t lm_off = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3
+            + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
+        data_o[lm_off + lm_base + iq2] = D_TYPE(elem);
+    }
+    return elem;
+}
+
 void main() {
 #ifdef NEEDS_INIT_IQ_SHMEM
     init_iq_shmem(gl_WorkGroupSize);
@@ -85,7 +107,7 @@ void main() {
 
     tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
 
-#if defined(BLOCK_SIZE)
+#if BLOCK_SIZE > 1
     tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE);
     tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
 #endif
@@ -98,7 +120,7 @@ void main() {
     if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
     {
         q_stride &= ~7;
-#if !defined(BLOCK_SIZE)
+#if BLOCK_SIZE == 1
         k_stride &= ~7;
         v_stride &= ~7;
 #endif
@@ -111,13 +133,13 @@ void main() {
     coopmat Q;
     coopmat Qf16;
 
-    uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
+    uint32_t q_offset = gqa_iq1*p.nb01*4/*sizeof(float)*/ + iq2*p.nb02+iq3*p.nb03;
     coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad));
 
     Qf16 = coopmat(Q);
     Qf16 *= float16_t(p.scale);
 
-    coopmat O = coopmat(0);
+    coopmat O = coopmat(0);
 
     coopmat L, M;
 
@@ -138,48 +160,67 @@ void main() {
         coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
     }
 
-    uint32_t m_offset = 0;
+    const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);
+    // mo_offset will point to the tile starting at row i*Br and col 0
+    uint32_t mo_offset = mo_stride * i;
+
+    uint32_t m_offset = gqa_iq1*KV * 2 /*sizeof(float16_t)*/;
     if (p.nem2 != 1 || p.nem3 != 1) {
-        m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
+        m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
+        mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;
     }
 
+    uint32_t mask_opt = 0;
+    uint32_t mask_opt_idx = ~0;
+
     [[dont_unroll]]
     for (uint32_t j = start_j; j < end_j; ++j) {
 
-        coopmat mv;
-        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
-            bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
-
-            if (nem1_bounds_check) {
-                tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
-                tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
-                tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
-                tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
-
-                coopmat mvmax;
-
-                coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
-
-                // skip the block if the mask is entirely -inf
-                coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
-                if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
-                    continue;
-                }
-            } else {
-                tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
-                // Don't clamp against nem1 when GQA is enabled
-                uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
-                tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
-                tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
-
-                coopmat mvmax;
-
-                coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
+        coopmat mv = coopmat(0);
+        if (MASK_ENABLE) {
 
-                // skip the block if the mask is entirely -inf
-                coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
-                if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
-                    continue;
+            if (USE_MASK_OPT && mask_opt_idx != j / 16) {
+                mask_opt_idx = j / 16;
+                mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
+            }
+            uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
+            if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
+                // skip this block
+                continue;
+            }
+            // Only load if the block is not all zeros
+            if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
+                bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
+
+                if (nem1_bounds_check) {
+                    tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
+                    tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
+                    tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
+                    tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
+
+                    coopmat mvmax;
+
+                    coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
+                    // skip the block if the mask is entirely -inf
+                    coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
+                    if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
+                        continue;
+                    }
+                } else {
+                    tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
+                    // Don't clamp against nem1 when GQA is enabled
+                    uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
+                    tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
+                    tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
+
+                    coopmat mvmax;
+
+                    coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
+                    // skip the block if the mask is entirely -inf
+                    coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
+                    if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
+                        continue;
+                    }
                 }
             }
         }
@@ -192,14 +233,14 @@ void main() {
         coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
         S = coopMatMulAdd(Qf16, K_T, S);
 
-        if (p.logit_softcap != 0.0f) {
+        if (LOGIT_SOFTCAP) {
             [[unroll]]
             for (int k = 0; k < S.length(); ++k) {
                 S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
             }
         }
 
-        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
+        if (MASK_ENABLE) {
             S += slopeMat*coopmat(mv);
         }
 
@@ -218,6 +259,8 @@ void main() {
 
         coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce);
 
+        rowmax += coopmat(FATTN_KQ_MAX_OFFSET);
+
         coopmat Mold = M;
 
         // M = max(rowmax, Mold)
@@ -260,11 +303,8 @@ void main() {
         // resize eM by using smear/reduce
         coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
 
-        // multiply with fp16 accumulation, then add to O.
-        coopmat PV = coopmat(0);
-        PV = coopMatMulAdd(P_A, V, PV);
-
-        O = eMdiag * O + coopmat(PV);
+        O *= coopmat(eMdiag);
+        O = coopMatMulAdd(P_A, V, O);
     }
 
     // If there is split_k, then the split_k resolve shader does the final
@@ -272,12 +312,19 @@ void main() {
     if (p.k_num > 1) {
         coopmat O_D = coopmat(O);
 
-        uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
-        coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
-
-        o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
-        coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
-        coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
+        if (p.gqa_ratio > 1) {
+            // note: O and Q have swapped coord 1,2.
+            uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
+            coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
+
+            o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
+            coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
+            coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
+        } else {
+            coopMatPerElementNV(O_D, O_D, perElemOpNonGqaSplitKStore, 0u, iq2, N);
+            coopMatPerElementNV(L, L, perElemOpNonGqaSplitKStoreCol0, 0u, iq2, N);
+            coopMatPerElementNV(M, M, perElemOpNonGqaSplitKStoreCol0, p.ne1, iq2, N);
+        }
         return;
     }
 
@@ -305,7 +352,7 @@ void main() {
             if (sink > Mr[i]) {
                 ms = exp(Mr[i] - sink);
 
-                O[i] *= ms;
+                O[i] *= float16_t(ms);
             } else {
                 vs = exp(sink - Mr[i]);
             }
@@ -319,15 +366,16 @@ void main() {
         Ldiag[k] = (Ldiag[k] == 0.0) ? ACC_TYPE(0.0) : (ACC_TYPE(1.0) / Ldiag[k]);
     }
 
-    O = Ldiag*O;
+    coopmat O_D = coopmat(O);
+
+    O_D = coopmat(Ldiag)*O_D;
 
 #if defined(ACC_TYPE_MAX)
-    [[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
+    [[unroll]] for (uint i = 0; i < O_D.length(); ++i) { O_D[i] = clamp(O_D[i], D_TYPE(-ACC_TYPE_MAX), D_TYPE(ACC_TYPE_MAX)); }
 #endif
 
-    uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
+    uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
 
-    coopmat O_D = coopmat(O);
     if (p.gqa_ratio > 1) {
         coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
     } else {
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp
new file mode 100644
index 00000000000..8c92c1adcda
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp
@@ -0,0 +1,142 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_EXT_shader_16bit_storage : enable
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 128;
+layout (constant_id = 1) const uint NUM_SUBGROUPS = 4;
+layout (constant_id = 2) const uint Br = 32;
+layout (constant_id = 3) const uint Bc = 32;
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {float16_t data_a[];};
+layout (binding = 0) readonly buffer Av4 {f16vec4 data_av4[];};
+layout (binding = 1) writeonly buffer D {uint data_d[];};
+
+layout (push_constant) uniform parameter {
+    uint nem0;
+    uint nem1;
+    uint nem2;
+    uint nbm1;
+    uint nbm2;
+    uint nbm3;
+    uint nbd1;
+    uint nbd2;
+    uint nbd3;
+};
+
+#define MASK_OPT_ALL_NEG_INF 1
+#define MASK_OPT_ALL_ZERO 2
+
+shared float minsh[NUM_SUBGROUPS];
+shared float maxsh[NUM_SUBGROUPS];
+
+// For each Br x Bc block of the mask (input) buffer, read all values and check
+// if it's all -inf or all zero. Write out a two-bit code indicating which it is
+// (or zero for neither). Each workgroup processes 16 tiles and writes out a
+// 32-bit result mask.
+//
+// TODO: This is a lot of work per workgroup, might make sense to split this into
+// more workgroups in the future.
+void main() {
+    // Each workgroup handles a row
+    const uint tid = gl_LocalInvocationIndex;
+    const uint i0 = gl_WorkGroupID.x;
+    const uint i1 = gl_WorkGroupID.y;
+    const uint i2 = gl_WorkGroupID.z % nem2;
+    const uint i3 = gl_WorkGroupID.z / nem2;
+
+    float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF);
+
+    uint result = 0;
+
+    // Fast path for fully in-bounds blocks where we can do f16vec4 loads
+    if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 &&
+        ((Br * Bc) % (BLOCK_SIZE * 4)) == 0) {
+        [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
+            float min_v = FLT_MAX_OVER_2;
+            float max_v = -FLT_MAX_OVER_2;
+            [[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) {
+                uint j0 = (i + tid) % (Bc / 4);
+                uint j1 = (i + tid) / (Bc / 4);
+
+                j0 *= 4;
+                j0 += (i0 * 16 + block_x) * Bc;
+                j1 += i1 * Br;
+
+                vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]);
+                [[unroll]] for (int c = 0; c < 4; ++c) {
+                    min_v = min(min_v, f[c]);
+                    max_v = max(max_v, f[c]);
+                }
+            }
+            min_v = subgroupMin(min_v);
+            max_v = subgroupMax(max_v);
+            if (gl_SubgroupInvocationID == 0) {
+                minsh[gl_SubgroupID] = min_v;
+                maxsh[gl_SubgroupID] = max_v;
+            }
+            barrier();
+            if (tid == 0) {
+                [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
+                    min_v = min(min_v, minsh[i]);
+                    max_v = max(max_v, maxsh[i]);
+                }
+                if (max_v <= -FLT_MAX_OVER_2) {
+                    result |= 1 << (2*block_x);
+                }
+                if (min_v == 0.0f && max_v == 0.0f) {
+                    result |= 2 << (2*block_x);
+                }
+            }
+            barrier();
+        }
+    } else {
+        [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
+            float min_v = FLT_MAX_OVER_2;
+            float max_v = -FLT_MAX_OVER_2;
+            [[unroll]] for (uint i = 0; i < Br * Bc; i += BLOCK_SIZE) {
+                if ((Br * Bc % BLOCK_SIZE) != 0 && i + tid >= Br * Bc) {
+                    continue;
+                }
+                uint j0 = (i + tid) % Bc;
+                uint j1 = (i + tid) / Bc;
+
+                j0 += (i0 * 16 + block_x) * Bc;
+                j1 += i1 * Br;
+
+                if (j0 < nem0 && j1 < nem1) {
+                    float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]);
+                    min_v = min(min_v, f);
+                    max_v = max(max_v, f);
+                }
+            }
+            min_v = subgroupMin(min_v);
+            max_v = subgroupMax(max_v);
+            if (gl_SubgroupInvocationID == 0) {
+                minsh[gl_SubgroupID] = min_v;
+                maxsh[gl_SubgroupID] = max_v;
+            }
+            barrier();
+            if (tid == 0) {
+                [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
+                    min_v = min(min_v, minsh[i]);
+                    max_v = max(max_v, maxsh[i]);
+                }
+                if (max_v <= -FLT_MAX_OVER_2) {
+                    result |= 1 << (2*block_x);
+                }
+                if (min_v == 0.0f && max_v == 0.0f) {
+                    result |= 2 << (2*block_x);
+                }
+            }
+            barrier();
+        }
+    }
+
+    if (tid == 0) {
+        data_d[i0 + i1 * nbd1 + i2 * nbd2 + i3 * nbd3] = result;
+    }
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp
index 4eaddd31a8f..68917fc0bb0 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp
@@ -12,7 +12,8 @@ layout (binding = 2) writeonly buffer D {float data_d[];};
 
 layout (push_constant) uniform parameter {
     uint D;
-    uint N;
+    uint ne1;
+    uint ne2;
     uint ne3;
     uint k_num;
     uint sinks;
@@ -24,15 +25,15 @@ void main() {
     // Each workgroup handles a row
     const uint n = gl_WorkGroupID.x;
     const uint tid = gl_LocalInvocationID.x;
-    const uint iq3 = gl_WorkGroupID.z;
+    const uint i2 = gl_WorkGroupID.z % p.ne2;
+    const uint i3 = gl_WorkGroupID.z / p.ne2;
 
     uint D = p.D;
-    uint N = p.N;
     uint k_num = p.k_num;
 
-    uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n;
-    uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n;
-    uint lm_stride = N * 2;
+    uint l_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + n;
+    uint m_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + p.ne1 + n;
+    uint lm_stride = p.ne1 * 2;
 
     // Compute the max m value for the row
     float m_max = -1.0/0.0;
@@ -99,7 +100,7 @@ void main() {
     if (d < D) {
         float O = 0.0;
         [[unroll]] for (uint k = 0; k < k_num; ++k) {
-            uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
+            uint o_offset = D * p.ne1 * (k + p.k_num * (i2 + p.ne2 * i3)) + D * n + d;
             float m = data_a[m_offset + k * lm_stride];
             O += exp(m - m_max) * data_a[o_offset];
         }
@@ -115,6 +116,6 @@ void main() {
         const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF);
         O = clamp(O, -FLT_MAX, FLT_MAX);
 
-        data_d[iq3 * D * N + D * n + d] = O;
+        data_d[(i3 * p.ne2 + i2) * p.ne1 * D + D * n + d] = O;
     }
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl
index 66e46ae6796..3797901f043 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl
@@ -6,4 +6,6 @@ layout (push_constant) uniform parameter
     uint KY;
     float param1;
     float param2;
+    float param3;
+    float param4;
 } p;
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp
index 1827d647a21..db14f5a3cf3 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp
@@ -19,6 +19,7 @@ layout (push_constant) uniform parameter
     int s0; int s1;
     int p0; int p1;
     int d0; int d1;
+    uint batch_IC;
 } p;
 
 layout(constant_id = 0) const uint BLOCK_SIZE = 32;
@@ -34,12 +35,12 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
 layout (buffer_reference) buffer D_ptr {D_TYPE d;};
 #endif
 
-void main() {
+void im2col(const uint y, const uint z) {
     const uint gidx = gl_GlobalInvocationID.x;
 
-    const uint oh = gl_GlobalInvocationID.y;
-    const uint batch = gl_GlobalInvocationID.z / p.IC;
-    const uint ic = gl_GlobalInvocationID.z % p.IC;
+    const uint oh = y;
+    const uint batch = z / p.IC;
+    const uint ic = z % p.IC;
 
     const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
     const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);
@@ -101,3 +102,15 @@ void main() {
 #endif
     }
 }
+
+void main() {
+    uint y = gl_GlobalInvocationID.y;
+    while (y < p.OH) {
+        uint z = gl_GlobalInvocationID.z;
+        while (z < p.batch_IC) {
+            im2col(y, z);
+            z += gl_NumWorkGroups.z;
+        }
+        y += gl_NumWorkGroups.y;
+    }
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp
index 83ef2f87958..7d0a1de0df9 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp
@@ -1,6 +1,6 @@
 #version 450
 
-#include "generic_head.glsl"
+#include "generic_unary_head.glsl"
 #include "types.glsl"
 
 #extension GL_EXT_control_flow_attributes : enable
@@ -8,19 +8,22 @@
 
 layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
 
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
 shared FLOAT_TYPE sum[BLOCK_SIZE];
 
 void main() {
     const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
     const uint tid = gl_LocalInvocationID.x;
 
+    const uint i3 = row / (p.ne11 * p.ne12);
+    const uint i3_offset = i3 * p.ne12 * p.ne11;
+    const uint i2 = (row - i3_offset) / p.ne11;
+    const uint i2_offset = i2 * p.ne11;
+    const uint i1 = row - i3_offset - i2_offset;
+
     sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
 
-    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
-        const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
+    [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
+        const FLOAT_TYPE xi = FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]);
         sum[tid] += xi * xi;
     }
 
@@ -35,7 +38,7 @@ void main() {
 
     const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1)));
 
-    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
-        data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
+    [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
+        data_d[i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0] = D_TYPE(scale * FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]));
     }
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
index b3c96576deb..2271be4021b 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
@@ -87,7 +87,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     const uint tid = gl_LocalInvocationID.x;
 
     get_offsets(a_offset, b_offset, d_offset);
-    a_offset /= QUANT_K;
 
     y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
 
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl
index cfc8b0c7f4b..4aeda68c7f2 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl
@@ -29,7 +29,10 @@ layout (push_constant) uniform parameter
 #ifdef MUL_MAT_ID
     uint nei0;
     uint ne11;
+    uint expert_i1;
+    uint nbi1;
 #else
+    uint base_work_group_y;
     uint ne02;
     uint ne12;
     uint broadcast2;
@@ -43,9 +46,9 @@ uint expert_id;
 
 void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
 #ifdef MUL_MAT_ID
-    const uint expert_idx = gl_GlobalInvocationID.y;
+    const uint expert_i0 = gl_WorkGroupID.y;
 #else
-    const uint batch_idx = gl_GlobalInvocationID.y;
+    const uint batch_idx = gl_WorkGroupID.y + p.base_work_group_y;
 #endif
 
 #ifndef MUL_MAT_ID
@@ -60,24 +63,24 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
         batch_idx_a = i03 * p.ne02 + i02;
     }
 #else
-    expert_id = data_ids[expert_idx];
+    expert_id = data_ids[expert_i0 + p.expert_i1 * p.nbi1];
 #endif
 
     a_offset =
 #ifdef MUL_MAT_ID
-            expert_id * p.batch_stride_a;
+            expert_id * (p.batch_stride_a / QUANT_K);
 #else
-            batch_idx_a * p.batch_stride_a;
+            batch_idx_a * (p.batch_stride_a / QUANT_K);
 #endif
     b_offset =
 #ifdef MUL_MAT_ID
-            (expert_idx % p.ne11) * p.stride_b;
+            (expert_i0 % p.ne11) * p.stride_b + p.expert_i1 * p.batch_stride_b;
 #else
             batch_idx * p.batch_stride_b;
 #endif
     d_offset =
 #ifdef MUL_MAT_ID
-            expert_idx * p.stride_d;
+            expert_i0 * p.stride_d + p.expert_i1 * p.batch_stride_d;
 #else
             batch_idx * p.batch_stride_d;
 #endif
@@ -103,12 +106,12 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t
                     temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
                 }
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
-                    const uint expert_idx = gl_GlobalInvocationID.y;
-                    temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
+                    const uint expert_i0 = gl_GlobalInvocationID.y;
+                    temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]);
                 }
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
-                    const uint expert_idx = gl_GlobalInvocationID.y;
-                    temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
+                    const uint expert_i0 = gl_GlobalInvocationID.y;
+                    temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]);
                 }
 #else
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
@@ -158,12 +161,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
                     temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
                 }
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
-                    const uint expert_idx = gl_GlobalInvocationID.y;
-                    temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
+                    const uint expert_i0 = gl_GlobalInvocationID.y;
+                    temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]);
                 }
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
-                    const uint expert_idx = gl_GlobalInvocationID.y;
-                    temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
+                    const uint expert_i0 = gl_GlobalInvocationID.y;
+                    temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]);
                 }
 #else
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
@@ -203,12 +206,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
                     tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
                 }
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
-                    const uint expert_idx = gl_GlobalInvocationID.y;
-                    tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_idx]);
+                    const uint expert_i0 = gl_GlobalInvocationID.y;
+                    tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_i0]);
                 }
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
-                    const uint expert_idx = gl_GlobalInvocationID.y;
-                    tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_idx]);
+                    const uint expert_i0 = gl_GlobalInvocationID.y;
+                    tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_i0]);
                 }
 #else
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp
index e5cc7ff8629..3ea24a76cec 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp
@@ -11,7 +11,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32,
                                const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
     // Compute starting index in matrix B for this superblock
     const uint y_idx = i * QUANT_K + 32 * ib32;
-    uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
+    uint ibi = a_offset + first_row * num_blocks_per_row + i;
 
     // Precompute indices for quantization lookup tables
     const uint qh_base = 2 * ib32;
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp
index c5f5e9cbb2b..fd953c8fadd 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp
@@ -17,7 +17,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32,
             const vec4 b_val_1 = vec4(data_b_v4[base_b_idx + 2 * l + 1]);
 
             // index for data_a
-            uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
+            uint ibi = a_offset + first_row * num_blocks_per_row + i;
 
             [[unroll]] for (uint n = 0; n < num_rows; ++n) {
                 const float d = float(data_a[ibi].d);
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp
index e424af12c5a..b4f6d1d6b64 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp
@@ -12,7 +12,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
     const uint nibble_shift = 4 * (itid & 1);
     const uint ib32 = itid / 2; // 0..7
 
-    uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
+    uint ibi = a_offset + first_row * num_blocks_per_row + i;
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
         const float d = float(data_a[ibi].d);
         const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF;
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp
index 0cd906dbbf4..d8dafe5f709 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp
@@ -11,36 +11,54 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
     const uint y_idx = i * QUANT_K + 16 * itid;
     const uint nibble_shift = 4 * (itid & 1);
     const uint ib32 = itid / 2; // 0..7
-
-    uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
+    uint ibi = a_offset + first_row * num_blocks_per_row + i;
+    // Precompute db multiplication factors
+    float db_vals[NUM_ROWS];
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
         const float d = float(data_a[ibi].d);
-        const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF;
-        const float db = d * (0.5 + scale) * 0.25;
-
+        const uint scale_raw = data_a[ibi].scales[ib32];
+        const uint scale = (scale_raw >> nibble_shift) & 0xF;
+        // Merge constant calculations d * (0.5 + scale) * 0.25 = d*0.125 + d*scale*0.25
+        db_vals[n] = d * (0.125f + float(scale) * 0.25f);
+        ibi += num_blocks_per_row;
+    }
+    ibi = a_offset + first_row * num_blocks_per_row + i;
+    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+        // Preload grid and sign data for all l values
+        vec4 grid0_vals[2], grid1_vals[2];
+        uint sign_vals[2], sign7_vals[2];
         [[unroll]] for (uint l = 0; l < 2; ++l) {
             const uint qs = data_a[ibi].qs[2 * itid + l];
-            const uint sign = qs >> 9;
-            const uint sign7 = bitCount(sign);
-            const vec4 grid0 = vec4(unpack8(iq2xs_grid[qs & 511].x));
-            const vec4 grid1 = vec4(unpack8(iq2xs_grid[qs & 511].y));
-
-            [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
-                vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
-                vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
-
-                FLOAT_TYPE sum =
-                      fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign &   1) != 0 ? -grid0.x : grid0.x),
-                      fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign &   2) != 0 ? -grid0.y : grid0.y),
-                      fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign &   4) != 0 ? -grid0.z : grid0.z),
-                      fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign &   8) != 0 ? -grid0.w : grid0.w),
-                      fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign &  16) != 0 ? -grid1.x : grid1.x),
-                      fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign &  32) != 0 ? -grid1.y : grid1.y),
-                      fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign &  64) != 0 ? -grid1.z : grid1.z),
-                      fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 &  1) != 0 ? -grid1.w : grid1.w),
-                      FLOAT_TYPE(0.0)))))))));
-                temp[j][n] = fma(db, sum, temp[j][n]);
+            sign_vals[l] = qs >> 9;
+            sign7_vals[l] = bitCount(sign_vals[l]);
+            const uvec2 grid_data = iq2xs_grid[qs & 511];
+            grid0_vals[l] = vec4(unpack8(grid_data.x));
+            grid1_vals[l] = vec4(unpack8(grid_data.y));
+        }
+        // Preload B data for all j columns (reduce repeated index calculations)
+        [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+            FLOAT_TYPE sum = FLOAT_TYPE(0.0);
+            [[unroll]] for (uint l = 0; l < 2; ++l) {
+                const uint sign = sign_vals[l];
+                const uint sign7 = sign7_vals[l];
+                const vec4 grid0 = grid0_vals[l];
+                const vec4 grid1 = grid1_vals[l];
+                // Precompute indices
+                const uint b_idx = (j * p.batch_stride_b + b_offset + y_idx) / 4 + 2 * l;
+                const vec4 b0 = vec4(data_b_v4[b_idx + 0]);
+                const vec4 b4 = vec4(data_b_v4[b_idx + 1]);
+                sum +=
+                    fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign &   1) != 0 ? -grid0.x : grid0.x),
+                    fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign &   2) != 0 ? -grid0.y : grid0.y),
+                    fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign &   4) != 0 ? -grid0.z : grid0.z),
+                    fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign &   8) != 0 ? -grid0.w : grid0.w),
+                    fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign &  16) != 0 ? -grid1.x : grid1.x),
+                    fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign &  32) != 0 ? -grid1.y : grid1.y),
+                    fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign &  64) != 0 ? -grid1.z : grid1.z),
+                    fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 &  1) != 0 ? -grid1.w : grid1.w),
+                    FLOAT_TYPE(0.0)))))))));
             }
+            temp[j][n] = fma(FLOAT_TYPE(db_vals[n]), sum, temp[j][n]);
         }
         ibi += num_blocks_per_row;
     }
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp
index 71bd72d17e3..f75dcf8331d 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp
@@ -11,7 +11,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
     const uint y_idx = i * QUANT_K + 16 * itid;
     const uint ib32 = itid / 2; // 0..7
 
-    uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
+    uint ibi = a_offset + first_row * num_blocks_per_row + i;
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
         const float d = float(data_a[ibi].d);
         const uint signscale = pack32(u16vec2(
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp
index a4b9ab1f94f..5cdf2a89d0f 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp
@@ -10,7 +10,7 @@ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
 void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
     const uint y_idx = i * QUANT_K + 32 * ib32;
 
-    uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
+    uint ibi = a_offset + first_row * num_blocks_per_row + i;
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
         const float d = float(data_a[ibi].d);
         const uint scale = (data_a[ibi].scales[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp
index 40849c691f2..a88898109ab 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp
@@ -11,7 +11,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
     const uint y_idx = i * QUANT_K + 16 * itid;
     const uint ib32 = itid / 2; // 0..7
 
-    uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
+    uint ibi = a_offset + first_row * num_blocks_per_row + i;
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
         const float d = float(data_a[ibi].d);
         const uint signscale = pack32(u16vec2(
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
index 14093c0de5a..619de054cb8 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
@@ -15,7 +15,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
     const uint y_idx = i * QUANT_K + y_offset;
 
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-        const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+        const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;
         csel ^= 1;
 
         if (!all_threads) { // when we don't have enough blocks to use all threads
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp
index 528f224d86b..93e48b79012 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp
@@ -14,7 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co
     const uint y_idx = i * QUANT_K + y_offset;
 
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-        const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+        const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;
         csel ^= 1;
 
         if (!all_threads) { // when we don't have enough blocks to use all threads
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
index 49d91ad5910..6af5a81587d 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
@@ -13,7 +13,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
     const uint y2_idx = y1_idx + 128;
 
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-        const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+        const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;
         const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
 
         const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im    ];
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp
index 0d61b4966ec..3695b47b98d 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp
@@ -13,7 +13,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
     const uint y2_idx = y1_idx + 128;
 
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-        const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+        const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;
         const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
 
         const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im    ];
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp
index d7a7f6426ee..3e89d91cbb0 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp
@@ -15,7 +15,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
     const uint y_idx = i * QUANT_K + y_offset;
 
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-        const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+        const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;
         csel ^= 1;
 
         if (!all_threads) { // when we don't have enough blocks to use all threads
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp
index 15f005be3ea..6fe3e2dc043 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp
@@ -14,6 +14,8 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 #define K_PER_ITER 8
 #elif defined(DATA_A_QUANT_K)
 #define K_PER_ITER 16
+#elif defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M)
+#define K_PER_ITER 32
 #else
 #error unimplemented
 #endif
@@ -49,6 +51,15 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
         cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 1];
         cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 2];
         cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 3];
+#elif K_PER_ITER == 32
+        cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8    ];
+        cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 1];
+        cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 2];
+        cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 3];
+        cache_b_qs[4] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 4];
+        cache_b_qs[5] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 5];
+        cache_b_qs[6] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 6];
+        cache_b_qs[7] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 7];
 #else
 #error unimplemented
 #endif
@@ -68,7 +79,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
     const uint tid = gl_LocalInvocationID.x;
 
     get_offsets(a_offset, b_offset, d_offset);
-    a_offset /= QUANT_K_Q8_1;
+    a_offset *= QUANT_K / QUANT_K_Q8_1;
     b_offset /= QUANT_K_Q8_1;
 
     FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl
index 2389ea0b1ec..6ddbed309d7 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl
@@ -377,3 +377,118 @@ FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
     return FLOAT_TYPE(float(cache_b_ds.x) * float(d_scale) * float(q_sum));
 }
 #endif
+
+#if defined(DATA_A_IQ1_S)
+void repack8(uint ib, uint iqs, out i32vec4 out0, out i32vec4 out1) {
+    const uint ib32 = iqs / 32;
+
+    const uint qh = data_a[ib].qh[ib32];
+
+    const uint qs16_0 = data_a_packed16[ib].qs[(4 * ib32 + 0) / 2];
+    const uint qs16_1 = data_a_packed16[ib].qs[(4 * ib32 + 2) / 2];
+
+    const uint qs0 = qs16_0 & 0xFF;
+    const uint qs1 = qs16_0 >> 8;
+    const uint qs2 = qs16_1 & 0xFF;
+    const uint qs3 = qs16_1 >> 8;
+
+    const uint hi0 = bitfieldExtract(qh, 3 * int(0), 3);
+    const uint hi1 = bitfieldExtract(qh, 3 * int(1), 3);
+    const uint hi2 = bitfieldExtract(qh, 3 * int(2), 3);
+    const uint hi3 = bitfieldExtract(qh, 3 * int(3), 3);
+
+    const int32_t grid0 = int32_t(iq1s_grid_gpu[qs0 | (hi0 << 8)]);
+    const int32_t grid1 = int32_t(iq1s_grid_gpu[qs1 | (hi1 << 8)]);
+    const int32_t grid2 = int32_t(iq1s_grid_gpu[qs2 | (hi2 << 8)]);
+    const int32_t grid3 = int32_t(iq1s_grid_gpu[qs3 | (hi3 << 8)]);
+
+    out0 = i32vec4((grid0 >> 0) & 0x0F0F0F0F,
+                   (grid0 >> 4) & 0x0F0F0F0F,
+                   (grid1 >> 0) & 0x0F0F0F0F,
+                   (grid1 >> 4) & 0x0F0F0F0F);
+    out1 = i32vec4((grid2 >> 0) & 0x0F0F0F0F,
+                   (grid2 >> 4) & 0x0F0F0F0F,
+                   (grid3 >> 0) & 0x0F0F0F0F,
+                   (grid3 >> 4) & 0x0F0F0F0F);
+}
+
+vec2 get_dm(uint ib, uint iqs) {
+    const uint ib32 = iqs / 32;
+
+    const uint qh = data_a[ib].qh[ib32];
+    const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
+
+    const float d = float(data_a[ib].d);
+    const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
+
+    // the -1 cancels out the bias in iq1s_grid_gpu
+    return FLOAT_TYPE_VEC2(dl, dl * (delta - 1));
+}
+
+FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
+    int32_t q_sum = 0;
+
+    const uint ib_k = ib_a / 8;
+    const uint iqs_k = (ib_a % 8) * 32 + iqs * 32;
+
+    i32vec4 qs_a0;
+    i32vec4 qs_a1;
+    repack8(ib_k, iqs_k, qs_a0, qs_a1);
+
+    const vec2 dm = get_dm(ib_k, iqs_k);
+
+    q_sum += dotPacked4x8EXT(qs_a0.x, cache_b_qs[0]);
+    q_sum += dotPacked4x8EXT(qs_a0.y, cache_b_qs[1]);
+    q_sum += dotPacked4x8EXT(qs_a0.z, cache_b_qs[2]);
+    q_sum += dotPacked4x8EXT(qs_a0.w, cache_b_qs[3]);
+    q_sum += dotPacked4x8EXT(qs_a1.x, cache_b_qs[4]);
+    q_sum += dotPacked4x8EXT(qs_a1.y, cache_b_qs[5]);
+    q_sum += dotPacked4x8EXT(qs_a1.z, cache_b_qs[6]);
+    q_sum += dotPacked4x8EXT(qs_a1.w, cache_b_qs[7]);
+
+    return FLOAT_TYPE(float(cache_b_ds.x) * float(dm.x) * float(q_sum) + float(dm.y) * float(cache_b_ds.y));
+}
+#endif
+
+#if defined(DATA_A_IQ1_M)
+FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
+    const uint ib_k = ib_a / 8;
+    const uint iqs_k = (ib_a % 8) * 32 + iqs * 32;
+
+    const uint ib32 = iqs_k / 32;
+    const uint ib64 = ib32 / 2;
+
+    const uint16_t[4] scales = data_a[ib_k].scales;
+    const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
+    const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
+
+    const uint qs32 = data_a_packed32[ib_k].qs[ib32];
+    const uint qh16 = data_a_packed16[ib_k].qh[ib32];
+
+    float sum = 0;
+    const uint sc = data_a[ib_k].scales[ib64];
+    [[unroll]] for (int l = 0; l < 4; ++l) {
+        const uint ib16 = 2 * ib32 + l / 2;
+        const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
+        const uint qh = qh16 >> (4 * l);
+        const uint qs = (qs32 >> (8 * l)) & 0xFF;
+        const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
+
+        const int32_t grid = int32_t(iq1s_grid_gpu[qs | ((qh & 7) << 8)]);
+
+        int32_t q_sum = 0;
+        q_sum += dotPacked4x8EXT((grid >> 0) & 0x0F0F0F0F, cache_b_qs[2 * l + 0]);
+        q_sum += dotPacked4x8EXT((grid >> 4) & 0x0F0F0F0F, cache_b_qs[2 * l + 1]);
+
+        int32_t y_sum = 0;
+        y_sum += dotPacked4x8EXT(int(0x01010101), cache_b_qs[2 * l + 0]);
+        y_sum += dotPacked4x8EXT(int(0x01010101), cache_b_qs[2 * l + 1]);
+
+        // the -1 cancels out the bias in iq1s_grid_gpu
+        sum += dl * (q_sum + y_sum * (delta - 1));
+    }
+    sum *= float(cache_b_ds.x);
+
+    return sum;
+}
+#endif
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
index 5c5251da39b..79344d33005 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
@@ -68,6 +68,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
 
 #ifdef MUL_MAT_ID
 layout (binding = 3) readonly buffer IDS {int data_ids[];};
+layout (binding = 4) readonly buffer Counts {int data_expert_count[];};
 #endif
 
 layout (push_constant) uniform parameter
@@ -89,6 +90,8 @@ layout (push_constant) uniform parameter
     uint nbi1;
     uint ne11;
 #else
+    uint base_work_group_z;
+    uint num_batches;
     uint k_split;
     uint ne02;
     uint ne12;
@@ -135,14 +138,20 @@ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
 #include "mul_mm_funcs.glsl"
 
 void main() {
+    const uint ic = gl_WorkGroupID.y;
+
+#ifdef MUL_MAT_ID
+    const uint expert_idx = gl_WorkGroupID.z;
+    if (ic * BN >= data_expert_count[expert_idx]) {
+        return;
+    }
+#endif
 #ifdef NEEDS_INIT_IQ_SHMEM
     init_iq_shmem(gl_WorkGroupSize);
 #endif
 
-#ifdef MUL_MAT_ID
-    const uint expert_idx = gl_GlobalInvocationID.z;
-#else
-    const uint batch_idx = gl_GlobalInvocationID.z;
+#ifndef MUL_MAT_ID
+    const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;
 
     const uint i13 = batch_idx / p.ne12;
     const uint i12 = batch_idx % p.ne12;
@@ -156,7 +165,6 @@ void main() {
     const uint blocks_m = (p.M + BM - 1) / BM;
     const uint ir = gl_WorkGroupID.x % blocks_m;
     const uint ik = gl_WorkGroupID.x / blocks_m;
-    const uint ic = gl_WorkGroupID.y;
 
     const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
     const uint WSUBM = WM / WMITER;
@@ -228,13 +236,13 @@ void main() {
     const uint end_k = min(p.K, (ik + 1) * p.k_split);
 #endif
 
-    uint pos_a = (
+    uint pos_a =
 #ifdef MUL_MAT_ID
-        expert_idx * p.batch_stride_a +
+        expert_idx * (p.batch_stride_a / LOAD_VEC_A) +
 #else
-        batch_idx_a * p.batch_stride_a +
+        batch_idx_a * (p.batch_stride_a / LOAD_VEC_A) +
 #endif
-        ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
+        (ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
 #ifdef MUL_MAT_ID
     uint pos_b = 0;
 #else
@@ -360,7 +368,7 @@ void main() {
     const uint dc = ic * BN + warp_c * WN;
 
 #ifndef MUL_MAT_ID
-    const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
+    const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
 #endif
 
 #ifdef COOPMAT
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
index 2e04baa44ec..497a18ff8a7 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
@@ -53,6 +53,8 @@ layout (push_constant) uniform parameter
     uint nbi1;
     uint ne11;
 #else
+    uint base_work_group_z;
+    uint num_batches;
     uint k_split;
     uint ne02;
     uint ne12;
@@ -92,6 +94,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
 
 #ifdef MUL_MAT_ID
 layout (binding = 3) readonly buffer IDS {int data_ids[];};
+layout (binding = 4) readonly buffer Counts {int data_expert_count[];};
 
 shared u16vec4 row_ids[BN];
 
@@ -107,11 +110,7 @@ B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const i
 {
     const uint row_i = blockCoords[0];
 
-    if (row_i >= _ne1) {
-        return B_TYPE(0.0);
-    }
-
-    const u16vec4 row_idx = row_ids[row_i & (BN - 1)];
+    const u16vec4 row_idx = row_ids[row_i];
     B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
 
     return ret;
@@ -138,6 +137,8 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
     uint ids[16];
     uint iter = 0;
 
+    uint expert_count = data_expert_count[expert_idx];
+
     for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
         // prefetch up to 16 elements
         if (iter == 0) {
@@ -166,7 +167,9 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
         uint id = ids[iter++];
         uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
 
-        ballots_sh[gl_SubgroupID] = ballot;
+        if (gl_SubgroupInvocationID == 0) {
+            ballots_sh[gl_SubgroupID] = ballot;
+        }
         barrier();
 
         uint subgroup_base = 0;
@@ -185,7 +188,7 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
         }
         _ne1 += total;
         iter &= 15;
-        if (_ne1 >= (ic + 1) * BN) {
+        if (_ne1 >= (ic + 1) * BN || _ne1 == expert_count) {
             break;
         }
     }
@@ -194,16 +197,29 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
 #endif
 
 void main() {
+    const uint tid = gl_LocalInvocationIndex;
+    const uint ic = gl_WorkGroupID.y;
+
+#ifdef MUL_MAT_ID
+    const uint expert_idx = gl_WorkGroupID.z;
+    if (ic * BN >= data_expert_count[expert_idx]) {
+        return;
+    }
+    // initialize to row 0 so we don't need to bounds check
+    if (tid < BN) {
+        row_ids[tid] = u16vec4(0);
+    }
+#if !defined(NEEDS_INIT_IQ_SHMEM)
+    barrier();
+#endif
+#endif
+
 #ifdef NEEDS_INIT_IQ_SHMEM
     init_iq_shmem(gl_WorkGroupSize);
 #endif
 
-    const uint tid = gl_LocalInvocationIndex;
-
-#ifdef MUL_MAT_ID
-    const uint expert_idx = gl_GlobalInvocationID.z;
-#else
-    const uint batch_idx = gl_GlobalInvocationID.z;
+#ifndef MUL_MAT_ID
+    const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;
 
     const uint i13 = batch_idx / p.ne12;
     const uint i12 = batch_idx % p.ne12;
@@ -217,7 +233,6 @@ void main() {
     const uint blocks_m = (p.M + BM - 1) / BM;
     const uint ir = gl_WorkGroupID.x % blocks_m;
     const uint ik = gl_WorkGroupID.x / blocks_m;
-    const uint ic = gl_WorkGroupID.y;
 
 #ifdef MUL_MAT_ID
     if (bitCount(p.nei0) == 1) {
@@ -239,12 +254,12 @@ void main() {
 #endif
 
 #ifdef MUL_MAT_ID
-    uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K;
+    uint pos_a = expert_idx * (p.batch_stride_a / QUANT_K);
     uint pos_b = 0;
 #else
-    uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K;
+    uint pos_a = batch_idx_a * (p.batch_stride_a / QUANT_K);
     uint pos_b = batch_idx * p.batch_stride_b;
-    uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
+    uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
 #endif
 
     uint stride_a = p.stride_a / QUANT_K;
@@ -482,7 +497,7 @@ void main() {
                     coopmat mat_b;
 
                     coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
-                    coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
+                    coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
 
                     sum = coopMatMulAdd(mat_a, mat_b, sum);
                 } else {
@@ -490,7 +505,7 @@ void main() {
                     coopmat mat_b;
 
                     coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
-                    coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
+                    coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
 
                     sum = coopMatMulAdd(mat_a, mat_b, sum);
                 }
@@ -526,7 +541,7 @@ void main() {
                     coopmat mat_b;
 
                     coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
-                    coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
+                    coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
 
                     sum = coopMatMulAdd(mat_a, mat_b, sum);
                 } else {
@@ -534,7 +549,7 @@ void main() {
                     coopmat mat_b;
 
                     coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
-                    coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
+                    coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
 
                     sum = coopMatMulAdd(mat_a, mat_b, sum);
                 }
@@ -571,7 +586,7 @@ void main() {
 
                 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
 #ifdef MUL_MAT_ID
-                coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
+                coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
 #else
                 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
 #endif
@@ -583,7 +598,7 @@ void main() {
 
                 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
 #ifdef MUL_MAT_ID
-                coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
+                coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
 #else
                 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
 #endif
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl
index 58ede04400d..ce7f2d699a2 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl
@@ -47,7 +47,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
 #endif
 #elif defined(DATA_A_Q4_0)
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
-            const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
+            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
 
             const uint ib = idx / 4;
             const uint iqs = idx & 0x03;
@@ -63,16 +63,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw);
 #elif defined(DATA_A_Q4_1)
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
-            const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
+            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
 
             const uint ib = idx / 4;
             const uint iqs = idx & 0x03;
 
-            const float d = float(data_a_packed16[ib].d);
-            const float m = float(data_a_packed16[ib].m);
-            const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
-            const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m;
-            const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m;
+            const vec2 dm = vec2(data_a_packed32[ib].dm);
+            const uint vui = data_a_packed32[ib].qs[iqs];
+            const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * dm.x + dm.y;
+            const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * dm.x + dm.y;
 
             buf_a[buf_idx     ] = FLOAT_TYPE_VEC2(v0.xy);
             buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw);
@@ -80,7 +79,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw);
 #elif defined(DATA_A_Q5_0)
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
-            const uint buf_idx = col * SHMEM_STRIDE + row;
+            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
 
             const uint ib = idx / 8;
             const uint iqs = idx & 0x07;
@@ -97,22 +96,26 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
 #elif defined(DATA_A_Q5_1)
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
-            const uint buf_idx = col * SHMEM_STRIDE + row;
-
-            const uint ib = idx / 8;
-            const uint iqs = idx & 0x07;
+            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
 
-            const float d = float(data_a_packed16[ib].d);
-            const float m = float(data_a_packed16[ib].m);
-            const uint uint_qh = data_a_packed16[ib].qh;
-            const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
-            const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
-
-            const uint vui = uint(data_a_packed16[ib].qs[iqs]);
-            const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;
+            const uint ib = idx / 4;
+            const uint iqs = idx & 0x03;
 
-            buf_a[buf_idx    ] = FLOAT_TYPE_VEC2(v.xz);
-            buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
+            const vec2 dm = vec2(data_a_packed32[ib].dm);
+            const uint uint_qh = data_a_packed32[ib].qh;
+            const uvec2 qh0 = uvec2(((uint_qh >> 4*iqs) << 4) & 0x10, (uint_qh >> (4*iqs + 12)) & 0x10);
+            const uvec2 qh1 = uvec2(((uint_qh >> (4*iqs + 1)) << 4) & 0x10, (uint_qh >> (4*iqs + 13)) & 0x10);
+            const uvec2 qh2 = uvec2(((uint_qh >> (4*iqs + 2)) << 4) & 0x10, (uint_qh >> (4*iqs + 14)) & 0x10);
+            const uvec2 qh3 = uvec2(((uint_qh >> (4*iqs + 3)) << 4) & 0x10, (uint_qh >> (4*iqs + 15)) & 0x10);
+
+            const uint vui = data_a_packed32[ib].qs[iqs];
+            const vec4 v0 = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * dm.x + dm.y;
+            const vec4 v1 = vec4(((vui >> 16) & 0xF) | qh2.x, ((vui >> 20) & 0xF) | qh2.y, ((vui >> 24) & 0xF) | qh3.x, ((vui >> 28) & 0xF) | qh3.y) * dm.x + dm.y;
+
+            buf_a[buf_idx    ] = FLOAT_TYPE_VEC2(v0.xz);
+            buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v1.xz);
+            buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v0.yw);
+            buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.yw);
 #elif defined(DATA_A_Q8_0)
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
             const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -131,20 +134,21 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
             const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
 
-            const uint ib = idx / 128;                         // 2 values per idx
-            const uint iqs = idx % 128;                        // 0..127
+            const uint ib = idx / 64;                          // 4 values per idx
+            const uint iqs = (idx % 64) * 2;                   // 0,2,4..126
 
             const uint qsi = (iqs / 64) * 16 + (iqs % 16);     // 0..15
             const uint scalesi = iqs / 8;                      // 0..15
             const uint qsshift = ((iqs % 64) / 16) * 2;        // 0,2,4,6
 
-            const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi]));
+            const vec4 qs = vec4(unpack8((data_a_packed32[ib].qs[qsi / 2] >> qsshift) & 0x03030303));
             const uint scales = data_a[ib].scales[scalesi];
             const vec2 dm = vec2(data_a[ib].dm);
 
-            const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
+            const vec4 v = dm.x * float(scales & 0xF) * qs - dm.y * float(scales >> 4);
 
-            buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
+            buf_a[buf_idx    ] = FLOAT_TYPE_VEC2(v.xy);
+            buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
 #elif defined(DATA_A_Q3_K)
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
             const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -159,20 +163,22 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             const uint is = iqs / 8;                     // 0..15
             const uint halfsplit = ((iqs % 64) / 16);    // 0,1,2,3
             const uint qsshift = halfsplit * 2;          // 0,2,4,6
-            const uint m = 1 << (4 * n + halfsplit);     // 1,2,4,8,16,32,64,128
 
             const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF)
                                   | (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4));
             const float dl = float(data_a[ib].d) * float(us - 32);
 
-            buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * float(int8_t((data_a[ib].qs[qsi    ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi    ] & m) != 0) ? 0 : 4)),
-                                             dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
+            const vec2 qs = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> qsshift) & 0x0303).xy);
+            const vec2 hm = vec2(unpack8(((uint(data_a_packed16[ib].hmask[hmi / 2]) >> (4 * n + halfsplit)) & 0x0101 ^ 0x0101) << 2).xy);
+
+            buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * (qs.x - hm.x),
+                                             dl * (qs.y - hm.y));
 #elif defined(DATA_A_Q4_K)
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
             const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
 
-            const uint ib = idx / 128;                 // 2 values per idx
-            const uint iqs = idx % 128;                // 0..127
+            const uint ib = idx / 64;                  // 4 values per idx
+            const uint iqs = (idx % 64) * 2;           // 0,2,4..126
 
             const uint n = iqs / 32;                   // 0,1,2,3
             const uint b = (iqs % 32) / 16;            // 0,1
@@ -198,14 +204,16 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             const float d = loadd.x * sc;
             const float m = -loadd.y * mbyte;
 
-            buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi    ] >> (b * 4)) & 0xF), m),
-                                             fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
+            const vec4 q = vec4(unpack8((data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F));
+
+            buf_a[buf_idx    ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m));
+            buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m));
 #elif defined(DATA_A_Q5_K)
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
             const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
 
-            const uint ib = idx / 128;                 // 2 values per idx
-            const uint iqs = idx % 128;                // 0..127
+            const uint ib = idx / 64;                  // 4 values per idx
+            const uint iqs = (idx % 64) * 2;           // 0,2,4..126
 
             const uint n = iqs / 32;                   // 0,1,2,3
             const uint b = (iqs % 32) / 16;            // 0,1
@@ -213,8 +221,6 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             const uint qsi = n * 32 + (iqs % 16) * 2;  // 0,2,4..126
             const uint qhi = (iqs % 16) * 2;           // 0,2,4..30
 
-            const uint8_t hm = uint8_t(1 << (iqs / 16));
-
             const vec2 loadd = vec2(data_a[ib].dm);
 
             const uint scidx0 = (is < 4) ? is : (is + 4);
@@ -234,8 +240,12 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             const float d = loadd.x * sc;
             const float m = -loadd.y * mbyte;
 
-            buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi    ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi    ] & hm) != 0 ? 16 : 0), m),
-                                             fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
+            const uint qs = (data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F;
+            const uint qh = ((data_a_packed32[ib].qh[qhi / 4] >> (iqs / 16)) & 0x01010101) << 4;
+            const vec4 q = vec4(unpack8(qs | qh));
+
+            buf_a[buf_idx    ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m));
+            buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m));
 #elif defined(DATA_A_Q6_K)
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
             const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@@ -394,11 +404,9 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
 
             const float d = float(data_a[ib].d);
             const uint qs = data_a[ib].qs[iqs];
-            const uint signs = pack32(u8vec4(
-                data_a[ib].qs[is+0],
-                data_a[ib].qs[is+1],
-                data_a[ib].qs[is+2],
-                data_a[ib].qs[is+3]
+            const uint signs = pack32(u16vec2(
+                data_a_packed16[ib].qs[is/2],
+                data_a_packed16[ib].qs[is/2+1]
             ));
             const float db = d * 0.5 * (0.5 + (signs >> 28));
             const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
@@ -443,8 +451,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
             const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3;
             const uint qshift = (idx & 8) >> 1;
-            u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]);
-            qs = (qs >> qshift) & uint8_t(0xF);
+            u8vec2 qs = unpack8((uint(data_a_packed16[ib].qs[iq/2]) >> qshift) & 0x0F0F).xy;
 
             const float d = float(data_a[ib].d);
             const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
@@ -452,7 +459,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             buf_a[buf_idx    ] = FLOAT_TYPE_VEC2(v.xy);
 #elif defined(DATA_A_IQ4_NL)
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
-            const uint buf_idx = col * SHMEM_STRIDE + row;
+            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
 
             const uint ib = idx / 8;
             const uint iqs = idx & 0x07;
@@ -466,7 +473,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
                                                      kvalues_iq4nl[vui >> 12]);
 #elif defined(DATA_A_MXFP4)
             const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
-            const uint buf_idx = col * SHMEM_STRIDE + row;
+            const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
 
             const uint ib = idx / 8;
             const uint iqs = (idx & 0x07) * 2;
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl
index 1d0e84ac942..26c5c12a49a 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl
@@ -13,6 +13,8 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
     uint ids[16];
     uint iter = 0;
 
+    uint expert_count = data_expert_count[expert_idx];
+
     for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
         // prefetch up to 16 elements
         if (iter == 0) {
@@ -41,7 +43,9 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
         uint id = ids[iter++];
         uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
 
-        ballots_sh[gl_SubgroupID] = ballot;
+        if (gl_SubgroupInvocationID == 0) {
+            ballots_sh[gl_SubgroupID] = ballot;
+        }
         barrier();
 
         uint subgroup_base = 0;
@@ -60,7 +64,7 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
         }
         _ne1 += total;
         iter &= 15;
-        if (_ne1 >= (ic + 1) * BN) {
+        if (_ne1 >= (ic + 1) * BN || _ne1 == expert_count) {
             break;
         }
     }
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
index dc8b3df47be..aae1c2e8ae9 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
@@ -35,6 +35,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
 
 #ifdef MUL_MAT_ID
 layout (binding = 3) readonly buffer IDS {int data_ids[];};
+layout (binding = 4) readonly buffer Counts {int data_expert_count[];};
 #endif
 
 layout (push_constant) uniform parameter
@@ -56,6 +57,8 @@ layout (push_constant) uniform parameter
     uint nbi1;
     uint ne11;
 #else
+    uint base_work_group_z;
+    uint num_batches;
     uint k_split;
     uint ne02;
     uint ne12;
@@ -104,14 +107,20 @@ block_b_cache cache_b;
 #include "mul_mmq_funcs.glsl"
 
 void main() {
+    const uint ic = gl_WorkGroupID.y;
+
+#ifdef MUL_MAT_ID
+    const uint expert_idx = gl_WorkGroupID.z;
+    if (ic * BN >= data_expert_count[expert_idx]) {
+        return;
+    }
+#endif
 #ifdef NEEDS_INIT_IQ_SHMEM
     init_iq_shmem(gl_WorkGroupSize);
 #endif
 
-#ifdef MUL_MAT_ID
-    const uint expert_idx = gl_GlobalInvocationID.z;
-#else
-    const uint batch_idx = gl_GlobalInvocationID.z;
+#ifndef MUL_MAT_ID
+    const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;
 
     const uint i13 = batch_idx / p.ne12;
     const uint i12 = batch_idx % p.ne12;
@@ -125,7 +134,6 @@ void main() {
     const uint blocks_m = (p.M + BM - 1) / BM;
     const uint ir = gl_WorkGroupID.x % blocks_m;
     const uint ik = gl_WorkGroupID.x / blocks_m;
-    const uint ic = gl_WorkGroupID.y;
 
     const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
     const uint WSUBM = WM / WMITER;
@@ -183,13 +191,13 @@ void main() {
     const uint end_k = min(p.K, (ik + 1) * p.k_split);
 #endif
 
-    uint pos_a_ib = (
+    uint pos_a_ib =
 #ifdef MUL_MAT_ID
-        expert_idx * p.batch_stride_a +
+        expert_idx * (p.batch_stride_a / BK) +
 #else
-        batch_idx_a * p.batch_stride_a +
+        batch_idx_a * (p.batch_stride_a / BK) +
 #endif
-        ir * BM * p.stride_a + start_k) / BK;
+        (ir * BM * p.stride_a + start_k) / BK;
 #ifdef MUL_MAT_ID
     uint pos_b_ib = 0;
 #else
@@ -270,7 +278,7 @@ void main() {
     const uint dc = ic * BN + warp_c * WN;
 
 #ifndef MUL_MAT_ID
-    const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
+    const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
 #endif
 
     [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl
index 7f32dadf17d..9c297d1c60d 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl
@@ -264,7 +264,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
         const i8vec2 scales = i8vec2(unpack8(uint32_t(((data_a_packed16[ib_k].scales[(is % 8      ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
                                                      (((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4))).xy); // vec4 used due to #12147
 
-        buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32);
+        buf_a[buf_ib].d_scales = FLOAT_TYPE_VEC2(float(data_a_packed16[ib_k].d) * vec2(scales - 32));
     }
 }
 
@@ -334,7 +334,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
                               (data_a[ib_k].scales[is+4] >>  4) | ((data_a[ib_k].scales[is  ] & 0xC0) >> 2));
         }
 
-        buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
+        buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(vec2(data_a_packed32[ib_k].dm) * vec2(scale_dm));
     }
 }
 
@@ -385,7 +385,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
         const uint is = iqs_k / 4;
         const i8vec2 scales = unpack8(int32_t(data_a_packed16[ib_k].scales[is / 2])).xy;
 
-        buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales);
+        buf_a[buf_ib].d_scales = FLOAT_TYPE_VEC2(float(data_a_packed16[ib_k].d) * vec2(scales));
     }
 }
 
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp
index 20e45d0253e..7ea29a07e37 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp
@@ -15,6 +15,7 @@
 layout (push_constant) uniform parameter
 {
     uint ne;
+    uint num_blocks;
 } p;
 
 #include "types.glsl"
@@ -33,8 +34,7 @@ layout (binding = 1) writeonly buffer D {block_q8_1_x4 data_b[];};
 shared float shmem[GROUP_SIZE];
 #endif
 
-void quantize() {
-    const uint wgid = gl_WorkGroupID.x;
+void quantize(const uint wgid) {
     const uint tid = INVOCATION_ID;
 
     // Each thread handles a vec4, so 8 threads handle a block
@@ -45,11 +45,7 @@ void quantize() {
     const uint ib = wgid * blocks_per_group + block_in_wg;
     const uint iqs = tid % 8;
 
-#ifndef QBLOCK_X4
-    if (ib >= gl_NumWorkGroups.x * blocks_per_group) {
-        return;
-    }
-#else
+#ifdef QBLOCK_X4
     const uint ibx4_outer = ib / 4;
     const uint ibx4_inner = ib % 4;
 
@@ -123,5 +119,9 @@ void quantize() {
 }
 
 void main() {
-    quantize();
+    uint wgid = gl_WorkGroupID.x;
+    while (wgid < p.num_blocks) {
+        quantize(wgid);
+        wgid += gl_NumWorkGroups.x;
+    }
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp
index 9d6d3665427..55b89f19a7a 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp
@@ -112,12 +112,11 @@ void rms_norm(uint num_iters) {
 #if RMS_NORM_ROPE_FUSION
     barrier();
     rope_params rp = p.rope;
-    uint rope_row = (samp*nchannels + channel)*nrows + row;
     for (uint t = 2*tid; t < ncols; t += 2*BLOCK_SIZE) {
         if (rp.rope_mode == GGML_ROPE_TYPE_NEOX) {
-            rope_neox(t, rope_row, rp);
+            rope_neox(t, row, channel, samp, rp);
         } else if (rp.rope_mode == GGML_ROPE_TYPE_NORMAL) {
-            rope_norm(t, rope_row, rp);
+            rope_norm(t, row, channel, samp, rp);
         }
     }
 #endif
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl
index 1c8c69422a9..f2028c4c56b 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl
@@ -4,12 +4,12 @@ float rope_yarn_ramp(const float low, const float high, const uint i0) {
     return 1.0f - min(1.0f, max(0.0f, y));
 }
 
-uint rope_a_coord(const uint i0, const uint i01, const uint i02, rope_params p) {
+uint rope_a_coord(const uint i0, const uint i01, const uint i02, const uint i03, rope_params p) {
 #if RMS_NORM_ROPE_FUSION
     // Per-row offset in shared memory
     const uint ix = i0;
 #else
-    const uint ix = i02*p.nb02 + i01*p.nb01 + i0;
+    const uint ix = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0;
 #endif
     return ix;
 }
@@ -34,26 +34,19 @@ void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out
     sin_theta = sin(theta) * mscale;
 }
 
-void rope_norm(const uint i0, const uint i1, rope_params p) {
-    uint ne0 = p.ncols;
-    uint ne1 = p.p_delta_rows;
-
-    if (i0 >= ne0) {
+void rope_norm(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
+    if (i0 >= p.ne00) {
         return;
     }
 
-    // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i01 = i1 % ne1;
-    const uint i02 = i1 / ne1;
-
-    uint idst = i1*ne0 + i0;
-    const uint ix = rope_a_coord(i0, i01, i02, p);
+    uint idst = i0 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
+    const uint ix = rope_a_coord(i0, i1, i2, i3, p);
 
-    // Fusion optimization: ROPE + VIEW + SET_ROWS..
-    // The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
+    // Fusion optimization: ROPE + VIEW + SET_ROWS.
+    // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
     if (p.set_rows_stride != 0) {
-        idst = i01*ne0 + i0;
-        idst += rope_data_i[i02].x * p.set_rows_stride;
+        idst = i1*p.nb11 + i0;
+        idst += rope_data_i[i2].x * p.set_rows_stride;
     }
 
     if (i0 >= p.n_dims) {
@@ -63,7 +56,7 @@ void rope_norm(const uint i0, const uint i1, rope_params p) {
         return;
     }
 
-    const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f);
+    const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f);
 
     const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
 
@@ -77,25 +70,19 @@ void rope_norm(const uint i0, const uint i1, rope_params p) {
     rope_data_d[idst + 1] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
 }
 
-void rope_neox(const uint i0, const uint i1, rope_params p) {
-    uint ne0 = p.ncols;
-    uint ne1 = p.p_delta_rows;
-
-    if (i0 >= ne0) {
+void rope_neox(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
+    if (i0 >= p.ne00) {
         return;
     }
 
-    const uint i01 = i1 % ne1;
-    const uint i02 = i1 / ne1;
-
-    uint idst = i1*ne0 + i0/2;
-    const uint ix = rope_a_coord(i0/2, i01, i02, p);
+    uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
+    const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
 
-    // Fusion optimization: ROPE + VIEW + SET_ROWS..
+    // Fusion optimization: ROPE + VIEW + SET_ROWS.
     // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
     if (p.set_rows_stride != 0) {
-        idst = i01*ne0 + i0/2;
-        idst += rope_data_i[i02].x * p.set_rows_stride;
+        idst = i1*p.nb11 + i0/2;
+        idst += rope_data_i[i2].x * p.set_rows_stride;
     }
 
     if (i0 >= p.n_dims) {
@@ -105,7 +92,7 @@ void rope_neox(const uint i0, const uint i1, rope_params p) {
         return;
     }
 
-    const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f);
+    const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f);
 
     const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
 
@@ -120,20 +107,20 @@ void rope_neox(const uint i0, const uint i1, rope_params p) {
 }
 
 
-void rope_multi(const uint i0, const uint i1, rope_params p) {
-    uint ne0 = p.ncols;
-    uint ne1 = p.p_delta_rows;
-    uint ne2 = p.ne02;
-
-    if (i0 >= ne0) {
+void rope_multi(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
+    if (i0 >= p.ne00) {
         return;
     }
 
-    const uint i01 = i1 % ne1;
-    const uint i02 = i1 / ne1;
+    uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
+    const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
 
-    const uint idst = i1*ne0 + i0/2;
-    const uint ix = rope_a_coord(i0/2, i01, i02, p);
+    // Fusion optimization: ROPE + VIEW + SET_ROWS.
+    // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
+    if (p.set_rows_stride != 0) {
+        idst = i1*p.nb11 + i0/2;
+        idst += rope_data_i[i2].x * p.set_rows_stride;
+    }
 
     if (i0 >= p.n_dims) {
         rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]);
@@ -149,26 +136,26 @@ void rope_multi(const uint i0, const uint i1, rope_params p) {
     float theta_base = 0.0;
     if (p.is_imrope != 0) {
         if (sector % 3 == 1 && sector < 1 + 3 * p.sections[1]) {
-            theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
+            theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f);
         } else if (sector % 3 == 2 && sector < 2 + 3 * p.sections[2]) {
-            theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
+            theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f);
         } else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
-            theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
+            theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f);
         //} else {
-        //    theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
+        //    theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f);
         }
     } else {
         if (sector < p.sections[0]) {
-            theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
+            theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f);
         }
         else if (sector >= p.sections[0] && sector < sec_w) {
-            theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
+            theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f);
         }
         else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
-            theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
+            theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f);
         }
         else if (sector >= sec_w + p.sections[2]) {
-            theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
+            theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f);
         }
     }
 
@@ -184,20 +171,13 @@ void rope_multi(const uint i0, const uint i1, rope_params p) {
     rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
 }
 
-void rope_vision(const uint i0, const uint i1, rope_params p) {
-    uint ne0 = p.ncols;
-    uint ne1 = p.p_delta_rows;
-    uint ne2 = p.ne02;
-
-    if (i0 >= ne0) {
+void rope_vision(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
+    if (i0 >= p.ne00) {
         return;
     }
 
-    const uint i01 = i1 % ne1;
-    const uint i02 = i1 / ne1;
-
-    const uint idst = i1*ne0 + i0/2;
-    const uint ix = rope_a_coord(i0/2, i01, i02, p);
+    const uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
+    const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
 
     const int sect_dims = p.sections[0] + p.sections[1];
     const int sec_w = p.sections[1] + p.sections[0];
@@ -206,11 +186,11 @@ void rope_vision(const uint i0, const uint i1, rope_params p) {
     float theta_base = 0.0;
     if (sector < p.sections[0]) {
         const uint p0 = sector;
-        theta_base = rope_data_pos[i02]*pow(p.theta_scale, p0);
+        theta_base = rope_data_pos[i2]*pow(p.theta_scale, p0);
     }
     else if (sector >= p.sections[0] && sector < sec_w) {
         const uint p0 = sector - p.sections[0];
-        theta_base = rope_data_pos[i02 + ne2]*pow(p.theta_scale, p0);
+        theta_base = rope_data_pos[i2 + p.ne02]*pow(p.theta_scale, p0);
     }
 
     const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp
index 7c1fb1cd224..1528fbeeaec 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp
@@ -5,7 +5,13 @@
 
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
-    // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i1 = gl_GlobalInvocationID.x;
-    rope_multi(i0, i1, pc);
+    const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+    if (row >= pc.nrows) {
+        return;
+    }
+    const uint i3 = row / (pc.ne01*pc.ne02);
+    const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
+    const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);
+
+    rope_multi(i0, i1, i2, i3, pc);
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
index 68f00c180bb..ad0896095db 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
@@ -5,7 +5,13 @@
 
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
-    // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i1 = gl_GlobalInvocationID.x;
-    rope_neox(i0, i1, pc);
+    const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+    if (row >= pc.nrows) {
+        return;
+    }
+    const uint i3 = row / (pc.ne01*pc.ne02);
+    const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
+    const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);
+
+    rope_neox(i0, i1, i2, i3, pc);
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
index 28a939ec6ad..11220817df0 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
@@ -5,7 +5,13 @@
 
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
-    // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i1 = gl_GlobalInvocationID.x;
-    rope_norm(i0, i1, pc);
+    const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+    if (row >= pc.nrows) {
+        return;
+    }
+    const uint i3 = row / (pc.ne01*pc.ne02);
+    const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
+    const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);
+
+    rope_norm(i0, i1, i2, i3, pc);
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl
index 82f39cee349..ec6ceaca9bd 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl
@@ -5,23 +5,29 @@
 
 struct rope_params {
     uint rope_mode;
-    uint ncols;
+    uint nrows;
     uint n_dims;
     float freq_scale;
-    uint p_delta_rows;
     float freq_base;
     float ext_factor;
     float attn_factor;
     float corr_dims[2];
     float theta_scale;
     uint has_ff;
-    uint ne02;
-    uint nb01;
-    uint nb02;
     int sections[4];
     uint is_imrope;
     uint is_back;
     uint set_rows_stride;
+
+    uint ne00;
+    uint ne01;
+    uint ne02;
+    uint nb01;
+    uint nb02;
+    uint nb03;
+    uint nb11;
+    uint nb12;
+    uint nb13;
 };
 
 #endif // !defined(GGML_ROPE_PARAMS)
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp
index ea1e0fdb416..ca71efb2f55 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp
@@ -5,7 +5,13 @@
 
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
-    // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i1 = gl_GlobalInvocationID.x;
-    rope_vision(i0, i1, pc);
+    const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+    if (row >= pc.nrows) {
+        return;
+    }
+    const uint i3 = row / (pc.ne01*pc.ne02);
+    const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
+    const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);
+
+    rope_vision(i0, i1, i2, i3, pc);
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp
index 8f67be97995..c7416206dbd 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp
@@ -1,6 +1,7 @@
 #version 450
 
 #extension GL_EXT_control_flow_attributes : require
+#extension GL_KHR_shader_subgroup_basic : enable
 #if USE_SUBGROUP_ADD
 #extension GL_KHR_shader_subgroup_arithmetic : enable
 #endif
@@ -9,7 +10,8 @@
 
 layout(constant_id = 0) const uint D_STATE = 128;
 layout(constant_id = 1) const uint SUBGROUP_SIZE = 32;
-layout(constant_id = 2) const uint SPLIT_H = 16;
+
+const uint32_t c_factor = D_STATE / SUBGROUP_SIZE;
 
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
@@ -41,22 +43,28 @@ float softplus(float x) {
     }
 }
 
-shared float stateC[SPLIT_H * D_STATE];
+#if !USE_SUBGROUP_ADD
+shared float temp[D_STATE];
+#endif
 
 void main() {
-    const uint tid = gl_LocalInvocationID.x;
-    const uint head_idx = (gl_WorkGroupID.x * SPLIT_H) / d_head;
-    const uint head_off = ((gl_WorkGroupID.x * SPLIT_H) % d_head) * 4;
-    const uint seq_idx = gl_WorkGroupID.y;
+    const uint subgroup = gl_SubgroupID;
+    const uint lane     = gl_SubgroupInvocationID;
+    const uint tid      = gl_SubgroupID * SUBGROUP_SIZE + lane;
+    const uint subgroup_idx = gl_WorkGroupID.x  * c_factor + subgroup;
+
+    const uint head_idx =  subgroup_idx / d_head;
+    const uint head_off = (subgroup_idx % d_head) * 4;
+    const uint seq_idx  = gl_WorkGroupID.y;
 
     const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4;
     const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
-    const uint x_base_idx = (seq_idx * nb13 + gl_WorkGroupID.x * SPLIT_H * 4) / 4;
+    const uint x_base_idx = (seq_idx * nb13 + subgroup_idx * 4) / 4;
     const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4;
     const uint A_base_idx = (head_idx * nb31) / 4;
     const uint B_base_idx = (seq_idx * nb43 + group_off) / 4;
     const uint C_base_idx = (seq_idx * nb53 + group_off) / 4;
-    const uint y_base_idx = seq_idx * n_tok * n_head * d_head + gl_WorkGroupID.x * SPLIT_H;
+    const uint y_base_idx = seq_idx * n_tok * n_head * d_head + subgroup_idx;
     const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
 
     const uint stride_x = nb12 / 4;
@@ -65,76 +73,52 @@ void main() {
     const uint stride_C = nb52 / 4;
     const uint stride_y = n_head * d_head;
 
-    float state[SPLIT_H];
-    [[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
-        state[j] = s0[s0_base_idx + j * D_STATE + tid];
-    }
+    float state[c_factor];
 
-    for (uint i = 0; i < n_tok; i++) {
-        const float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);
+    [[unroll]] for (uint j = 0; j < c_factor; j++) {
+        state[j] = s0[s0_base_idx + SUBGROUP_SIZE * j + lane];
+    }
 
-        const float dA = exp(dt_soft_plus * A[A_base_idx]);
+    float a = A[A_base_idx];
 
-        const float B_val = B[B_base_idx + i * stride_B + tid];
-        const float C_val = C[C_base_idx + i * stride_C + tid];
+    for (uint i = 0; i < n_tok; i++) {
+        float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);
 
-        [[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
-            const float x_dt = x[x_base_idx + i * stride_x + j] * dt_soft_plus;
+        float state_sum = 0.0f;
 
+        const float dA   = exp(dt_soft_plus * a);
+        const float x_dt = x[x_base_idx + i * stride_x] * dt_soft_plus;
+        [[unroll]] for (uint j = 0; j < c_factor; j++) {
+            float B_val = B[B_base_idx + i * stride_B + SUBGROUP_SIZE * j + lane];
+            float C_val = C[C_base_idx + i * stride_C + SUBGROUP_SIZE * j + lane];
             state[j] = (state[j] * dA) + (B_val * x_dt);
-
-            stateC[j * D_STATE + tid] = state[j] * C_val;
+            state_sum += state[j] * C_val;
         }
 
+#if USE_SUBGROUP_ADD
+        state_sum = subgroupAdd(state_sum);
+#else
+        temp[tid] = state_sum;
         barrier();
-        [[unroll]]
-        for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) {
-            [[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
-                const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w);
-                if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) {
-                    stateC[k] += stateC[k + w];
-                }
+        [[unroll]] for (uint s = SUBGROUP_SIZE / 2; s > 0; s >>= 1) {
+            if (lane < s) {
+                temp[tid] += temp[tid + s];
             }
             barrier();
         }
-
-        [[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) {
-            const uint idx = (tid % SUBGROUP_SIZE) +
-                            D_STATE * (tid / SUBGROUP_SIZE) +
-                            j * D_STATE * (D_STATE / SUBGROUP_SIZE);
-            const uint max_idx = SUBGROUP_SIZE - 1 +
-                            D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) +
-                            j * D_STATE * (D_STATE / SUBGROUP_SIZE);
-
-            if (idx < SPLIT_H * D_STATE ||
-                max_idx < SPLIT_H * D_STATE) {
-                float sc;
-#if USE_SUBGROUP_ADD
-                sc = stateC[idx];
-                sc = subgroupAdd(sc);
-#else
-                [[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
-                    if (idx + offset < SPLIT_H * D_STATE) {
-                        stateC[idx] += stateC[idx + offset];
-                    }
-                    barrier();
-                }
-                if (tid % SUBGROUP_SIZE == 0) {
-                    sc = stateC[idx];
-                }
+        // get the value from lane 0
+        state_sum = temp[subgroup * SUBGROUP_SIZE];
+        barrier();
 #endif
 
-                if (tid % SUBGROUP_SIZE == 0) {
-                    const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
-                    d[y_base_idx + i * stride_y + k] = sc;
-                }
-            }
+        if (lane == 0) {
+            d[y_base_idx + i * stride_y] = state_sum;
         }
-
-        barrier();
     }
 
-    [[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
-        d[s_base_idx + j * D_STATE + tid] = state[j];
+    // write back the state
+    [[unroll]]
+    for (int j = 0; j < c_factor; j++) {
+        d[s_base_idx + SUBGROUP_SIZE * j + lane] = state[j];
     }
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
index b83a2b9d2d4..ef2f202ec9b 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
@@ -7,6 +7,10 @@
 
 #include "types.glsl"
 
+#define GATING_FUNC_SOFTMAX 0
+#define GATING_FUNC_SIGMOID 1
+#define GATING_FUNC_SOFTMAX_WEIGHT 2
+
 layout (push_constant) uniform parameter
 {
     uint n_rows;
@@ -14,15 +18,18 @@ layout (push_constant) uniform parameter
     uint n_expert_used;
     float clamp_min;
     float clamp_max;
+    uint gating_func;
+    uint has_bias;
+    uint with_norm;
+    float output_scale;
+    float output_bias;
 };
 
 layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
 
 layout(constant_id = 0) const uint WARP_SIZE = 32;
 layout(constant_id = 1) const uint n_experts_spec = 512;
-layout(constant_id = 2) const bool with_norm = true;
-layout(constant_id = 3) const bool late_softmax = false;
-layout(constant_id = 4) const bool nexperts_use_push = false;
+layout(constant_id = 2) const bool nexperts_use_push = false;
 
 uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec;
 
@@ -31,8 +38,9 @@ uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec;
 const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE);
 
 layout (binding = 0, std430) readonly buffer Logits {float logits[];};
-layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
-layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
+layout (binding = 1, std430) readonly buffer BiasProbs {float bias[];};
+layout (binding = 2, std430) writeonly buffer Weights {float weights[];};
+layout (binding = 3, std430) writeonly buffer Ids {uint ids[];};
 
 const float INFINITY = 1.0 / 0.0;
 
@@ -87,20 +95,45 @@ void main() {
     }
 
     const uint logits_offset = n_experts * row;
+    const uint bias_offset = 0; // 1D
     const uint weights_offset = n_expert_used * row;
     const uint ids_offset = n_experts * row;
     const uint lane = gl_SubgroupInvocationID;
 
-    float wt[experts_per_thread];
+    float probs[experts_per_thread];
+    [[unroll]]
+    for (int i = 0; i < experts_per_thread; i++) {
+        probs[i] = -INFINITY;
+    }
 
     [[unroll]]
     for (uint i = 0; i < n_experts; i += WARP_SIZE) {
         const uint expert = i + lane;
-        wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
+        probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
     }
 
-    if (!late_softmax) {
-        softmax_warp_inplace(wt, n_experts, lane, nexperts_use_push);
+    if (gating_func == GATING_FUNC_SOFTMAX) {
+        softmax_warp_inplace(probs, n_experts, lane, nexperts_use_push);
+    } else if (gating_func == GATING_FUNC_SIGMOID) {
+        [[unroll]]
+        for (uint i = 0; i < n_experts; i += WARP_SIZE) {
+            const uint expert = i + lane;
+            probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? 1.f / (1.f + exp(-probs[i / WARP_SIZE])) : -INFINITY;
+        }
+    }
+
+    float selection_probs[experts_per_thread];
+    if (has_bias != 0) {
+        [[unroll]]
+        for (uint i = 0; i < n_experts; i += WARP_SIZE) {
+            const uint expert = i + lane;
+            selection_probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? probs[i / WARP_SIZE] + bias[bias_offset + expert] : -INFINITY;
+        }
+    } else {
+        [[unroll]]
+        for (int i = 0; i < experts_per_thread; i++) {
+            selection_probs[i] = probs[i];
+        }
     }
 
     // at this point, each thread holds a portion of softmax,
@@ -117,14 +150,16 @@ void main() {
     }
 
     for (int k = 0; k < n_expert_used; k++) {
-        float max_val    = wt[0];
+        float max_val    = probs[0];
+        float max_val_s  = selection_probs[0];
         uint   max_expert = lane;
 
         [[unroll]]
-        for (int i = 1; i < experts_per_thread; i++) {
-            const uint expert = lane + i * WARP_SIZE;
-            if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
-                max_val    = wt[i];
+        for (uint i = WARP_SIZE; i < n_experts; i += WARP_SIZE) {
+            const uint expert = i + lane;
+            if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i / WARP_SIZE] > max_val_s) {
+                max_val    = probs[i / WARP_SIZE];
+                max_val_s  = selection_probs[i / WARP_SIZE];
                 max_expert = expert;
             }
         }
@@ -132,9 +167,11 @@ void main() {
         [[unroll]]
         for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
             const float val    = subgroupShuffleXor(max_val, mask);
+            const float val_s  = subgroupShuffleXor(max_val_s, mask);
             const uint  expert = subgroupShuffleXor(max_expert, mask);
-            if (val > max_val || (val == max_val && expert < max_expert)) {
+            if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) {
                 max_val    = val;
+                max_val_s  = val_s;
                 max_expert = expert;
             }
         }
@@ -144,16 +181,14 @@ void main() {
         }
 
         if ((max_expert & (WARP_SIZE - 1)) == lane) {
-            wt[max_expert / WARP_SIZE] = -INFINITY;
+            selection_probs[max_expert / WARP_SIZE] = -INFINITY;
 
             ids[ids_offset + k] = max_expert;
-            if (with_norm) {
-                wt_sum += max_val;
-            }
+            wt_sum += max_val;
         }
     }
 
-    if (with_norm) {
+    if (with_norm != 0) {
         wt_sum              = subgroupAdd(wt_sum);
         wt_sum              = clamp(wt_sum, clamp_min, clamp_max);
         const float inv_sum = 1.0f / wt_sum;
@@ -164,7 +199,7 @@ void main() {
         }
     }
 
-    if (late_softmax) {
+    if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) {
         softmax_warp_inplace(output_weights, n_expert_used, lane, true);
     }
 
@@ -172,7 +207,7 @@ void main() {
     for (uint i = 0; i < experts_per_thread; ++i) {
         uint idx = i * WARP_SIZE + lane;
         if (idx < n_expert_used) {
-            weights[weights_offset + idx] = output_weights[i];
+            weights[weights_offset + idx] = output_scale * output_weights[i] + output_bias;
         }
     }
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl
index 02578c77c4f..bdb2c09259b 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl
@@ -172,16 +172,12 @@ struct block_q8_0
     float16_t d;
     int8_t qs[32];
 };
+
 struct block_q8_0_packed16
 {
     float16_t d;
     int16_t qs[32/2];
 };
-struct block_q8_0_packed32
-{
-    float16_t d;
-    int32_t qs[32/4];
-};
 
 #if defined(DATA_A_Q8_0)
 #define QUANT_K QUANT_K_Q8_0
@@ -189,7 +185,6 @@ struct block_q8_0_packed32
 #define QUANT_AUXF 1
 #define A_TYPE block_q8_0
 #define A_TYPE_PACKED16 block_q8_0_packed16
-#define A_TYPE_PACKED32 block_q8_0_packed32
 #define DATA_A_QUANT_LEGACY
 #endif
 
@@ -201,11 +196,13 @@ struct block_q8_1
     f16vec2 ds;
     int8_t qs[32];
 };
+
 struct block_q8_1_packed16
 {
     f16vec2 ds;
     int16_t qs[16];
 };
+
 struct block_q8_1_packed32
 {
     f16vec2 ds;
@@ -218,6 +215,7 @@ struct block_q8_1_x4
     f16vec2 ds[4];
     int32_t qs[32];
 };
+
 struct block_q8_1_x4_packed128
 {
     f16vec2 ds[4];
@@ -398,6 +396,12 @@ struct block_iq1_s {
     uint16_t qh[QUANT_K_IQ1_S/32];
 };
 
+struct block_iq1_s_packed16 {
+    float16_t d;
+    uint16_t qs[QUANT_K_IQ1_S/8/2];
+    uint16_t qh[QUANT_K_IQ1_S/32];
+};
+
 #define QUANT_K_IQ1_M 256
 #define QUANT_R_IQ1_M 1
 
@@ -407,6 +411,18 @@ struct block_iq1_m {
     uint16_t scales[QUANT_K_IQ1_M/64];
 };
 
+struct block_iq1_m_packed16 {
+    uint16_t qs[QUANT_K_IQ1_M/8/2];
+    uint16_t qh[QUANT_K_IQ1_M/16/2];
+    uint16_t scales[QUANT_K_IQ1_M/64];
+};
+
+struct block_iq1_m_packed32 {
+    uint32_t qs[QUANT_K_IQ1_M/8/4];
+    uint32_t qh[QUANT_K_IQ1_M/16/4];
+    uint32_t scales[QUANT_K_IQ1_M/64/2];
+};
+
 struct block_iq1_m_packed64 {
     uint64_t  qs[QUANT_K_IQ1_M/8/8];
     uint64_t  qh[QUANT_K_IQ1_M/16/8];
@@ -417,12 +433,15 @@ struct block_iq1_m_packed64 {
 #define QUANT_K QUANT_K_IQ1_S
 #define QUANT_R QUANT_R_IQ1_S
 #define A_TYPE block_iq1_s
+#define A_TYPE_PACKED16 block_iq1_s_packed16
 #endif
 
 #if defined(DATA_A_IQ1_M)
 #define QUANT_K QUANT_K_IQ1_M
 #define QUANT_R QUANT_R_IQ1_M
 #define A_TYPE block_iq1_m
+#define A_TYPE_PACKED16 block_iq1_m_packed16
+#define A_TYPE_PACKED32 block_iq1_m_packed32
 #endif
 
 #if defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M)
@@ -561,7 +580,270 @@ const uint[1024] iq1s_grid_const = {
     0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557
 };
 
+// Same content as iq1s_grid_const except each 2-bit value is expanded to 4-bit
+// and has 1 added to it (allows packed values to be extracted with & 0x0F0F0F0F
+// and 0xF0F0F0F0).
+const uint32_t[2048] iq1s_grid_gpu_const = {
+    0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000,
+    0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101,
+    0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200,
+    0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212,
+    0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011,
+    0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111,
+    0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220,
+    0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022,
+    0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220,
+    0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101,
+    0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110,
+    0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111,
+    0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010,
+    0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210,
+    0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221,
+    0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021,
+    0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002,
+    0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101,
+    0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101,
+    0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211,
+    0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110,
+    0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022,
+    0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121,
+    0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220,
+    0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001,
+    0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101,
+    0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102,
+    0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012,
+    0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010,
+    0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111,
+    0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122,
+    0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222,
+    0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001,
+    0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102,
+    0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101,
+    0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000,
+    0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101,
+    0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112,
+    0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110,
+    0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211,
+    0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012,
+    0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111,
+    0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120,
+    0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122,
+    0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121,
+    0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221,
+    0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001,
+    0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101,
+    0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101,
+    0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011,
+    0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111,
+    0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011,
+    0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122,
+    0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121,
+    0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222,
+    0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101,
+    0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000,
+    0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200,
+    0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110,
+    0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112,
+    0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222,
+    0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021,
+    0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121,
+    0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201,
+    0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200,
+    0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101,
+    0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011,
+    0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010,
+    0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211,
+    0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121,
+    0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000,
+    0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202,
+    0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202,
+    0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211,
+    0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112,
+    0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020,
+    0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121,
+    0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222,
+    0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102,
+    0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100,
+    0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110,
+    0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011,
+    0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111,
+    0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110,
+    0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121,
+    0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222,
+    0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201,
+    0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102,
+    0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201,
+    0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012,
+    0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010,
+    0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010,
+    0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110,
+    0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011,
+    0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212,
+    0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021,
+    0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021,
+    0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021,
+    0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101,
+    0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101,
+    0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100,
+    0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010,
+    0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111,
+    0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010,
+    0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111,
+    0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120,
+    0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120,
+    0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101,
+    0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001,
+    0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201,
+    0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210,
+    0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211,
+    0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111,
+    0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112,
+    0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211,
+    0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010,
+    0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021,
+    0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122,
+    0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221,
+    0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102,
+    0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100,
+    0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101,
+    0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101,
+    0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101,
+    0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012,
+    0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110,
+    0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112,
+    0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210,
+    0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210,
+    0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210,
+    0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010,
+    0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110,
+    0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122,
+    0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020,
+    0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021,
+    0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022,
+    0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120,
+    0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222,
+    0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221,
+    0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001,
+    0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102,
+    0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201,
+    0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012,
+    0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111,
+    0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012,
+    0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110,
+    0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110,
+    0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121,
+    0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221,
+    0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220,
+    0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222,
+    0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000,
+    0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201,
+    0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012,
+    0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011,
+    0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212,
+    0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221,
+    0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121,
+    0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202,
+    0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202,
+    0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002,
+    0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101,
+    0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210,
+    0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112,
+    0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011,
+    0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011,
+    0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210,
+    0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020,
+    0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220,
+    0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222,
+    0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222,
+    0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001,
+    0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010,
+    0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111,
+    0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010,
+    0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110,
+    0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221,
+    0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122,
+    0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202,
+    0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100,
+    0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101,
+    0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112,
+    0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111,
+    0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211,
+    0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222,
+    0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221,
+    0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022,
+    0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101,
+    0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211,
+    0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111,
+    0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111,
+    0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010,
+    0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121,
+    0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222,
+    0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000,
+    0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202,
+    0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000,
+    0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202,
+    0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110,
+    0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110,
+    0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222,
+    0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120,
+    0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022,
+    0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101,
+    0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202,
+    0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110,
+    0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110,
+    0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111,
+    0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111,
+    0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120,
+    0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121,
+    0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001,
+    0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202,
+    0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001,
+    0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200,
+    0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011,
+    0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212,
+    0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012,
+    0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110,
+    0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012,
+    0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111,
+    0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020,
+    0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121,
+    0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222,
+    0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102,
+    0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102,
+    0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101,
+    0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212,
+    0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210,
+    0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111,
+    0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212,
+    0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221,
+    0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121,
+    0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002,
+    0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000,
+    0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202,
+    0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112,
+    0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111,
+    0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020,
+    0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221,
+    0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022,
+    0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100,
+    0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201,
+    0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112,
+    0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211,
+    0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012,
+    0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121,
+    0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020,
+    0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120,
+    0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200,
+    0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200,
+    0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110,
+    0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011,
+    0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222,
+    0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020,
+    0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222,
+};
+
 shared uint16_t iq1s_grid[2048];
+shared uint32_t iq1s_grid_gpu[2048];
 
 #define NEEDS_INIT_IQ_SHMEM
 void init_iq_shmem(uvec3 wgsize)
@@ -575,6 +857,12 @@ void init_iq_shmem(uvec3 wgsize)
             iq1s_grid[2*idx+1] = g.y;
         }
     }
+    [[unroll]] for (uint i = 0; i < iq1s_grid_gpu_const.length(); i += wgsize.x) {
+        uint idx = i + gl_LocalInvocationIndex.x;
+        if (iq1s_grid_gpu_const.length() % wgsize.x == 0 || idx < iq1s_grid_gpu_const.length()) {
+            iq1s_grid_gpu[idx] = iq1s_grid_gpu_const[idx];
+        }
+    }
     barrier();
 }
 #endif
@@ -1346,10 +1634,28 @@ struct block_iq4_xs
     uint8_t qs[QUANT_K_IQ4_XS/2];
 };
 
+struct block_iq4_xs_packed16
+{
+    float16_t d;
+    uint16_t scales_h;
+    uint16_t scales_l[QUANT_K_IQ4_XS/128];
+    uint16_t qs[QUANT_K_IQ4_XS/4];
+};
+
+struct block_iq4_xs_packed32
+{
+    float16_t d;
+    uint16_t scales_h;
+    uint32_t scales_l;
+    uint32_t qs[QUANT_K_IQ4_XS/8];
+};
+
 #if defined(DATA_A_IQ4_XS)
 #define QUANT_K QUANT_K_IQ4_XS
 #define QUANT_R QUANT_R_IQ4_XS
 #define A_TYPE block_iq4_xs
+#define A_TYPE_PACKED16 block_iq4_xs_packed16
+#define A_TYPE_PACKED32 block_iq4_xs_packed32
 #endif
 
 #define QUANT_K_IQ4_NL 32
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp
index 037ab0c78f0..f7d12a8dda6 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp
@@ -21,6 +21,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
 #define NEAREST  0
 #define BILINEAR 1
 #define BICUBIC  2
+#define BILINEAR_ANTIALIAS 513
 
 layout (constant_id = 0) const uint scale_mode = 0;
 
@@ -62,6 +63,56 @@ float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) {
     return fetch_bilinear(c0, c1, d, i12, i13);
 }
 
+float triangle_filter(float x) {
+    return max(1.0f - abs(x), 0.0f);
+}
+
+float interpolate_bilinear_antialias(uint i10, uint i11, uint i12, uint i13) {
+    const float support1  = max(1.0f, 1.0f / p.sf1);
+    const float invscale1 = 1.0f / support1;
+    const float support0  = max(1.0f, 1.0f / p.sf0);
+    const float invscale0 = 1.0f / support0;
+
+    const uint i02 = uint(i12 / p.sf2);
+    const uint i03 = uint(i13 / p.sf3);
+
+    const float y = (float(i11) + p.pixel_offset) / p.sf1;
+    const float x = (float(i10) + p.pixel_offset) / p.sf0;
+
+    // the range of source pixels that contribute
+    const int x_min = max(int(x - support0 + p.pixel_offset), 0);
+    const int x_max = min(int(x + support0 + p.pixel_offset), int(p.ne00));
+    const int y_min = max(int(y - support1 + p.pixel_offset), 0);
+    const int y_max = min(int(y + support1 + p.pixel_offset), int(p.ne01));
+
+    // bilinear filter with antialiasing
+    float val = 0.0f;
+    float total_weight = 0.0f;
+
+    for (int sy = y_min; sy < y_max; sy++) {
+        const float weight_y = triangle_filter((sy - y + p.pixel_offset) * invscale1);
+
+        for (int sx = x_min; sx < x_max; sx++) {
+            const float weight_x = triangle_filter((sx - x + p.pixel_offset) * invscale0);
+            const float weight = weight_x * weight_y;
+
+            if (weight <= 0.0f) {
+                continue;
+            }
+
+            const float pixel = data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + sy * p.nb01 + sx * p.nb00];
+            val += pixel * weight;
+            total_weight += weight;
+        }
+    }
+
+    if (total_weight > 0.0f) {
+        val /= total_weight;
+    }
+
+    return val;
+}
+
 // Bicubic interpolation with alpha = -0.75
 // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
 const vec4 bcoeffs1 = vec4( 1.25, -2.25,  0.0, 1.0);
@@ -118,6 +169,9 @@ void main() {
         case BICUBIC:
             result = interpolate_bicubic(i10, i11, i12, i13);
             break;
+        case BILINEAR_ANTIALIAS:
+            result = interpolate_bilinear_antialias(i10, i11, i12, i13);
+            break;
     }
 
     data_d[p.d_offset + idx] = D_TYPE(result);
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
index b0ade078c7b..85455988c57 100644
--- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
@@ -330,7 +330,7 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
         std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, in_path, "-o", out_path};
     #endif
 
-    // disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734
+    // disable spirv-opt for coopmat shaders for https://github.com/ggml-org/llama.cpp/issues/10734
     // disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344
     // disable spirv-opt for rope shaders for https://github.com/ggml-org/llama.cpp/issues/16860
     if (!coopmat && name.find("bf16") == std::string::npos && name.find("rope") == std::string::npos) {
@@ -552,9 +552,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
 
     for (const auto& tname : type_names) {
         std::string load_vec_quant = "2";
-        if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
+        if ((tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
             load_vec_quant = "8";
-        else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4"))
+        else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4"))
             load_vec_quant = "4";
 
         if (tname == "bf16") {
@@ -595,8 +595,6 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
 }
 
 void process_shaders() {
-    std::map base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
-
     // matmul
     for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {
         // No coopmats
@@ -622,49 +620,63 @@ void process_shaders() {
         }
     }
 
-    // flash attention
-    for (const auto& f16acc : {false, true}) {
-        std::map fa_base_dict = base_dict;
-        fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
-        fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4";
-        if (f16acc) {
-            fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
+    for (const bool& fp16 : {false, true}) {
+        std::map base_dict;
+        if (fp16) {
+            base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}};
+        } else {
+            base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV4", "vec4"}};
         }
 
-        for (const auto& tname : type_names) {
-            if (tname == "bf16") continue;
+        // flash attention
+        for (const bool& f16acc : {false, true}) {
+            std::map fa_base_dict = base_dict;
+            fa_base_dict["ACC_TYPE"] = fp16 && f16acc ? "float16_t" : "float";
+            fa_base_dict["ACC_TYPEV4"] = fp16 && f16acc ? "f16vec4" : "vec4";
+            if (fp16 && f16acc) {
+                fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
+            }
+
+            for (const auto& tname : type_names) {
+                if (tname == "bf16") continue;
 
+                if (fp16) {
 #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
-            if (tname == "f16") {
-                string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
-                    merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc);
-            } else {
-                std::string data_a_key = "DATA_A_" + to_uppercase(tname);
-                string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
-                    merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
-            }
+                if (tname == "f16") {
+                    string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
+                        merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc);
+                } else {
+                    std::string data_a_key = "DATA_A_" + to_uppercase(tname);
+                    string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
+                        merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, true, f16acc);
+                }
 #endif
 #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
-            if (tname == "f16") {
-                string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
-                    merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
-            } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
-                std::string data_a_key = "DATA_A_" + to_uppercase(tname);
-                string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
-                    merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
-            }
+                if (tname == "f16") {
+                    string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
+                        merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
+                } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
+                    std::string data_a_key = "DATA_A_" + to_uppercase(tname);
+                    string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
+                        merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
+                }
 #endif
-            if (tname == "f16") {
-                string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
-                    merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
-            } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
-                std::string data_a_key = "DATA_A_" + to_uppercase(tname);
-                string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
-                    merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
+                }
+
+                if (tname == "f16") {
+                    string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
+                        merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc);
+                } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
+                    std::string data_a_key = "DATA_A_" + to_uppercase(tname);
+                    string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
+                        merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc);
+                }
             }
         }
     }
 
+    std::map base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
+
     for (const auto& tname : type_names) {
         // mul mat vec
         std::string data_a_key = "DATA_A_" + to_uppercase(tname);
@@ -685,7 +697,7 @@ void process_shaders() {
 
         // mul mat vec with integer dot product
 #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
-        if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname)) {
+        if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname) || tname == "iq1_s" || tname == "iq1_m") {
             string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
             string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
             string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
@@ -790,6 +802,8 @@ void process_shaders() {
     string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
     string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {});
 
+    string_to_spv("fa_mask_opt", "flash_attn_mask_opt.comp", {});
+
     string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
     string_to_spv("quantize_q8_1_subgroup", "quantize_q8_1.comp", {{"USE_SUBGROUPS", "1"}});
 
@@ -853,6 +867,8 @@ void process_shaders() {
     string_to_spv("hardswish_f32",  "hardswish.comp",   {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
     string_to_spv("abs_f16",        "abs.comp",         {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
     string_to_spv("abs_f32",        "abs.comp",         {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
+    string_to_spv("xielu_f16",      "xielu.comp",       {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
+    string_to_spv("xielu_f32",      "xielu.comp",       {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 
     string_to_spv("tri_f16",        "tri.comp",         {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
     string_to_spv("tri_f32",        "tri.comp",         {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
@@ -925,6 +941,8 @@ void process_shaders() {
     string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
     string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
     string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
+    string_to_spv("rope_multi_f32_f16", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}});
+    string_to_spv("rope_multi_f32_f16_rte", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
 
     string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
     string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
@@ -940,6 +958,10 @@ void process_shaders() {
     string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
     string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
     string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+    string_to_spv("cumsum_multipass1_f32", "cumsum_multipass1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+    string_to_spv("cumsum_multipass2_f32", "cumsum_multipass2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+
+    string_to_spv("count_experts", "count_experts.comp", merge_maps(base_dict, {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}}));
 
     for (std::string dim_str : {"", "_3d"}) {
         for (bool bda : {false, true}) {
@@ -1117,7 +1139,7 @@ void write_output_files() {
 
     for (const std::string& btype : btypes) {
     for (const auto& tname : type_names) {
-        if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname)) {
+        if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname) && tname != "iq1_s" && tname != "iq1_m") {
             continue;
         }
         hdr << "extern const void * arr_dmmv_"   << tname << "_" << btype << "_f32_data[3];\n";
diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp
new file mode 100644
index 00000000000..35d463bfe44
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp
@@ -0,0 +1,35 @@
+#version 450
+
+#include "generic_head.glsl"
+#include "types.glsl"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+    if (i >= p.KX) {
+        return;
+    }
+
+    float x = float(data_a[i]);
+
+    float alpha_n = p.param1;
+    float alpha_p = p.param2;
+    float beta = p.param3;
+    float eps = p.param4;
+
+    if (x > 0.0f) {
+        x = alpha_p * x * x + beta * x;
+    } else {
+        const float min_x_eps = min(x, eps);
+        x = (exp(min_x_eps) - 1 - x) * alpha_n + beta * x;
+    }
+
+    data_d[i] = D_TYPE(x);
+}
diff --git a/ml/backend/ggml/ggml/src/ggml.c b/ml/backend/ggml/ggml/src/ggml.c
index c9242a15a00..d3c039bd9de 100644
--- a/ml/backend/ggml/ggml/src/ggml.c
+++ b/ml/backend/ggml/ggml/src/ggml.c
@@ -53,13 +53,15 @@
 
 #define UNUSED GGML_UNUSED
 
+// Needed for ggml_fp32_to_bf16_row()
+#if defined(__AVX512BF16__)
 #if defined(_MSC_VER)
-#define m512bh(p) p
 #define m512i(p) p
 #else
-#define m512bh(p) (__m512bh)(p)
+#include 
 #define m512i(p) (__m512i)(p)
-#endif
+#endif // defined(_MSC_VER)
+#endif // defined(__AVX512BF16__)
 
 #if defined(__linux__) || \
     defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \
@@ -902,7 +904,8 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
 };
 
 const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type) {
-    GGML_ASSERT(type < GGML_TYPE_COUNT);
+    assert(type >= 0);
+    assert(type < GGML_TYPE_COUNT);
     return &type_traits[type];
 }
 
@@ -1268,27 +1271,33 @@ size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) {
 }
 
 int64_t ggml_blck_size(enum ggml_type type) {
+    assert(type >= 0);
+    assert(type < GGML_TYPE_COUNT);
     return type_traits[type].blck_size;
 }
 
 size_t ggml_type_size(enum ggml_type type) {
+    assert(type >= 0);
+    assert(type < GGML_TYPE_COUNT);
     return type_traits[type].type_size;
 }
 
 size_t ggml_row_size(enum ggml_type type, int64_t ne) {
+    assert(type >= 0);
+    assert(type < GGML_TYPE_COUNT);
     assert(ne % ggml_blck_size(type) == 0);
     return ggml_type_size(type)*ne/ggml_blck_size(type);
 }
 
-double ggml_type_sizef(enum ggml_type type) {
-    return ((double)(type_traits[type].type_size))/type_traits[type].blck_size;
-}
-
 const char * ggml_type_name(enum ggml_type type) {
-    return type < GGML_TYPE_COUNT ? type_traits[type].type_name : "NONE";
+    assert(type >= 0);
+    assert(type < GGML_TYPE_COUNT);
+    return type_traits[type].type_name;
 }
 
 bool ggml_is_quantized(enum ggml_type type) {
+    assert(type >= 0);
+    assert(type < GGML_TYPE_COUNT);
     return type_traits[type].is_quantized;
 }
 
@@ -1499,6 +1508,10 @@ bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tenso
         (t0->nb[3] == t1->nb[3]);
 }
 
+bool ggml_is_view(const struct ggml_tensor * t) {
+    return ggml_impl_is_view(t);
+}
+
 // check if t1 can be represented as a repetition of t0
 bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
@@ -1628,11 +1641,23 @@ static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml
     const size_t cur_end  = cur_offs + cur_size;
 
     // align to GGML_MEM_ALIGN
+    GGML_ASSERT(size <= SIZE_MAX - (GGML_MEM_ALIGN - 1));
     size_t size_needed = GGML_PAD(size, GGML_MEM_ALIGN);
 
     char * const mem_buffer = ctx->mem_buffer;
     struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end);
 
+    // integer overflow checks
+    if (cur_end > SIZE_MAX - size_needed) {
+        GGML_LOG_WARN("%s: overflow detected in cur_end (%zu) + size_needed (%zu)\n", __func__, cur_end, size_needed);
+        return NULL;
+    }
+    if (cur_end + size_needed > SIZE_MAX - GGML_OBJECT_SIZE) {
+        GGML_LOG_WARN("%s: overflow detected in cur_end (%zu) + size_needed (%zu) + GGML_OBJECT_SIZE (%zu)\n", __func__,
+                cur_end, size_needed, (size_t) GGML_OBJECT_SIZE);
+        return NULL;
+    }
+
     if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
         GGML_LOG_WARN("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
                 __func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size);
@@ -1701,6 +1726,8 @@ static struct ggml_tensor * ggml_new_tensor_impl(
         obj_alloc_size = data_size;
     }
 
+    GGML_ASSERT(GGML_TENSOR_SIZE <= SIZE_MAX - obj_alloc_size);
+
     struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TYPE_TENSOR, GGML_TENSOR_SIZE + obj_alloc_size);
     GGML_ASSERT(obj_new);
 
@@ -3444,7 +3471,8 @@ struct ggml_tensor * ggml_cast(
 
     result->op     = GGML_OP_CPY;
     result->src[0] = a;
-    result->src[1] = result;
+    result->src[1] = result; // note: this self-reference might seem redundant, but it's actually needed by some
+                             //       backends for consistency with ggml_cpy_impl() above
 
     return result;
 }
@@ -4841,6 +4869,8 @@ struct ggml_tensor * ggml_pool_1d(
         a->ne[2],
         a->ne[3],
     };
+    GGML_ASSERT(ne[0] > 0);
+
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
     int32_t params[] = { op, k0, s0, p0 };
@@ -4871,6 +4901,9 @@ struct ggml_tensor * ggml_pool_2d(
         a->ne[2],
         a->ne[3],
     };
+    GGML_ASSERT(ne[0] > 0);
+    GGML_ASSERT(ne[1] > 0);
+
     result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
     int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
@@ -5746,7 +5779,7 @@ static struct ggml_tensor * ggml_unary_impl(
         struct ggml_tensor  * a,
         enum ggml_unary_op    op,
         bool                  inplace) {
-    GGML_ASSERT(ggml_is_contiguous_1(a));
+    GGML_ASSERT(ggml_is_contiguous_rows(a));
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
@@ -6559,7 +6592,7 @@ static void ggml_compute_backward(
         case GGML_OP_DIAG_MASK_INF: {
             if (src0_needs_grads) {
                 /* ggml_diag_mask_inf_impl() shouldn't be here */
-                /* ref:  https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */
+                /* ref:  https://github.com/ggml-org/llama.cpp/pull/4203#discussion_r1412377992 */
                 const int n_past = ((const int32_t *) tensor->op_params)[0];
                 ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false));
             }
@@ -6723,20 +6756,35 @@ static void ggml_compute_backward(
     GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2]));
 }
 
-static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
-    // check if already visited
-    size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
+static size_t ggml_visit_parents_graph(struct ggml_cgraph * cgraph, struct ggml_tensor * node, bool compute) {
+    if (node->op != GGML_OP_NONE && compute) {
+        node->flags |= GGML_TENSOR_FLAG_COMPUTE;
+    }
+
+    const size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
     GGML_ASSERT(node_hash_pos != GGML_HASHSET_FULL);
-    if (!ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) {
-        // This is the first time we see this node in the current graph.
-        cgraph->visited_hash_set.keys[node_hash_pos] = node;
-        ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos);
-        cgraph->use_counts[node_hash_pos] = 0;
-    } else {
+
+    if (ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) {
         // already visited
+
+        if (compute) {
+            // update the compute flag regardless
+            for (int i = 0; i < GGML_MAX_SRC; ++i) {
+                struct ggml_tensor * src = node->src[i];
+                if (src && ((src->flags & GGML_TENSOR_FLAG_COMPUTE) == 0)) {
+                    ggml_visit_parents_graph(cgraph, src, true);
+                }
+            }
+        }
+
         return node_hash_pos;
     }
 
+    // This is the first time we see this node in the current graph.
+    cgraph->visited_hash_set.keys[node_hash_pos] = node;
+    ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos);
+    cgraph->use_counts[node_hash_pos] = 0;
+
     for (int i = 0; i < GGML_MAX_SRC; ++i) {
         const int k =
             (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
@@ -6745,7 +6793,7 @@ static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor
 
         struct ggml_tensor * src = node->src[k];
         if (src) {
-            size_t src_hash_pos = ggml_visit_parents(cgraph, src);
+            const size_t src_hash_pos = ggml_visit_parents_graph(cgraph, src, compute);
 
             // Update the use count for this operand.
             cgraph->use_counts[src_hash_pos]++;
@@ -6776,17 +6824,17 @@ static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor
     return node_hash_pos;
 }
 
-static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
+static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand, bool compute) {
     if (!expand) {
         // TODO: this branch isn't accessible anymore, maybe move this to ggml_build_forward_expand
         ggml_graph_clear(cgraph);
     }
 
-    const int n0 = cgraph->n_nodes;
+    const int n_old = cgraph->n_nodes;
 
-    ggml_visit_parents(cgraph, tensor);
+    ggml_visit_parents_graph(cgraph, tensor, compute);
 
-    const int n_new = cgraph->n_nodes - n0;
+    const int n_new = cgraph->n_nodes - n_old;
     GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new);
 
     if (n_new > 0) {
@@ -6795,8 +6843,22 @@ static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_ten
     }
 }
 
+struct ggml_tensor * ggml_build_forward_select(
+        struct ggml_cgraph  * cgraph,
+        struct ggml_tensor ** tensors,
+        int                   n_tensors,
+        int                   idx) {
+    GGML_ASSERT(idx >= 0 && idx < n_tensors);
+
+    for (int i = 0; i < n_tensors; i++) {
+        ggml_build_forward_impl(cgraph, tensors[i], true, i == idx ? true : false);
+    }
+
+    return tensors[idx];
+}
+
 void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) {
-    ggml_build_forward_impl(cgraph, tensor, true);
+    ggml_build_forward_impl(cgraph, tensor, true, true);
 }
 
 void ggml_build_backward_expand(
@@ -7227,6 +7289,10 @@ bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
             return false;
         }
 
+        if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
+            return false;
+        }
+
         if (ggml_node_list_find_tensor(cgraph, outputs, num_outputs, node) != -1) {
             continue;
         }
@@ -7308,7 +7374,7 @@ static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node,
             label);
 }
 
-void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) {
+void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename) {
     char color[16];
 
     FILE * fp = ggml_fopen(filename, "w");
@@ -7329,7 +7395,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
         if (node->flags & GGML_TENSOR_FLAG_PARAM) {
             snprintf(color, sizeof(color), "yellow");
         } else if (grad) {
-            if (ggml_graph_find(gf, node)) {
+            if (ggml_graph_find(cgraph, node)) {
                 snprintf(color, sizeof(color), "green");
             } else {
                 snprintf(color, sizeof(color), "lightblue");
@@ -7481,8 +7547,11 @@ void ggml_quantize_free(void) {
 
     iq2xs_free_impl(GGML_TYPE_IQ2_XXS);
     iq2xs_free_impl(GGML_TYPE_IQ2_XS);
+    iq2xs_free_impl(GGML_TYPE_IQ2_S);
     iq2xs_free_impl(GGML_TYPE_IQ1_S);
+    iq2xs_free_impl(GGML_TYPE_IQ1_M);
     iq3xs_free_impl(256);
+    iq3xs_free_impl(512);
 
     ggml_critical_section_end();
 }
diff --git a/ml/backend/ggml/ggml/src/gguf.cpp b/ml/backend/ggml/ggml/src/gguf.cpp
index f91d4fabad3..6fb6ea927f9 100644
--- a/ml/backend/ggml/ggml/src/gguf.cpp
+++ b/ml/backend/ggml/ggml/src/gguf.cpp
@@ -15,6 +15,17 @@
 #include 
 #include 
 
+#define GGUF_MAX_STRING_LENGTH  (1024*1024*1024)
+#define GGUF_MAX_ARRAY_ELEMENTS (1024*1024*1024)
+
+#ifdef _WIN32
+#    define gguf_ftell _ftelli64
+#    define gguf_fseek _fseeki64
+#else
+#    define gguf_ftell ftello
+#    define gguf_fseek fseeko
+#endif
+
 template 
 struct type_to_gguf_type;
 
@@ -217,17 +228,64 @@ struct gguf_context {
 };
 
 struct gguf_reader {
-    FILE * file;
+    gguf_reader(FILE * file) : file(file) {
+        // read the remaining bytes once and update on each read
+        nbytes_remain = file_remain(file);
+    }
+
+    // helper for remaining bytes in a file
+    static uint64_t file_remain(FILE * file) {
+        const int64_t cur = gguf_ftell(file);
+        if (cur < 0) {
+            return 0;
+        }
+        if (gguf_fseek(file, 0, SEEK_END) != 0) {
+            gguf_fseek(file, cur, SEEK_SET);
+
+            return 0;
+        }
+        const int64_t end = gguf_ftell(file);
+        if (end < 0) {
+            gguf_fseek(file, cur, SEEK_SET);
 
-    gguf_reader(FILE * file) : file(file) {}
+            return 0;
+        }
+        gguf_fseek(file, cur, SEEK_SET);
+        return static_cast(end - cur);
+    }
 
     template 
     bool read(T & dst) const {
-        return fread(&dst, 1, sizeof(dst), file) == sizeof(dst);
+        const size_t size = sizeof(dst);
+        if (nbytes_remain < size) {
+            return false;
+        }
+        const size_t nread = fread(&dst, 1, size, file);
+        nbytes_remain -= nread;
+        return nread == size;
     }
 
     template 
     bool read(std::vector & dst, const size_t n) const {
+        if (n > GGUF_MAX_ARRAY_ELEMENTS) {
+            return false;
+        }
+        if constexpr (std::is_same::value) {
+            // strings are prefixed with their length, so we need to account for that
+            if (n > SIZE_MAX / sizeof(uint64_t)) {
+                return false;
+            }
+            if (nbytes_remain < n * sizeof(uint64_t)) {
+                return false;
+            }
+        } else {
+            if (n > SIZE_MAX / sizeof(T)) {
+                return false;
+            }
+            if (nbytes_remain < n * sizeof(T)) {
+                return false;
+            }
+        }
         dst.resize(n);
         for (size_t i = 0; i < dst.size(); ++i) {
             if constexpr (std::is_same::value) {
@@ -277,13 +335,33 @@ struct gguf_reader {
         if (!read(size)) {
             return false;
         }
-        dst.resize(size);
-        return fread(dst.data(), 1, dst.length(), file) == dst.length();
+        if (size > GGUF_MAX_STRING_LENGTH) {
+            GGML_LOG_ERROR("%s: string length %" PRIu64 " exceeds maximum %" PRIu64 "\n", __func__, size, (uint64_t) GGUF_MAX_STRING_LENGTH);
+            return false;
+        }
+        if (size > nbytes_remain) {
+            GGML_LOG_ERROR("%s: string length %" PRIu64 " exceeds remaining file size %" PRIu64 " bytes\n", __func__, size, nbytes_remain);
+            return false;
+        }
+        dst.resize(static_cast(size));
+        const size_t nread = fread(dst.data(), 1, size, file);
+        nbytes_remain -= nread;
+        return nread == size;
     }
 
     bool read(void * dst, const size_t size) const {
-        return fread(dst, 1, size, file) == size;
+        if (size > nbytes_remain) {
+            return false;
+        }
+        const size_t nread = fread(dst, 1, size, file);
+        nbytes_remain -= nread;
+        return nread == size;
     }
+
+private:
+    FILE * file;
+
+    mutable uint64_t nbytes_remain;
 };
 
 struct gguf_context * gguf_init_empty(void) {
@@ -568,8 +646,8 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
 
             // check that tensor type is within defined range
             if (info.t.type < 0 || info.t.type >= GGML_TYPE_COUNT) {
-                GGML_LOG_ERROR("%s: tensor '%s' has invalid ggml type %d (%s)\n",
-                    __func__, info.t.name, info.t.type, ggml_type_name(info.t.type));
+                GGML_LOG_ERROR("%s: tensor '%s' has invalid ggml type %d. should be in [0, %d)\n",
+                    __func__, info.t.name, info.t.type, GGML_TYPE_COUNT);
                 ok = false;
                 break;
             }
@@ -585,6 +663,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
                 break;
             }
 
+            // check that the size of the tensor in bytes is representable
+            if (ok && uint64_t(ggml_nelements(&info.t)/ggml_blck_size(info.t.type)) > SIZE_MAX/ggml_type_size(info.t.type)) {
+                GGML_LOG_ERROR("%s: tensor '%s' with shape (%" PRIi64 ", %" PRIi64 ", %" PRIi64 ", %" PRIi64 ") has a size in bytes > %zu\n",
+                    __func__, info.t.name, info.t.ne[0], info.t.ne[1], info.t.ne[2], info.t.ne[3], SIZE_MAX);
+                ok = false;
+                break;
+            }
+
             // calculate byte offsets given the tensor shape and type
             info.t.nb[0] = type_size;
             info.t.nb[1] = info.t.nb[0]*(info.t.ne[0]/blck_size);
@@ -610,14 +696,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
     GGML_ASSERT(int64_t(ctx->info.size()) == n_tensors);
 
     // we require the data section to be aligned, so take into account any padding
-    if (fseek(file, GGML_PAD(ftell(file), ctx->alignment), SEEK_SET) != 0) {
+    if (gguf_fseek(file, GGML_PAD(gguf_ftell(file), ctx->alignment), SEEK_SET) != 0) {
         GGML_LOG_ERROR("%s: failed to seek to beginning of data section\n", __func__);
         gguf_free(ctx);
         return nullptr;
     }
 
     // store the current file offset - this is where the data section starts
-    ctx->offset = ftell(file);
+    ctx->offset = gguf_ftell(file);
 
     // compute the total size of the data section, taking into account the alignment
     {
@@ -649,10 +735,34 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
         //   the ggml_tensor structs to the appropriate locations in the binary blob
 
         // compute the exact size needed for the new ggml_context
-        const size_t mem_size =
-            params.no_alloc ?
-            (n_tensors    )*ggml_tensor_overhead() :
-            (n_tensors + 1)*ggml_tensor_overhead() + ctx->size;
+        size_t mem_size = 0;
+        if (params.no_alloc) {
+            if (n_tensors != 0 && SIZE_MAX / n_tensors < ggml_tensor_overhead()) {
+                GGML_LOG_ERROR("%s: memory size overflow while allocating ggml context\n", __func__);
+                gguf_free(ctx);
+                return nullptr;
+            }
+
+            const size_t overhead = n_tensors * ggml_tensor_overhead();
+
+            mem_size = overhead;
+        } else {
+            if ((n_tensors + 1) != 0 && SIZE_MAX / (n_tensors + 1) < ggml_tensor_overhead()) {
+                GGML_LOG_ERROR("%s: memory size overflow while allocating ggml context\n", __func__);
+                gguf_free(ctx);
+                return nullptr;
+            }
+
+            const size_t overhead = (n_tensors + 1) * ggml_tensor_overhead();
+
+            if (SIZE_MAX - overhead < ctx->size) {
+                GGML_LOG_ERROR("%s: memory size overflow while allocating ggml context\n", __func__);
+                gguf_free(ctx);
+                return nullptr;
+            }
+
+            mem_size = overhead + ctx->size;
+        }
 
         struct ggml_init_params pdata = {
             /*mem_size   =*/ mem_size,
@@ -734,7 +844,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
     FILE * file = ggml_fopen(fname, "rb");
 
     if (!file) {
-        GGML_LOG_ERROR("%s: failed to open GGUF file '%s'\n", __func__, fname);
+        GGML_LOG_ERROR("%s: failed to open GGUF file '%s' (%s)\n", __func__, fname, strerror(errno));
         return nullptr;
     }
 
diff --git a/ml/backend/ggml/ggml/src/mem_dxgi_pdh.cpp b/ml/backend/ggml/ggml/src/mem_dxgi_pdh.cpp
index 2f395761c5a..4dd66c25f5d 100644
--- a/ml/backend/ggml/ggml/src/mem_dxgi_pdh.cpp
+++ b/ml/backend/ggml/ggml/src/mem_dxgi_pdh.cpp
@@ -41,7 +41,7 @@ struct {
     void *pdh_dll_handle;
     // DXGI Functions
     HRESULT (*CreateDXGIFactory1)(REFIID riid, void **ppFactory);
-    // PDH functions  
+    // PDH functions
     PDH_STATUS (*PdhOpenQueryW)(LPCWSTR szDataSource, DWORD_PTR dwUserData, PDH_HQUERY *phQuery);
     PDH_STATUS (*PdhAddCounterW)(PDH_HQUERY hQuery, LPCWSTR szFullCounterPath, DWORD_PTR dwUserData, PDH_HCOUNTER *phCounter);
     PDH_STATUS (*PdhCollectQueryData)(PDH_HQUERY hQuery);
@@ -96,7 +96,7 @@ static std::vector get_dxgi_gpu_infos() {
         while (pFactory->EnumAdapters1(i, &pAdapter) != DXGI_ERROR_NOT_FOUND) {
             DXGI_ADAPTER_DESC1 desc;
             pAdapter->GetDesc1(&desc);
-            
+
             // Get all the GPU adapter info
             GpuInfo info;
             fetch_dxgi_adapter_desc1(desc, &info);
@@ -197,7 +197,7 @@ extern "C" {
         dll_functions.PdhCollectQueryData = (PDH_STATUS (*)(PDH_HQUERY hQuery)) GetProcAddress((HMODULE)(dll_functions.pdh_dll_handle), "PdhCollectQueryData");
         dll_functions.PdhGetFormattedCounterValue = (PDH_STATUS (*)(PDH_HCOUNTER hCounter, DWORD dwFormat, LPDWORD lpdwType, PPDH_FMT_COUNTERVALUE pValue)) GetProcAddress((HMODULE)(dll_functions.pdh_dll_handle), "PdhGetFormattedCounterValue");
         dll_functions.PdhCloseQuery = (PDH_STATUS (*)(PDH_HQUERY hQuery)) GetProcAddress((HMODULE)(dll_functions.pdh_dll_handle), "PdhCloseQuery");
-    
+
         SetErrorMode(old_mode); // set old mode before any return
 
         // Check if any function pointers are NULL (not found)
@@ -209,7 +209,7 @@ extern "C" {
             dll_functions.pdh_dll_handle = NULL;
             return ERROR_PROC_NOT_FOUND;
         }
-    
+
         // No other initializations needed, successfully loaded the libraries and functions!
         return ERROR_SUCCESS;
     }
@@ -294,4 +294,4 @@ extern "C" {
 
 } // extern "C"
 
-#endif // #ifdef _WIN32
\ No newline at end of file
+#endif // #ifdef _WIN32
diff --git a/ml/backend/ggml/ggml/src/mem_hip.cpp b/ml/backend/ggml/ggml/src/mem_hip.cpp
index 23c76580629..734d437a706 100644
--- a/ml/backend/ggml/ggml/src/mem_hip.cpp
+++ b/ml/backend/ggml/ggml/src/mem_hip.cpp
@@ -288,7 +288,7 @@ int ggml_hip_mgmt_init() {
         const char *version = NULL;
         ADLX_RESULT status = adlx.ADLXQueryVersion(&version);
         if (ADLX_SUCCEEDED(status)) {
-            GGML_LOG_DEBUG("%s located ADLX version %s\n", __func__, version);  
+            GGML_LOG_DEBUG("%s located ADLX version %s\n", __func__, version);
         }
     }
 
@@ -406,7 +406,7 @@ int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool
             adlx_gdm_cleanup;
             return status;
         }
-        
+
         adlx_uint totalVRAM = 0;
         status = gpu->pVtbl->TotalVRAM(gpu, &totalVRAM);
         if (ADLX_FAILED(status)) {
@@ -555,4 +555,4 @@ int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool
 
 } // extern "C"
 
-#endif // #ifdef _WIN32
\ No newline at end of file
+#endif // #ifdef _WIN32
diff --git a/ml/backend/ggml/ggml/src/mem_nvml.cpp b/ml/backend/ggml/ggml/src/mem_nvml.cpp
index f473a2a2cbd..f8a4ac7b596 100644
--- a/ml/backend/ggml/ggml/src/mem_nvml.cpp
+++ b/ml/backend/ggml/ggml/src/mem_nvml.cpp
@@ -271,4 +271,4 @@ int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total) {
     return status;
 }
 
-}
\ No newline at end of file
+}
diff --git a/model/models/lfm2/cache.go b/model/models/lfm2/cache.go
index 7c1185e6e03..0df4db97020 100644
--- a/model/models/lfm2/cache.go
+++ b/model/models/lfm2/cache.go
@@ -31,10 +31,6 @@ func NewHybridCache(shift func(ctx ml.Context, layer int, key, shift ml.Tensor)
 	return &HybridCache{Recurrent: base}
 }
 
-func (c *HybridCache) slotsTensor() ml.Tensor {
-	return c.SlotsTensor()
-}
-
 func (c *HybridCache) seqTokens() int {
 	return c.SeqTokens()
 }
diff --git a/readline/errors.go b/readline/errors.go
index 1be5213e560..a33596d979d 100644
--- a/readline/errors.go
+++ b/readline/errors.go
@@ -1,11 +1,11 @@
 package readline
 
-import (
-	"errors"
-)
+import "errors"
 
-var ErrInterrupt = errors.New("Interrupt")
-var ErrEditPrompt = errors.New("EditPrompt")
+var (
+	ErrInterrupt  = errors.New("Interrupt")
+	ErrEditPrompt = errors.New("EditPrompt")
+)
 
 type InterruptError struct {
 	Line []rune
diff --git a/scripts/build_windows.ps1 b/scripts/build_windows.ps1
index 21e6f3be0e0..4fd27f5a02f 100644
--- a/scripts/build_windows.ps1
+++ b/scripts/build_windows.ps1
@@ -178,7 +178,31 @@ function cuda12 {
 }
 
 function cuda13 {
-    cudaCommon("13")
+    # Use Windows-specific preset with reduced architectures to avoid MSVC template compilation issues
+    mkdir -Force -path "${script:DIST_DIR}\" | Out-Null
+    $cudaMajorVer = "13"
+    if ($script:ARCH -ne "arm64") {
+        if ("$script:CUDA_DIRS".Contains("v$cudaMajorVer")) {
+            foreach ($d in $Script:CUDA_DIRS){
+                if ($d.FullName.Contains("v$cudaMajorVer")) {
+                    if (test-path -literalpath (join-path -path $d -childpath "nvcc.exe" ) ) {
+                        $cuda=($d.FullName|split-path -parent)
+                        break
+                    }
+                }
+            }
+            write-host "Building CUDA v$cudaMajorVer backend libraries $cuda"
+            $env:CUDAToolkit_ROOT=$cuda
+            & cmake -B build\cuda_v$cudaMajorVer --preset "CUDA 13 Windows" -T cuda="$cuda" --install-prefix "$script:DIST_DIR"
+            if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
+            & cmake --build build\cuda_v$cudaMajorVer --target ggml-cuda --config Release --parallel $script:JOBS
+            if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
+            & cmake --install build\cuda_v$cudaMajorVer --component "CUDA" --strip
+            if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
+        } else {
+            write-host "CUDA v$cudaMajorVer not detected, skipping"
+        }
+    }
 }
 
 function rocm {
@@ -196,6 +220,7 @@ function rocm {
             & cmake -B build\rocm --preset "ROCm 6" -G Ninja `
                 -DCMAKE_C_COMPILER=clang `
                 -DCMAKE_CXX_COMPILER=clang++ `
+                -DCMAKE_HIP_COMPILER="${script:HIP_PATH}\bin\clang++.exe" `
                 -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" `
                 -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" `
                 --install-prefix $script:DIST_DIR