@ -39,6 +39,8 @@ pub struct DataSettings {
use_opencl_for_attention : bool ,
use_opencl_for_attention : bool ,
#[ cfg(feature = " opencl " ) ]
#[ cfg(feature = " opencl " ) ]
cl : Option < OpenCL > ,
cl : Option < OpenCL > ,
force_f16 : bool ,
}
}
// OpenCL is safe to send to threads but Rust doesn't know that
// OpenCL is safe to send to threads but Rust doesn't know that
@ -51,13 +53,14 @@ impl DataSettings {
DataSettings {
DataSettings {
use_opencl_for_feedforward : false ,
use_opencl_for_feedforward : false ,
use_opencl_for_attention : false ,
use_opencl_for_attention : false ,
force_f16 : false ,
cl : cl . clone ( ) ,
cl : cl . clone ( ) ,
}
}
}
}
#[ cfg(not(feature = " opencl " )) ]
#[ cfg(not(feature = " opencl " )) ]
pub fn new ( ) -> Self {
pub fn new ( ) -> Self {
DataSettings { }
DataSettings { force_f16 : false }
}
}
#[ cfg(feature = " opencl " ) ]
#[ cfg(feature = " opencl " ) ]
@ -69,6 +72,11 @@ impl DataSettings {
self . use_opencl_for_attention = true ;
self . use_opencl_for_attention = true ;
self
self
}
}
pub fn force_f16 ( mut self ) -> DataSettings {
self . force_f16 = true ;
self
}
}
}
pub struct TransformerCaches {
pub struct TransformerCaches {
@ -400,6 +408,12 @@ impl FeedForward {
FromPiecesDirection ::Rows ,
FromPiecesDirection ::Rows ,
) ? ;
) ? ;
if data_settings . force_f16 {
w1 = w1 . to_f16 ( ) ;
w2 = w2 . to_f16 ( ) ;
w3 = w3 . to_f16 ( ) ;
}
#[ cfg(feature = " opencl " ) ]
#[ cfg(feature = " opencl " ) ]
{
{
if data_settings . use_opencl_for_feedforward {
if data_settings . use_opencl_for_feedforward {
@ -412,12 +426,7 @@ impl FeedForward {
w3 . to_gpu_inplace ( & ds . cl . unwrap ( ) ) . unwrap ( ) ;
w3 . to_gpu_inplace ( & ds . cl . unwrap ( ) ) . unwrap ( ) ;
}
}
}
}
#[ cfg(not(feature = " opencl " )) ]
// w1, w2, w3 maybe be f32 or f16 depending on source data.
{
w1 = w1 . to_f32 ( ) ;
w2 = w2 . to_f32 ( ) ;
w3 = w3 . to_f32 ( ) ;
}
Ok ( Self {
Ok ( Self {
w1 ,
w1 ,
@ -428,31 +437,40 @@ impl FeedForward {
}
}
pub fn forward ( & self , x : & mut Tensor ) -> Tensor {
pub fn forward ( & self , x : & mut Tensor ) -> Tensor {
let original_x_dtype = x . dtype ( ) ;
if x . dtype ( ) ! = self . w1 . dtype ( ) {
* x = x . to_same_type ( & self . w1 ) ;
}
#[ cfg(feature = " opencl " ) ]
#[ cfg(feature = " opencl " ) ]
let x_was_on_cpu : bool ;
let x_was_on_cpu : bool ;
#[ cfg(feature = " opencl " ) ]
#[ cfg(feature = " opencl " ) ]
{
{
x_was_on_cpu = x . is_on_cpu ( ) ;
x_was_on_cpu = x . is_on_cpu ( ) ;
if self . data_settings . use_opencl_for_feedforward {
if self . data_settings . use_opencl_for_feedforward {
* x = x . to_f16 ( ) ;
x . to_gpu_inplace ( self . data_settings . cl . as_ref ( ) . unwrap ( ) )
x . to_gpu_inplace ( self . data_settings . cl . as_ref ( ) . unwrap ( ) )
. unwrap ( ) ;
. unwrap ( ) ;
}
}
}
}
let ( w1_out , w3_out ) = rayon ::join (
let ( mut w1_out , mut w3_out ) = rayon ::join (
| | self . w1 . matrix_mul_transposed ( x ) ,
| | self . w1 . matrix_mul_transposed ( x ) ,
| | self . w3 . matrix_mul_transposed ( x ) ,
| | self . w3 . matrix_mul_transposed ( x ) ,
) ;
) ;
let w1_out = w1_out . silu ( ) ;
let w1w3_out = w1_out . hadamard_product ( & w3_out ) . transpose ( ) ;
// Float16 not supported for some of these ops on CPU.
if w1_out . is_on_cpu ( ) & & w1_out . dtype ( ) = = TensorDType ::Float16 {
w1_out = w1_out . to_f32 ( ) ;
w3_out = w3_out . to_f32 ( ) ;
}
let w1_out = w1_out . silu ( ) ;
let mut w1w3_out = w1_out . hadamard_product ( & w3_out ) . transpose ( ) ;
if w1w3_out . dtype ( ) ! = self . w2 . dtype ( ) {
w1w3_out = w1w3_out . to_same_type ( & self . w2 ) ;
}
#[ cfg(not(feature = " opencl " )) ]
#[ cfg(not(feature = " opencl " )) ]
if w1w3_out . rows ( ) = = 1 {
{
self
self . w2
. w2
. matrix_mul_transposed ( & w1w3_out )
. matrix_vector_mul_transposed_multithreaded ( & w1w3_out )
. into_dtype ( original_x_dtype )
} else {
self . w2 . matrix_mul_transposed ( & w1w3_out )
}
}
#[ cfg(feature = " opencl " ) ]
#[ cfg(feature = " opencl " ) ]
{
{
@ -503,6 +521,13 @@ impl Attention {
FromPiecesDirection ::Cols ,
FromPiecesDirection ::Cols ,
) ? ;
) ? ;
if data_settings . force_f16 {
wq = wq . to_f16 ( ) ;
wk = wk . to_f16 ( ) ;
wv = wv . to_f16 ( ) ;
wo = wo . to_f16 ( ) ;
}
#[ cfg(feature = " opencl " ) ]
#[ cfg(feature = " opencl " ) ]
{
{
if data_settings . use_opencl_for_attention {
if data_settings . use_opencl_for_attention {
@ -517,13 +542,6 @@ impl Attention {
wo . to_gpu_inplace ( & ds . cl . unwrap ( ) ) . unwrap ( ) ;
wo . to_gpu_inplace ( & ds . cl . unwrap ( ) ) . unwrap ( ) ;
}
}
}
}
#[ cfg(not(feature = " opencl " )) ]
{
wq = wq . to_f32 ( ) ;
wk = wk . to_f32 ( ) ;
wv = wv . to_f32 ( ) ;
wo = wo . to_f32 ( ) ;
}
Ok ( Self {
Ok ( Self {
wq ,
wq ,
@ -544,13 +562,17 @@ impl Attention {
mask : & Option < Tensor > ,
mask : & Option < Tensor > ,
attention_cache : & mut AttentionCache ,
attention_cache : & mut AttentionCache ,
) -> Tensor {
) -> Tensor {
let original_x_dtype = x . dtype ( ) ;
if x . dtype ( ) ! = self . wq . dtype ( ) {
* x = x . to_same_type ( & self . wq ) ;
}
#[ cfg(feature = " opencl " ) ]
#[ cfg(feature = " opencl " ) ]
let x_was_on_cpu : bool ;
let x_was_on_cpu : bool ;
#[ cfg(feature = " opencl " ) ]
#[ cfg(feature = " opencl " ) ]
{
{
x_was_on_cpu = x . is_on_cpu ( ) ;
x_was_on_cpu = x . is_on_cpu ( ) ;
if self . data_settings . use_opencl_for_attention {
if self . data_settings . use_opencl_for_attention {
* x = x . to_f16 ( ) ;
x . to_gpu_inplace ( self . data_settings . cl . as_ref ( ) . unwrap ( ) )
x . to_gpu_inplace ( self . data_settings . cl . as_ref ( ) . unwrap ( ) )
. unwrap ( ) ;
. unwrap ( ) ;
}
}
@ -570,11 +592,11 @@ impl Attention {
#[ cfg(not(feature = " opencl " )) ]
#[ cfg(not(feature = " opencl " )) ]
let ( xq_out , ( xk_out , xv_out ) ) = rayon ::join (
let ( xq_out , ( xk_out , xv_out ) ) = rayon ::join (
| | x . matrix_mul_transposed ( & self . wq ) ,
| | x . matrix_mul_transposed ( & self . wq ) .to_f32 ( ) ,
| | {
| | {
rayon ::join (
rayon ::join (
| | x . matrix_mul_transposed ( & self . wk ) ,
| | x . matrix_mul_transposed ( & self . wk ) .to_f32 ( ) ,
| | x . matrix_mul_transposed ( & self . wv ) ,
| | x . matrix_mul_transposed ( & self . wv ) .to_f32 ( ) ,
)
)
} ,
} ,
) ;
) ;
@ -666,7 +688,9 @@ impl Attention {
#[ cfg(not(feature = " opencl " )) ]
#[ cfg(not(feature = " opencl " )) ]
{
{
let xq_row = Tensor ::concat ( & concat_vec2 ) . view ( 1 , self . wo . rows ( ) ) ;
let xq_row = Tensor ::concat ( & concat_vec2 ) . view ( 1 , self . wo . rows ( ) ) ;
xq_row . matrix_mul_transposed ( & self . wo )
xq_row
. into_same_type ( & self . wo )
. matrix_mul_transposed ( & self . wo )
}
}
#[ cfg(feature = " opencl " ) ]
#[ cfg(feature = " opencl " ) ]
{
{
@ -689,7 +713,7 @@ impl Attention {
let output3 : Vec < & Tensor > = output2 . iter ( ) . collect ( ) ;
let output3 : Vec < & Tensor > = output2 . iter ( ) . collect ( ) ;
let output2 : Tensor = Tensor ::concat ( & output3 ) ;
let output2 : Tensor = Tensor ::concat ( & output3 ) ;
output2
output2 . into_dtype ( original_x_dtype )
}
}
}
}