Comprender los algoritmos de DeepMind y Strassen | de Stefano Bosisio | octubre de 2022

Una introducción al problema de la multiplicación de matrices, con aplicaciones en Python y JAX

Imagen por Iván Díaz en Unsplash

DeepMind publicó recientemente un artículo interesante que empleó Deep Reinforcement Learning para encontrar nuevos algoritmos de multiplicación de matrices.[1]. Uno de los objetivos de este artículo es reducir la complejidad computacional de la multiplicación de matrices. El artículo ha generado muchos comentarios y preguntas sobre la multiplicación de matrices, como puede ver en El tuit de Demis Hassabis.

La multiplicación de matrices es un área intensa de investigación en matemáticas. [2–10]. Aunque la multiplicación de matrices es un problema simple, la implementación computacional tiene algunos obstáculos que resolver. Si estamos considerando solo matrices cuadradas, la primera idea es calcular el producto como un triple for-loop :

Fig.1: Multiplicación de matrices simple en Python, con una complejidad O(n³)

Un cálculo tan simple tiene una complejidad computacional de O(n³). Esto significa que el tiempo para ejecutar dicho cálculo aumenta como la tercera potencia del tamaño de la matriz. Este es un obstáculo a superar, ya que en AI y ML nos ocupamos de matrices enormes para cada paso de modelo: ¡las redes neuronales son toneladas de multiplicaciones de matrices! Por lo tanto, dado un poder computacional constante, ¡necesitamos más y más tiempo para ejecutar todos nuestros cálculos de IA!

DeepMind ha llevado el problema de la multiplicación de matrices a un paso más concreto. Sin embargo, antes de profundizar en este documento, echemos un vistazo al problema de la multiplicación de matrices y qué algoritmos pueden ayudarnos a reducir la potencia computacional. En particular, veremos el algoritmo de Strassen y luego lo implementaremos en Python y JAX.

Recuerda que para el resto del trabajo el tamaño de las matrices será N>>1000. Todos los algoritmos deben ser aplicados a matrices de bloques.

Fig.2: El producto de la matriz C viene dado por la suma del i-ésimo elemento de las filas de la matriz A y el j-ésimo elemento de las columnas de la matriz B, para devolver el elemento de la matriz C (i,j). [Image by the author]

La matriz producto C, está dada por la suma de las filas y las columnas de las matrices A y B, respectivamente — fig.2.

Fig.3: Esquema visual de la multiplicación de matrices entre la matriz A y B, dando un nuevo producto matriz C o AB

Como vimos en la introducción, la complejidad computacional para el producto estándar de multiplicación de matrices es O(n³). en 1969 Volker Strassenun matemático alemán, destrozó el O(n³) barrera, reduciendo la multiplicación de matrices a 7 multiplicaciones y 18 sumas, llegando a una complejidad de O(n²·⁸⁰⁸)[8]. Si consideramos un conjunto de matrices A, B y C, como en la figura 3, Strassen derivó el siguiente algoritmo:

Fig.4: Algoritmo de Strassen, tal como lo propone en su artículo “La eliminación gaussiana no es óptima”. [Image by the author]

Vale la pena notar algunas cosas sobre este algoritmo:

  • el algoritmo funciona recursivamente en matrices de bloques
  • es fácil probar la complejidad O(n²·⁸⁰⁸). GRAMOdado el tamaño de la matriz norte y las 7 multiplicaciones, se sigue que:
Fig.5: Derive la complejidad del algoritmo dado el tamaño de la matriz y el número de multiplicaciones. [Image by the author]

Todos los pasos de la figura 4 son polinomios, por lo que la multiplicación de matrices puede tratarse como un problema de polinomios.

  • Computacionalmente, el algoritmo de Strassen es inestable con números de precisión flotantes [14]. La inestabilidad numérica se produce por el redondeo de todos los resultados de las submatrices. A medida que avanza el cálculo, el error total se resume en una grave pérdida de precisión.

A partir de estos puntos podemos traducir la multiplicación de matrices a un problema de polinomios. Cada operación en la figura 4 se puede escribir como una combinación lineal, por ejemplo, los pasos I son:

Fig.6: Definición de los pasos I del algoritmo de Strassen como una combinación lineal. [Image by the author]

Aquí α y β son las combinaciones lineales de los elementos de las matrices A y B, mientras que H denota una matriz codificada en caliente para operaciones de suma/resta. Entonces es posible definir los elementos de la matriz C del producto como una combinación lineal y escribir:

Fig.7: El algoritmo de Strassen se puede expresar como una combinación lineal de tres elementos. [Image by the author]

Como puede ver, todo el algoritmo se ha reducido a una combinación lineal. En particular, el lado izquierdo de la ecuación Strassenen la fig.7 se puede denotar por los tamaños de las matrices, m, n, y pags – lo que significa una multiplicación entre metroXpags y norteXpags matrices:

Fig.8: La multiplicación de matrices se puede expresar como un tensor. [Image by the author]

Para Strassen es . La figura 8 describe la multiplicación de matrices como una combinación lineal, o una tensor — es por eso que a veces el algoritmo de Strassen se llama “tensor” . los un, b, y C elementos en la figura 8 forman un tríada. Siguiendo la convención de papel de DeepMind, la tríada se puede expresar como:

Fig.9: Definición del tensor de multiplicación de matrices como una tríada, tal como se define en el artículo de Deep Mind [1]. [Image by the author].

Esta tríada establece el objetivo de encontrar el mejor algoritmo para minimizar la complejidad computacional de la operación de multiplicación de matrices. De hecho, el número mínimo de tríadas define el número mínimo de operaciones para calcular la matriz del producto. Este número mínimo es el rango del tensor R

Fuente del artículo

Deja un comentario